jax.experimental.pallas.kernel

Contents

jax.experimental.pallas.kernel#

jax.experimental.pallas.kernel(body=<not-specified>, out_type=(), *, mesh, scratch_types=(), compiler_params=None, interpret=False, cost_estimate=None, debug=False, name=None, metadata=None)[source]#

Entry point for creating a Pallas kernel.

This is a convenience wrapper around mpmd_map for executing a kernel over a mesh.

If body is provided, this function behaves as a decorator:

def kernel_body(in_ref, out_ref):
  ...
kernel = pl.kernel(kernel_body, out_type=...)

If body is omitted, this function behaves as a decorator factory and will return a decorator that can be used to annotate a kernel body:

@pl.kernel(mesh=..., out_type=...)
def kernel(in_ref, out_ref):
  ...

For MPMD kernels, you can pass parallel lists of bodies and meshes:

my_kernel = pl.kernel(
    body=[vector_fn, scalar_fn],
    mesh=[v_mesh, s_mesh],
    out_type=...
)
Parameters:
  • body (Callable | Sequence[Callable] | NotSpecified) – The body of the kernel. If provided, this function behaves as a decorator, and if omitted, this function behaves as a decorator factory. Can also be a sequence of callables to be paired with a sequence of meshes.

  • out_type (object | None) – The type of the output. Should be a PyTree of jax.ShapeDtypeStruct or JAX types.

  • mesh (Mesh | Sequence[Mesh]) – The mesh to run the kernel on. Must be a sequence of meshes if body is a sequence of callables.

  • scratch_types (Sequence[ScratchShape | ScratchShapeTree | None] | Mapping[str, ScratchShape | ScratchShapeTree]) – The shapes of the scratch arrays.

  • compiler_params (CompilerParams | None) – The compiler parameters to pass to the backend.

  • interpret (bool) – Whether to run the function in interpret mode.

  • debug (bool) – Whether or not to out helpful debugging information.

  • cost_estimate (CostEstimate | None) – The cost estimate of the function.

  • name (str | None) – The (optional) name of the kernel.

  • metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.

Returns:

If body is provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as out_type. If body is omitted, returns a decorator that can be used to annotate a kernel body.