jax.experimental.pallas.mosaic_gpu.TilingTransform

jax.experimental.pallas.mosaic_gpu.TilingTransform#

class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[source]#

Represents a tiling transformation for memory refs.

A tiling of (X, Y) on an array of shape (M, N) will result in a transformed shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a tiling of (64, 32) will be tiled as (4, 8, 64, 32).

Parameters:

tiling (tuple[int, ...])

__init__(tiling)#
Parameters:

tiling (tuple[int, ...])

Return type:

None

Methods

__init__(tiling)

pretty_print(context)

transform_type(x)

undo(x)

Attributes

tiling