jax.lax.axis_index#
- jax.lax.axis_index(axis_name)[source]#
Return the index along the mapped axis
axis_name.- Parameters:
axis_name (AxisName) – hashable Python object used to name the mapped axis.
- Returns:
An integer representing the index.
- Return type:
For example, with 8 XLA devices available:
>>> mesh = jax.make_mesh((8,), 'i', axis_types=(jax.sharding.AxisType.Explicit,)) >>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i')) ... def f(): ... return lax.axis_index('i')[None] ... >>> f() Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
>>> mesh = jax.make_mesh((4, 2), ('i', 'j'), ... axis_types=(jax.sharding.AxisType.Explicit,) * 2) >>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i', 'j')) ... def f(): ... return lax.axis_index(('i', 'j'))[None, None] ... >>> f() Array([[0, 1], [2, 3], [4, 5], [6, 7]], dtype=int32)