jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem#
- jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem(src, dst, barrier, *, collective_axes=None, leader_tracked=None, oob_mode=OOBFillMode.ZEROS)[source]#
Asynchronously copies a GMEM reference to a SMEM reference.
If collective_axes is specified, this performs a multicast copy where all CUDA blocks that share the same index along the collective axis receive a copy of the same block of data loaded from dst to src.
If both
collective_axesandleader_trackedare specified asCopyPartition.PARTITIONED(axis), this will perform a partitioned collective copy where each block in the cluster will receive a tile oftransfer_size // cluster_sizedata from thesrcRef. For example, ifsrchas a shape of (256, 256) and a partitioned copy is performed along axis 0 with cluster size 2, then the first block will receivesrc[0:128, :]and the second will receivesrc[128:256, :].If both
collective_axesandleader_trackedare specified asCopyPartition.REPLICATED, this will perform a replicated copy where all blocks load the same data but only the first block in the collective tracks progress via barrier arrivals.NOTE: Only the first block in the cluster will arrive on the barrier, and an additional cluster barrier is necessary to ensure that all blocks in the cluster have finished the copy.
- Parameters:
src (_Ref) – The source Ref. Must be in GMEM.
dst (_Ref) – The destination Ref. Must be in SMEM.
barrier (_Ref) – The barrier to use for tracking completion of the copy.
collective_axes (str | tuple[str, ...] | None) – The collective axes to use for the copy.
leader_tracked (CopyPartition | None) – If specified, only the leader block in the cluster will observe the completion of the copy. If
CopyPartition.PARTITIONED(axis), performs a partitioned collective copy along the given axis. IfCopyPartition.REPLICATED, all blocks load the same data.oob_mode (OOBFillMode) – The optional out-of-bounds fill mode. Can be
OOBFillMode.UNDEFINED,OOBFillMode.PROMISE_IN_BOUNDSorOOBFillMode.ZEROS.
- Return type:
None