jax.experimental.pallas.mosaic_gpu module#
Experimental GPU backend for Pallas targeting H100.
These APIs are highly unstable and can change weekly. Use at your own risk.
Classes#
|
Describes a barrier reference. |
|
A GPU-specific |
|
Mosaic GPU compiler parameters. |
|
|
|
|
|
|
|
|
|
Represents a tiling transformation for memory refs. |
|
Transpose a tiled memref. |
|
Functions#
|
Makes a Mosaic GPU kernel callable with PyTorch tensors. |
|
Entry point for defining a Mosaic GPU kernel. |
|
Casts the layout of the given array. |
|
Sets the maximum number of registers owned by a warp. |
|
Converts a linear index into an index into shape, trying to optimize locality. |
Loop-like functions#
|
Creates a function to emit a manual pipeline within a Pallas kernel. |
|
Creates a function to emit a warp-specialized pipeline. |
|
A loop over a multi-dimensional grid partitioned along the given axes. |
A loop over program instances using dynamic work scheduling. |
Synchronization#
|
Arrives at the given barrier. |
|
Waits on the given barrier. |
|
Signals multiple semaphores without any guaranteed ordering of signal arrivals. |
|
Asynchronous copies#
Commits all writes to SMEM, making them visible to TMA and MMA operations. |
|
|
Asynchronously copies a GMEM reference to a SMEM reference. |
|
Asynchronously copies a SMEM reference to a GMEM reference. |
|
Waits until no more than the most recent |
Hopper-specific functions#
|
Performs an asynchronous warp group matmul-accumulate on the given references. |
|
Waits until there is no more than |
Blackwell-specific functions#
|
Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell). |
|
Tracks completion of a preceding |
|
Performs an asynchronous load from the TMEM array. |
|
Stores the value to TMEM. |
Awaits all previously asynchronous TMEM loads issued by the calling thread. |
|
Commits all writes to TMEM issued by the current thread. |
|
|
Initiates an async request to claim a new work unit from the grid. |
|
Decodes the result of a |
Multimem operations#
|
Stores the value to ref on all devices present in collective_axes. |
|
Loads from a GMEM reference on all devices present in collective_axes and reduces the loaded values. |
Aliases#
alias of |
|
Alias of |
|
Alias of |