jax.experimental.pallas.mosaic_gpu.nd_loop#
- jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None][source]#
- jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]
A loop over a multi-dimensional grid partitioned along the given axes.
The body of the loop a single argument
loop_infowhich is an NDLoopInfo object containing index and iteration information. However if a carry is specified, the body will expect a second keyword argument carry containing the loop carry.For example, if
collective_axesis"x"withlax.axis_size()equal to 4 and the grid is (2, 3), the implementation would produce the following iteration orderloop step
index
axis index
0
(0, 0)
0
1
(0, 1)
1
2
(0, 2)
2
3
(1, 0)
3
4
(1, 1)
0
5
(1, 2)
1
which comes from partitioning the flat iteration space into chunks in an interleaved fashion wrt the
"x"axis index.Note that in the example the total number of loop steps is not divisible by the axis size of
"x", and thus for some"x"axis indices the loop will do one iteration less.axis index
indices
0
(0, 0), (1, 1)
1
(0, 1), (1, 2)
2
(0, 2)
3
(1, 0)
If
init_carryis passed thennd_loop()will expect the body to take and return the carry. If it’sNonethen no carry argument is expected.See also
jax.experimental.pallas.loop(): A loop over a single dimension.