jax.experimental.pallas.mosaic_gpu.TilingTransform#
- class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[source]#
Represents a tiling transformation for memory refs.
A tiling of (X, Y) on an array of shape (M, N) will result in a transformed shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a tiling of (64, 32) will be tiled as (4, 8, 64, 32).
Methods
__init__(tiling)pretty_print(context)transform_type(x)undo(x)Attributes
tiling