jax.set_mesh

jax.set_mesh#

class jax.set_mesh(mesh)[source]#

Sets a concrete mesh in a thread-local context.

jax.set_mesh has dual behavior. You can use it as a global setter or as a context manager.

When a mesh is in context via jax.set_mesh, you can use pass raw PartitionSpecs to all APIs that accept sharding as an argument. Using jax.set_mesh is also required for enabling explicit sharding mode: https://docs.jax.dev/en/latest/parallel.html

For example:

mesh = jax.make_mesh((2,), ('x',))
jax.set_mesh(mesh)  # use the API as a global setter

with jax.set_mesh(mesh):  # use the API as a context manager
  ...

Note: jax.set_mesh can only be used outside of jax.jit.

Parameters:

mesh (Mesh | None)

__init__(mesh)[source]#
Parameters:

mesh (Mesh | None)

Methods

__init__(mesh)

Attributes

prev_abstract_mesh

prev_mesh