jax.lax.axis_index#

jax.lax.axis_index(axis_name)[source]#

Return the index along the mapped axis axis_name.

Parameters:

axis_name (AxisName) – hashable Python object used to name the mapped axis.

Returns:

An integer representing the index.

Return type:

Array

For example, with 8 XLA devices available:

>>> mesh = jax.make_mesh((8,), 'i', axis_types=(jax.sharding.AxisType.Explicit,))
>>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i'))
... def f():
...   return lax.axis_index('i')[None]
...
>>> f()
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
>>> mesh = jax.make_mesh((4, 2), ('i', 'j'),
...                       axis_types=(jax.sharding.AxisType.Explicit,) * 2)
>>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i', 'j'))
... def f():
...   return lax.axis_index(('i', 'j'))[None, None]
...
>>> f()
Array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]], dtype=int32)