Stores a value to the given ref.
If mask is not specified, this function has the same semantics as
ref[idx] = val in JAX.
- Parameters:
ref (Ref) – The ref to store to.
val (jax.Array) – The value to store.
mask (jax.Array | None) – An optional boolean mask specifying which indices to store.
- Return type:
None