jax.experimental.pallas.tpu.load# jax.experimental.pallas.tpu.load(ref, *, mask=None)[source]# Loads an array from the given ref. If mask is not specified, this function has the same semantics as ref[idx] in JAX. Parameters: ref (Ref) – The ref to load from. mask (jax.Array | None) – An optional boolean mask specifying which indices to load. Returns: The loaded array. Return type: jax.Array