jax.ref.get#
- jax.ref.get(ref, idx=None)[source]#
Read a value from an Ref.
This is equivalent to
ref[idx]for a NumPy-style indexeridx. For more on mutable array refs, refer to the Ref guide.- Parameters:
ref (core.Ref | TransformedRef) – a
jax.ref.Refobject.idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer
- Returns:
A
jax.Arrayobject (note, not ajax.ref.Ref) containing the indexed elements of the mutable reference.- Return type:
Array | HijaxType
Examples
>>> import jax >>> ref = jax.new_ref(jax.numpy.arange(5)) >>> jax.ref.get(ref, slice(1, 3)) Array([1, 2], dtype=int32)
Equivalent operation via indexing syntax:
>>> ref[1:3] Array([1, 2], dtype=int32)
Use
...to extract the full buffer:>>> ref[...] Array([0, 1, 2, 3, 4], dtype=int32)