jax.experimental.pallas.mosaic_gpu.nd_loop#

jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None][source]#
jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]

A loop over a multi-dimensional grid partitioned along the given axes.

The body of the loop a single argument loop_info which is an NDLoopInfo object containing index and iteration information. However if a carry is specified, the body will expect a second keyword argument carry containing the loop carry.

For example, if collective_axes is "x" with lax.axis_size() equal to 4 and the grid is (2, 3), the implementation would produce the following iteration order

loop step

index

axis index

0

(0, 0)

0

1

(0, 1)

1

2

(0, 2)

2

3

(1, 0)

3

4

(1, 1)

0

5

(1, 2)

1

which comes from partitioning the flat iteration space into chunks in an interleaved fashion wrt the "x" axis index.

Note that in the example the total number of loop steps is not divisible by the axis size of "x", and thus for some "x" axis indices the loop will do one iteration less.

axis index

indices

0

(0, 0), (1, 1)

1

(0, 1), (1, 2)

2

(0, 2)

3

(1, 0)

If init_carry is passed then nd_loop() will expect the body to take and return the carry. If it’s None then no carry argument is expected.

See also