jax.experimental.pallas.triton.load

Contents

jax.experimental.pallas.triton.load#

jax.experimental.pallas.triton.load(ref, *, mask=None, other=None, cache_modifier=None, eviction_policy=None, volatile=False)[source]#

Loads an array from the given ref.

If neither mask nor other is specified, this function has the same semantics as ref[idx] in JAX.

Parameters:
  • ref (Ref) – The ref to load from.

  • mask (jax.Array | None) – An optional boolean mask specifying which indices to load. If mask is False and other is not given, no assumptions can be made about the value in the resulting array.

  • other (jax.typing.ArrayLike | None) – An optional value to use for indices where mask is False.

  • cache_modifier (str | None) – TO BE DOCUMENTED.

  • eviction_policy (str | None) – TO BE DOCUMENTED.

  • volatile (bool) – TO BE DOCUMENTED.

Return type:

jax.Array