jax.ref.swap#

jax.ref.swap(ref, idx, value, _function_name='ref_swap')[source]#

Update an array value inplace while returning the previous value.

This is equivalent to ref[idx], prev = value, ref[idx] while returning prev, for a NumPy-style indexer idx. For more on mutable array refs, refer to the Ref guide.

Parameters:
  • ref (core.Ref | TransformedRef) – a jax.ref.Ref object. On return, the buffer will be mutated by this operation.

  • idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer

  • value (ArrayLike | HijaxType) – a jax.Array object (note, not a jax.ref.Ref) containing the values to set in the array.

  • _function_name (str)

Returns:

A jax.Array containing the previous value at idx.

Return type:

Array | HijaxType

Examples

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> jax.ref.swap(ref, 3, 10)
Array(3, dtype=int32)
>>> ref
Ref([ 0,  1,  2, 10,  4], dtype=int32)

Equivalent operation via indexing syntax:

>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> ref[3], prev = 10, ref[3]
>>> prev
Array(3, dtype=int32)
>>> ref
Ref([ 0,  1,  2, 10,  4], dtype=int32)

Use ... to swap the value of a scalar ref:

>>> ref = jax.new_ref(jax.numpy.int32(5))
>>> jax.ref.swap(ref, ..., 10)
Array(5, dtype=int32)
>>> ref
Ref(10, dtype=int32)