jax.ref.get#

jax.ref.get(ref, idx=None)[source]#

Read a value from an Ref.

This is equivalent to ref[idx] for a NumPy-style indexer idx. For more on mutable array refs, refer to the Ref guide.

Parameters:
  • ref (core.Ref | TransformedRef) – a jax.ref.Ref object.

  • idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer

Returns:

A jax.Array object (note, not a jax.ref.Ref) containing the indexed elements of the mutable reference.

Return type:

Array | HijaxType

Examples

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> jax.ref.get(ref, slice(1, 3))
Array([1, 2], dtype=int32)

Equivalent operation via indexing syntax:

>>> ref[1:3]
Array([1, 2], dtype=int32)

Use ... to extract the full buffer:

>>> ref[...]
Array([0, 1, 2, 3, 4], dtype=int32)