jax.numpy.put_along_axis#

jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[source]#

Put values into the destination array by matching 1d index and data slices.

JAX implementation of numpy.put_along_axis().

The semantics of numpy.put_along_axis() are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
  • arr (ArrayLike) – array into which values will be put.

  • indices (ArrayLike) – array of indices at which to put values.

  • values (ArrayLike) – array of values to put into the array.

  • axis (int | None) – the axis along which to put values. If not specified, the array will be flattened before indexing is applied.

  • inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.

  • mode (str | None) – Out-of-bounds indexing mode. For more discussion of mode options, see jax.numpy.ndarray.at.

Returns:

A copy of a with specified entries updated.

Return type:

Array

See also

Examples

>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
 [0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
 [99 40 50]]