jax.experimental.pallas.tpu.store

Contents

jax.experimental.pallas.tpu.store#

jax.experimental.pallas.tpu.store(ref, val, *, mask=None)[source]#

Stores a value to the given ref.

If mask is not specified, this function has the same semantics as ref[idx] = val in JAX.

Parameters:
  • ref (Ref) – The ref to store to.

  • val (jax.Array) – The value to store.

  • mask (jax.Array | None) – An optional boolean mask specifying which indices to store.

Return type:

None