jax.ref.freeze#

jax.ref.freeze(ref)[source]#

Invalidate a given reference and return its final value.

For more information about mutable array references, refer to the Ref guide.

Parameters:

ref (Ref) – A jax.ref.Ref object.

Returns:

A jax.Array containing the contents of ref.

Return type:

Array

Examples

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> ref[3] = 100
>>> ref
Ref([  0,   1,   2, 100,   4], dtype=int32)
>>> jax.ref.freeze(ref)
Array([  0,   1,   2, 100,   4], dtype=int32)