jax.experimental.pallas.mosaic_gpu module#

Experimental GPU backend for Pallas targeting H100.

These APIs are highly unstable and can change weekly. Use at your own risk.

Classes#

Barrier(*[, num_arrivals, num_barriers, ...])

Describes a barrier reference.

BlockSpec([block_shape, index_map, ...])

A GPU-specific BlockSpec.

CompilerParams(*[, approx_math, ...])

Mosaic GPU compiler parameters.

MemorySpace(value[, names, module, ...])

Layout(value[, names, module, qualname, ...])

SemaphoreType(value[, names, module, ...])

SwizzleTransform(swizzle)

TilingTransform(tiling)

Represents a tiling transformation for memory refs.

TransposeTransform(permutation)

Transpose a tiled memref.

WGMMAAccumulatorRef(shape, dtype, _init)

Functions#

as_torch_kernel(fn)

Makes a Mosaic GPU kernel callable with PyTorch tensors.

kernel(body, out_shape, *[, scratch_shapes, ...])

Entry point for defining a Mosaic GPU kernel.

layout_cast(x, new_layout)

Casts the layout of the given array.

set_max_registers(n, *, action)

Sets the maximum number of registers owned by a warp.

planar_snake(lin_idx, shape, minor_dim, ...)

Converts a linear index into an index into shape, trying to optimize locality.

Loop-like functions#

emit_pipeline(body, *, grid[, in_specs, ...])

Creates a function to emit a manual pipeline within a Pallas kernel.

emit_pipeline_warp_specialized(body, *, ...)

Creates a function to emit a warp-specialized pipeline.

nd_loop()

A loop over a multi-dimensional grid partitioned along the given axes.

dynamic_scheduling_loop()

A loop over program instances using dynamic work scheduling.

Synchronization#

barrier_arrive(barrier)

Arrives at the given barrier.

barrier_wait(barrier)

Waits on the given barrier.

semaphore_signal_parallel(*signals)

Signals multiple semaphores without any guaranteed ordering of signal arrivals.

SemaphoreSignal(ref, *, device_id[, inc])

Asynchronous copies#

commit_smem()

Commits all writes to SMEM, making them visible to TMA and MMA operations.

copy_gmem_to_smem(src, dst, barrier, *[, ...])

Asynchronously copies a GMEM reference to a SMEM reference.

copy_smem_to_gmem(src, dst[, predicate, ...])

Asynchronously copies a SMEM reference to a GMEM reference.

wait_smem_to_gmem(n[, wait_read_only])

Waits until no more than the most recent n SMEM->GMEM copies issued by the calling thread are in flight.

Hopper-specific functions#

wgmma(acc, a, b)

Performs an asynchronous warp group matmul-accumulate on the given references.

wgmma_wait(n)

Waits until there is no more than n WGMMA operations in flight.

Blackwell-specific functions#

tcgen05_mma(acc, a, b[, barrier, a_scale, ...])

Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell).

tcgen05_commit_arrive(barrier[, collective_axis])

Tracks completion of a preceding tcgen05_mma call.

async_load_tmem(src, *[, layout])

Performs an asynchronous load from the TMEM array.

async_store_tmem(ref, value)

Stores the value to TMEM.

wait_load_tmem()

Awaits all previously asynchronous TMEM loads issued by the calling thread.

commit_tmem()

Commits all writes to TMEM issued by the current thread.

try_cluster_cancel(result_ref, barrier)

Initiates an async request to claim a new work unit from the grid.

query_cluster_cancel(result_ref, grid_names)

Decodes the result of a try_cluster_cancel operation.

Multimem operations#

multimem_store(source, ref, collective_axes)

Stores the value to ref on all devices present in collective_axes.

multimem_load_reduce(ref, *, ...)

Loads from a GMEM reference on all devices present in collective_axes and reduces the loaded values.

Aliases#

ACC

alias of WGMMAAccumulatorRef

GMEM

Alias of jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM.

SMEM

Alias of jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM.