jax.experimental.pallas.mosaic_gpu.planar_snake#
- jax.experimental.pallas.mosaic_gpu.planar_snake(lin_idx, shape, minor_dim, tile_width)[source]#
Converts a linear index into an index into shape, trying to optimize locality.
The “space filling curve” this function computes splits the minor dimension into tiles of length
tile_width. Every other tile has its major dimension inverted, so that the iteration order “snakes around” when going from one tile to another.For a shape of (8, 8),
minor_dim=0andtile_width=2, the iteration order is:0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 30 28 26 24 22 20 18 16 31 29 27 25 23 21 19 17 32 34 36 38 40 42 44 46 33 35 37 39 41 43 45 47 62 60 58 56 54 52 50 48 63 61 59 57 55 53 51 49
Notice how each pair of rows forms a tile (
minor_dim=0,tile_width=2) and when moving from one tile to another, the indices increase along columns in one of them and decrease in the other.