jax.set_mesh#
- class jax.set_mesh(mesh)[source]#
Sets a concrete mesh in a thread-local context.
jax.set_meshhas dual behavior. You can use it as a global setter or as a context manager.When a mesh is in context via
jax.set_mesh, you can use pass raw PartitionSpecs to all APIs that accept sharding as an argument. Usingjax.set_meshis also required for enabling explicit sharding mode: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.htmlFor example:
mesh = jax.make_mesh((2,), ('x',)) jax.set_mesh(mesh) # use the API as a global setter with jax.set_mesh(mesh): # use the API as a context manager ...
Note:
jax.set_meshcan only be used outside ofjax.jit.- Parameters:
mesh (mesh_lib.Mesh)
Methods
__init__(mesh)Attributes
prev_abstract_meshprev_mesh