jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem

jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem#

jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem(src, dst, barrier, *, collective_axes=None, leader_tracked=None, oob_mode=OOBFillMode.ZEROS)[source]#

Asynchronously copies a GMEM reference to a SMEM reference.

If collective_axes is specified, this performs a multicast copy where all CUDA blocks that share the same index along the collective axis receive a copy of the same block of data loaded from dst to src.

If both collective_axes and leader_tracked are specified as CopyPartition.PARTITIONED(axis), this will perform a partitioned collective copy where each block in the cluster will receive a tile of transfer_size // cluster_size data from the src Ref. For example, if src has a shape of (256, 256) and a partitioned copy is performed along axis 0 with cluster size 2, then the first block will receive src[0:128, :] and the second will receive src[128:256, :].

If both collective_axes and leader_tracked are specified as CopyPartition.REPLICATED, this will perform a replicated copy where all blocks load the same data but only the first block in the collective tracks progress via barrier arrivals.

NOTE: Only the first block in the cluster will arrive on the barrier, and an additional cluster barrier is necessary to ensure that all blocks in the cluster have finished the copy.

Parameters:
  • src (_Ref) – The source Ref. Must be in GMEM.

  • dst (_Ref) – The destination Ref. Must be in SMEM.

  • barrier (_Ref) – The barrier to use for tracking completion of the copy.

  • collective_axes (str | tuple[str, ...] | None) – The collective axes to use for the copy.

  • leader_tracked (CopyPartition | None) – If specified, only the leader block in the cluster will observe the completion of the copy. If CopyPartition.PARTITIONED(axis), performs a partitioned collective copy along the given axis. If CopyPartition.REPLICATED, all blocks load the same data.

  • oob_mode (OOBFillMode) – The optional out-of-bounds fill mode. Can be OOBFillMode.UNDEFINED, OOBFillMode.PROMISE_IN_BOUNDS or OOBFillMode.ZEROS.

Return type:

None