jax.experimental.pallas.core_map#
- jax.experimental.pallas.core_map(mesh, *, compiler_params=None, interpret=False, debug=False, cost_estimate=None, name=None, metadata=None)[source]#
Runs a function on a mesh, mapping it over the devices in the mesh.
The function should be stateful in that it takes in no inputs and returns no outputs but can mutate closed-over Refs, for example.
- Parameters:
mesh – The mesh to run the function on.
compiler_params (Any | None) – The compiler parameters to pass to the backend.
interpret (bool) – Whether to run the function in interpret mode.
debug (bool) – Whether or not to out helpful debugging information.
cost_estimate (CostEstimate | None) – The cost estimate of the function.
name (str | None) – The (optional) name of the kernel.
metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.