jax.ref.set#
- jax.ref.set(ref, idx, value)[source]#
Set a value in an Ref in-place.
This is equivalent to
ref[idx] = valuefor a NumPy-style indexeridx. For more on mutable array refs, refer to the Ref guide.- Parameters:
ref (core.Ref | TransformedRef) – a
jax.ref.Refobject. On return, the buffer will be mutated by this operation.idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer
value (ArrayLike | HijaxType) – a
jax.Arrayobject (note, not ajax.ref.Ref) containing the values to set in the array.
- Returns:
None
- Return type:
None
Examples
>>> import jax >>> ref = jax.new_ref(jax.numpy.zeros(5)) >>> jax.ref.set(ref, 1, 10.0) >>> ref Ref([ 0., 10., 0., 0., 0.], dtype=float32)
Equivalent operation via indexing syntax:
>>> ref = jax.new_ref(jax.numpy.zeros(5)) >>> ref[1] = 10.0 >>> ref Ref([ 0., 10., 0., 0., 0.], dtype=float32)
Use
...to set the value of a scalar ref:>>> ref = jax.new_ref(jax.numpy.int32(0)) >>> ref[...] = 4 >>> ref Ref(4, dtype=int32)