jax.experimental.pallas.kernel#
- jax.experimental.pallas.kernel(body=<not-specified>, out_shape=None, *, mesh, scratch_shapes=(), compiler_params=None, interpret=False, cost_estimate=None, debug=False, name=None, metadata=None)[source]#
Entry point for creating a Pallas kernel.
This is a convenience wrapper around
core_mapfor executing a kernel over a mesh andrun_scopedfor allocating scratch memory.If
bodyis provided, this function behaves as a decorator:def kernel_body(in_ref, out_ref): ... kernel = pl.kernel(kernel_body, out_shape=...)
If
bodyis omitted, this function behaves as a decorator factory and will return a decorator that can be used to annotate a kernel body:@pl.kernel(out_shape=...) def kernel(in_ref, out_ref): ...
- Parameters:
body (Callable | NotSpecified) – The body of the kernel. If provided, this function behaves as a decorator, and if omitted, this function behaves as a decorator factory.
out_shape (object | None) – The shape of the output. Should be a PyTree of
jax.ShapeDtypeStructorjax.Arrays.mesh (Mesh) – The mesh to run the kernel on.
scratch_shapes (Sequence[ScratchShape | ScratchShapeTree] | Mapping[str, ScratchShape | ScratchShapeTree]) – The shapes of the scratch arrays.
compiler_params (CompilerParams | None) – The compiler parameters to pass to the backend.
interpret (bool) – Whether to run the function in interpret mode.
debug (bool) – Whether or not to out helpful debugging information.
cost_estimate (CostEstimate | None) – The cost estimate of the function.
name (str | None) – The (optional) name of the kernel.
metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.
- Returns:
If
bodyis provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as out_shape. Ifbodyis omitted, returns a decorator that can be used to annotate a kernel body.