jax.experimental.pallas.kernel#
- jax.experimental.pallas.kernel(body=<not-specified>, out_type=(), *, mesh, scratch_types=(), compiler_params=None, interpret=False, cost_estimate=None, debug=False, name=None, metadata=None)[source]#
Entry point for creating a Pallas kernel.
This is a convenience wrapper around
mpmd_mapfor executing a kernel over a mesh.If
bodyis provided, this function behaves as a decorator:def kernel_body(in_ref, out_ref): ... kernel = pl.kernel(kernel_body, out_type=...)
If
bodyis omitted, this function behaves as a decorator factory and will return a decorator that can be used to annotate a kernel body:@pl.kernel(mesh=..., out_type=...) def kernel(in_ref, out_ref): ...
For MPMD kernels, you can pass parallel lists of bodies and meshes:
my_kernel = pl.kernel( body=[vector_fn, scalar_fn], mesh=[v_mesh, s_mesh], out_type=... )
- Parameters:
body (Callable | Sequence[Callable] | NotSpecified) β The body of the kernel. If provided, this function behaves as a decorator, and if omitted, this function behaves as a decorator factory. Can also be a sequence of callables to be paired with a sequence of meshes.
out_type (object | None) β The type of the output. Should be a PyTree of
jax.ShapeDtypeStructor JAX types.mesh (Mesh | Sequence[Mesh]) β The mesh to run the kernel on. Must be a sequence of meshes if
bodyis a sequence of callables.scratch_types (Sequence[ScratchShape | ScratchShapeTree | None] | Mapping[str, ScratchShape | ScratchShapeTree]) β The shapes of the scratch arrays.
compiler_params (CompilerParams | None) β The compiler parameters to pass to the backend.
interpret (bool) β Whether to run the function in interpret mode.
debug (bool) β Whether or not to out helpful debugging information.
cost_estimate (CostEstimate | None) β The cost estimate of the function.
name (str | None) β The (optional) name of the kernel.
metadata (dict[str, str] | None) β Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.
- Returns:
If
bodyis provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as out_type. Ifbodyis omitted, returns a decorator that can be used to annotate a kernel body.