jax.experimental.pallas.mosaic_gpu.multimem_store#

jax.experimental.pallas.mosaic_gpu.multimem_store(source, ref, collective_axes)[source]#

Stores the value to ref on all devices present in collective_axes.

The stores is done using the multimem instructions, meaning that the data is only transferred to the switch once, and broadcasted to all other devices there.

Parameters:
  • source (jax.Array) – The value to store.

  • ref (_Ref) – The GMEM reference to store the value to.

  • collective_axes (Hashable | tuple[Hashable, ...]) – The JAX mesh axes indicating the devices to store to.