jax.experimental.pallas.mosaic_gpu.planar_snake#

jax.experimental.pallas.mosaic_gpu.planar_snake(lin_idx, shape, minor_dim, tile_width)[source]#

Converts a linear index into an index into shape, trying to optimize locality.

The “space filling curve” this function computes splits the minor dimension into tiles of length tile_width. Every other tile has its major dimension inverted, so that the iteration order “snakes around” when going from one tile to another.

For a shape of (8, 8), minor_dim=0 and tile_width=2, the iteration order is:

 0   2   4   6   8  10  12  14
 1   3   5   7   9  11  13  15
30  28  26  24  22  20  18  16
31  29  27  25  23  21  19  17
32  34  36  38  40  42  44  46
33  35  37  39  41  43  45  47
62  60  58  56  54  52  50  48
63  61  59  57  55  53  51  49

Notice how each pair of rows forms a tile (minor_dim=0, tile_width=2) and when moving from one tile to another, the indices increase along columns in one of them and decrease in the other.

Parameters: