jax.numpy.take#

jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Take elements from an array.

JAX implementation of numpy.take(), implemented in terms of jax.lax.gather(). JAX’s behavior differs from NumPy in the case of out-of-bound indices; see the mode parameter below.

Parameters:
  • a (ArrayLike) – array from which to take values.

  • indices (ArrayLike) – N-dimensional array of integer indices of values to take from the array.

  • axis (int | None) – the axis along which to take values. If not specified, the array will be flattened before indexing is applied.

  • mode (str | None) – Out-of-bounds indexing mode, either "fill" or "clip". The default mode="fill" returns invalid values (e.g. NaN) for out-of bounds indices; the fill_value argument gives control over this value. For more discussion of mode options, see jax.numpy.ndarray.at.

  • fill_value (StaticScalar | None) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

  • unique_indices (bool) – If True, the implementation will assume that the indices are unique after normalization of negative indices, which lets the compiler emit more efficient code during the backward pass. If set to True and normalized indices are not unique, the result is implementation-defined and may be non-deterministic.

  • indices_are_sorted (bool) – If True, the implementation will assume that the indices are sorted in ascending order after normalization of negative indices, which can lead to more efficient execution on some backends. If set to True and normalized indices are not sorted, the output is implementation-defined.

  • out (None)

Returns:

Array of values extracted from a.

Return type:

Array

See also

Examples

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([2, 0])

Passing no axis results in indexing into the flattened array:

>>> jnp.take(x, indices)
Array([3., 1.], dtype=float32)
>>> x.ravel()[indices]  # equivalent indexing syntax
Array([3., 1.], dtype=float32)

Passing an axis results ind applying the index to every subarray along the axis:

>>> jnp.take(x, indices, axis=1)
Array([[3., 1.],
       [6., 4.]], dtype=float32)
>>> x[:, indices]  # equivalent indexing syntax
Array([[3., 1.],
       [6., 4.]], dtype=float32)

Out-of-bound indices fill with invalid values. For float inputs, this is NaN:

>>> jnp.take(x, indices, axis=0)
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan)  # equivalent indexing syntax
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)

This default out-of-bound behavior can be adjusted using the mode parameter, for example, we can instead clip to the last valid value:

>>> jnp.take(x, indices, axis=0, mode='clip')
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='clip')  # equivalent indexing syntax
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)