jax.experimental.pallas.core_map#

jax.experimental.pallas.core_map(mesh, *, compiler_params=None, interpret=False, debug=False, cost_estimate=None, name=None, metadata=None)[source]#

Runs a function on a mesh, mapping it over the devices in the mesh.

The function should be stateful in that it takes in no inputs and returns no outputs but can mutate closed-over Refs, for example.

Parameters:
  • mesh – The mesh to run the function on.

  • compiler_params (Any | None) – The compiler parameters to pass to the backend.

  • interpret (bool) – Whether to run the function in interpret mode.

  • debug (bool) – Whether or not to out helpful debugging information.

  • cost_estimate (CostEstimate | None) – The cost estimate of the function.

  • name (str | None) – The (optional) name of the kernel.

  • metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.