jax.ref.addupdate

Contents

jax.ref.addupdate#

jax.ref.addupdate(ref, idx, x)[source]#

Add to an element in an Ref in-place.

This is analogous to ref[idx] += value for a NumPy array ref and NumPy-style indexer idx. However, for an Ref ref, executing ref[idx] += value actually performs a ref_get, add, and ref_set, so using this function can be more efficient under autodiff. For more on mutable array refs, refer to the Ref guide.

Parameters:
  • ref (core.Ref | TransformedRef) – a jax.ref.Ref object. On return, the buffer will be mutated by this operation.

  • idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer

  • x (ArrayLike | HijaxType) – a jax.Array object (note, not a jax.ref.Ref) containing the values to add at the specified indices.

Returns:

None

Return type:

None

Examples

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> jax.ref.addupdate(ref, 2, 10)
>>> ref
Ref([ 0,  1, 12,  3,  4], dtype=int32)

Equivalent operation via indexing syntax:

>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> ref[2] += 10
>>> ref
Ref([ 0,  1, 12,  3,  4], dtype=int32)

Use ... to add to a scalar ref:

>>> ref = jax.new_ref(jax.numpy.int32(2))
>>> ref[...] += 10
>>> ref
Ref(12, dtype=int32)