jax.set_mesh#

class jax.set_mesh(mesh)[source]#

Sets a concrete mesh in a thread-local context.

jax.set_mesh has dual behavior. You can use it as a global setter or as a context manager.

When a mesh is in context via jax.set_mesh, you can use pass raw PartitionSpecs to all APIs that accept sharding as an argument. Using jax.set_mesh is also required for enabling explicit sharding mode: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html

For example:

mesh = jax.make_mesh((2,), ('x',))
jax.set_mesh(mesh)  # use the API as a global setter

with jax.set_mesh(mesh):  # use the API as a context manager
  ...

Note: jax.set_mesh can only be used outside of jax.jit.

Parameters:

mesh (mesh_lib.Mesh)

__init__(mesh)[source]#
Parameters:

mesh (mesh_lib.Mesh)

Methods

__init__(mesh)

Attributes

prev_abstract_mesh

prev_mesh