jax.experimental.pallas.tpu.load

Contents

jax.experimental.pallas.tpu.load#

jax.experimental.pallas.tpu.load(ref, *, mask=None)[source]#

Loads an array from the given ref.

If mask is not specified, this function has the same semantics as ref[idx] in JAX.

Parameters:
  • ref (Ref) – The ref to load from.

  • mask (jax.Array | None) – An optional boolean mask specifying which indices to load.

Returns:

The loaded array.

Return type:

jax.Array