jax.ref.swap#
- jax.ref.swap(ref, idx, value, _function_name='ref_swap')[source]#
Update an array value inplace while returning the previous value.
This is equivalent to
ref[idx], prev = value, ref[idx]while returningprev, for a NumPy-style indexeridx. For more on mutable array refs, refer to the Ref guide.- Parameters:
ref (core.Ref | TransformedRef) – a
jax.ref.Refobject. On return, the buffer will be mutated by this operation.idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer
value (ArrayLike | HijaxType) – a
jax.Arrayobject (note, not ajax.ref.Ref) containing the values to set in the array._function_name (str)
- Returns:
A
jax.Arraycontaining the previous value at idx.- Return type:
Array | HijaxType
Examples
>>> import jax >>> ref = jax.new_ref(jax.numpy.arange(5)) >>> jax.ref.swap(ref, 3, 10) Array(3, dtype=int32) >>> ref Ref([ 0, 1, 2, 10, 4], dtype=int32)
Equivalent operation via indexing syntax:
>>> ref = jax.new_ref(jax.numpy.arange(5)) >>> ref[3], prev = 10, ref[3] >>> prev Array(3, dtype=int32) >>> ref Ref([ 0, 1, 2, 10, 4], dtype=int32)
Use
...to swap the value of a scalar ref:>>> ref = jax.new_ref(jax.numpy.int32(5)) >>> jax.ref.swap(ref, ..., 10) Array(5, dtype=int32) >>> ref Ref(10, dtype=int32)