jax.experimental.pallas.mosaic_gpu.kernel#
- jax.experimental.pallas.mosaic_gpu.kernel(body, out_shape, *, scratch_shapes=(), compiler_params=None, grid=(), grid_names=(), cluster=(), cluster_names=(), num_threads=None, thread_name=None, **mesh_kwargs)[source]#
Entry point for defining a Mosaic GPU kernel.
- Parameters:
body (Callable[..., None]) – The kernel body, which should take as arguments the input, output, and scratch Refs. The number of input Refs is determined by the number of arguments passed into kernel returned by this function. The number of output and scratch Refs are determined by out_shape and scratch_shapes respectively.
out_shape (object) – a PyTree of
jax.ShapeDtypeStructdescribing the shape and dtypes of the outputs.scratch_shapes (pallas_core.ScratchShapeTree) – an iterable (may be nested) of GPUMemoryRef describing scratch Refs to allocate for this kernel.
compiler_params (pallas_core.CompilerParams | None) – Additional compiler options. See the CompilerParams dataclass for more details.
grid (tuple[int, ...]) – A tuple of integers specifying the size of the kernel grid.
grid_names (tuple[str, ...]) – The axis names of the grid. Must be the same length as grid.
cluster (tuple[int, ...]) – A tuple of integers specifying the size of the kernel cluster.
cluster_names (tuple[str, ...]) – The axis names of the grid. Must be the same length as cluster.
num_threads (int | None) – The number of threads to launch per block. Note that these do not correspond to CUDA threads, but rather to warpgroups on Hopper and Blackwell GPUs.
thread_name (str | None) – The axis name used to query the thread index.
**mesh_kwargs (object) – Additional mesh kwargs. See Mesh for more details.
- Returns:
A function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as out_shape.