jax.export module#
jax.export is a library for exporting and serializing JAX functions
for persistent archival.
See the Exporting and serialization documentation.
Classes#
- class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[source]#
A JAX function lowered to StableHLO.
- Parameters:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- in_tree#
a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX function. The actual lowering does not depend on the
in_tree, but this can be used to invoke the exported function using the same argument structure.- Type:
tree_util.PyTreeDef
- in_avals#
the flat tuple of input abstract values. May contain dimension expressions in the shapes.
- Type:
tuple[core.ShapedArray, …]
- out_tree#
a PyTreeDef describing the result of the lowered JAX function.
- Type:
tree_util.PyTreeDef
- out_avals#
the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in
in_avals.- Type:
tuple[core.ShapedArray, …]
- in_shardings_hlo#
the flattened input shardings, a sequence as long as
in_avals.Nonemeans unspecified sharding. Note that these do not include the mesh or the actual devices used in the mesh. Seein_shardings_jaxfor a way to turn these into sharding specification that can be used with JAX APIs.- Type:
tuple[HloSharding | None, …]
- out_shardings_hlo#
the flattened output shardings, a sequence as long as
out_avals.Nonemeans unspecified sharding. Note that these do not include the mesh or the actual devices used in the mesh. Seeout_shardings_jaxfor a way to turn these into sharding specification that can be used with JAX APIs.- Type:
tuple[HloSharding | None, …]
- platforms#
a tuple containing the platforms for which the function should be exported. The set of platforms in JAX is open-ended; users can add platforms. JAX built-in platforms are: ‘tpu’, ‘cpu’, ‘cuda’, ‘rocm’. See https://docs.jax.dev/en/latest/export/export.html#cross-platform-and-multi-platform-export.
- ordered_effects#
the ordered effects present in the serialized module. This is present from serialization version 9. See https://docs.jax.dev/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects.
- Type:
tuple[effects.Effect, …]
- unordered_effects#
the unordered effects present in the serialized module. This is present from serialization version 9.
- Type:
tuple[effects.Effect, …]
- calling_convention_version#
a version number for the calling convention of the exported module. See more versioning details at https://docs.jax.dev/en/latest/export/export.html#calling-convention-versions.
- Type:
- module_kept_var_idx#
the sorted indices of the arguments among in_avals that must be passed to the module. The other arguments have been dropped because they are not used.
- uses_global_constants#
whether the
mlir_module_serializeduses shape polymorphism or multi-platform export. This may be becausein_avalscontains dimension variables, or due to inner calls of Exported modules that have dimension variables or platform index arguments. Such modules need shape refinement before XLA compilation.- Type:
- disabled_safety_checks#
a list of descriptors of safety checks that have been disabled at export time. See docstring for
DisabledSafetyCheck.- Type:
Sequence[DisabledSafetyCheck]
- _get_vjp#
an optional function that takes the current exported function and returns the exported VJP function. The VJP function takes a flat list of arguments, starting with the primal arguments and followed by a cotangent argument for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs.
See a description of the calling convention for the
mlir_module()method at https://docs.jax.dev/en/latest/export/export.html#module-calling-convention.- call(*args, **kwargs)[source]#
Call an exported function from a JAX program.
- Parameters:
args – the positional arguments to pass to the exported function. This should be a pytree of arrays with the same pytree structure as the arguments for which the function was exported.
kwargs – the keyword arguments to pass to the exported function.
- Returns: a pytree of result array, with the same structure as the
results of the exported function.
The invocation supports reverse-mode AD, and all the features supported by exporting: shape polymorphism, multi-platform, device polymorphism. See the examples in the [JAX export documentation](https://docs.jax.dev/en/latest/export/export.html).
- in_shardings_jax(mesh)[source]#
Creates Shardings corresponding to
self.in_shardings_hlo.The Exported object stores
in_shardings_hloas HloShardings, which are independent of a mesh or set of devices. This method constructs Sharding that can be used in JAX APIs such asjax.jit()orjax.device_put().Example usage:
>>> from jax import export, sharding >>> # Prepare the exported object: >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) ... )(np.arange(jax.device_count())) >>> exp.in_shardings_hlo ({devices=[8]<=[8]},) >>> # Create a mesh for running the exported object >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) >>> # Put the args and kwargs on the appropriate devices >>> run_arg = jax.device_put(np.arange(jax.device_count()), ... exp.in_shardings_jax(run_mesh)[0]) >>> res = exp.call(run_arg) >>> res.addressable_shards [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
- Parameters:
mesh (mesh_lib.Mesh)
- Return type:
Sequence[sharding.Sharding | None]
- out_shardings_jax(mesh)[source]#
Creates Shardings corresponding to
self.out_shardings_hlo.See documentation for in_shardings_jax.
- Parameters:
mesh (mesh_lib.Mesh)
- Return type:
Sequence[sharding.Sharding | None]
- serialize(vjp_order=0)[source]#
Serializes an Exported.
- Parameters:
vjp_order (int) – The maximum vjp order to include. E.g., the value 2 means that we serialize the primal functions and two orders of the
vjpfunction. This should allow 2nd order reverse mode differentiation of the deserialized function. i.e.,jax.grad(jax.grad(f)).- Return type:
- class jax.export.DisabledSafetyCheck(_impl)[source]#
A safety check that should be skipped on (de)serialization.
Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, e.g., as a sequence of string attributes to
jax.export.Exportedor oftf.XlaCallModuleOp.When using jax2tf, you can disable more deserialization safety checks by passing
TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform.- Parameters:
_impl (str)
- classmethod custom_call(target_name)[source]#
Allows the serialization of a call target not known to be stable.
Has effect only on serialization. :param target_name: the name of the custom call target to allow.
- Parameters:
target_name (str)
- Return type:
- is_custom_call()[source]#
Returns the custom call target allowed by this directive.
- Return type:
str | None
Functions#
|
Exports a JAX function for persistent serialization. |
|
Deserializes an Exported. |
int([x]) -> integer int(x, base=10) -> integer |
|
int([x]) -> integer int(x, base=10) -> integer |
|
Retrieves the default export platform. |
|
|
Registers a custom PyTree node for serialization and deserialization. |
|
Registers a namedtuple for serialization and deserialization. |
Constants#
- jax.export.minimum_supported_serialization_version#
The minimum supported serialization version; see Calling convention versions.
- jax.export.maximum_supported_serialization_version#
The maximum supported serialization version; see Calling convention versions.