jax.experimental.pallas.mosaic_gpu.multimem_store#
- jax.experimental.pallas.mosaic_gpu.multimem_store(source, ref, collective_axes)[source]#
Stores the value to ref on all devices present in collective_axes.
The stores is done using the multimem instructions, meaning that the data is only transferred to the switch once, and broadcasted to all other devices there.