jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop#

jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop(grid_names: Sequence[Hashable], *, thread_axis: Hashable | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None][source]#
jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop(grid_names: Sequence[Hashable], *, thread_axis: Hashable | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]

A loop over program instances using dynamic work scheduling.

This loop will iterate through available program instances until all work has been scheduled. The kernel should be instantiated with a grid equal to the logical amount of work to be done (as opposed to a persistent kernel where the grid is set to the number of cores). Each core running this loop will continuously query the next available block of work and the loop will terminate when the entire grid has been scheduled.

Example usage:

@plgpu.dynamic_scheduling_loop(grid_names)
def body(loop_info):
  work(loop_info.index)  # do work...
Parameters:
  • grid_names – The names of the axes in the grid.

  • thread_axis – The name of the thread axis. This must be passed in if the kernel uses multiple threads.

  • init_carry – An optional initial carry for the loop. If passed in, the body function should expect a carry keyword argument and return the next carry value.