jax.ref.new_ref

Contents

jax.ref.new_ref#

jax.ref.new_ref(init_val, *, memory_space=None, kind=None)[source]#

Create a mutable array reference with initial value init_val.

For more discussion, see the Ref guide.

Parameters:
  • init_val (Any) – A jax.Array representing the initial state of the buffer.

  • memory_space (Any) – An optional memory space attribute for the Ref.

  • kind (str | None) – An optional string indicating the mutation semantics under rematerialization. Currently only supports 'no_grad_no_remat' or None.

Returns:

A jax.ref.Ref containing a reference to a mutable buffer.

Return type:

Ref