jax.experimental.pallas.program_id#
- jax.experimental.pallas.program_id(axis)[source]#
Returns the kernel execution position along the given axis of the grid.
For example, with a 2D
gridin the kernel execution corresponding to the grid coordinates(1, 2),program_id(axis=0)returns1andprogram_id(axis=1)returns2.The returned value is an array of shape
()and dtypeint32.- Parameters:
axis (int) – the axis of the grid along which to count the program.
- Return type:
jax_typing.Array