jax.experimental.pallas.mosaic_gpu.multimem_load_reduce

jax.experimental.pallas.mosaic_gpu.multimem_load_reduce#

jax.experimental.pallas.mosaic_gpu.multimem_load_reduce(ref, *, collective_axes, reduction_op)[source]#

Loads from a GMEM reference on all devices present in collective_axes and reduces the loaded values.

The supported dtypes are: jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.int32 and jnp.int64.

8-bit floating point dtypes are only supported on Blackwell GPUs.

Parameters:
  • ref (_Ref) – The GMEM reference to load from.

  • collective_axes (Hashable | tuple[Hashable, ...]) – The JAX mesh axes indicating the devices to load from.

  • reduction_op (mgpu.MultimemReductionOp) – The reduction operation to perform on the loaded values. The allowed values are add (all dtypes), min, max (all dtypes but f32), as well as and, or and xor (integer types only).

Return type:

jax.Array