jax.experimental.pallas.kernel#

jax.experimental.pallas.kernel(body=<not-specified>, out_shape=None, *, mesh, scratch_shapes=(), compiler_params=None, interpret=False, cost_estimate=None, debug=False, name=None, metadata=None)[source]#

Entry point for creating a Pallas kernel.

This is a convenience wrapper around core_map for executing a kernel over a mesh and run_scoped for allocating scratch memory.

If body is provided, this function behaves as a decorator:

def kernel_body(in_ref, out_ref):
  ...
kernel = pl.kernel(kernel_body, out_shape=...)

If body is omitted, this function behaves as a decorator factory and will return a decorator that can be used to annotate a kernel body:

@pl.kernel(out_shape=...)
def kernel(in_ref, out_ref):
  ...
Parameters:
  • body (Callable | NotSpecified) – The body of the kernel. If provided, this function behaves as a decorator, and if omitted, this function behaves as a decorator factory.

  • out_shape (object | None) – The shape of the output. Should be a PyTree of jax.ShapeDtypeStruct or jax.Array s.

  • mesh (Mesh) – The mesh to run the kernel on.

  • scratch_shapes (Sequence[ScratchShape | ScratchShapeTree] | Mapping[str, ScratchShape | ScratchShapeTree]) – The shapes of the scratch arrays.

  • compiler_params (CompilerParams | None) – The compiler parameters to pass to the backend.

  • interpret (bool) – Whether to run the function in interpret mode.

  • debug (bool) – Whether or not to out helpful debugging information.

  • cost_estimate (CostEstimate | None) – The cost estimate of the function.

  • name (str | None) – The (optional) name of the kernel.

  • metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.

Returns:

If body is provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as out_shape. If body is omitted, returns a decorator that can be used to annotate a kernel body.