jax.export.register_pytree_node_serialization#
- jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[source]#
Registers a custom PyTree node for serialization and deserialization.
You must use this function before you can serialize and deserialize PyTree nodes for the types not supported natively. We serialize PyTree nodes for the
in_treeandout_treefields ofExported, which are part of the exported function’s calling convention.This function must be called after calling
jax.tree_util.register_pytree_node()(except forcollections.namedtuple, which do not require a call toregister_pytree_node).- Parameters:
nodetype (type[T]) – the type whose PyTree nodes we want to serialize. It is an error to attempt to register multiple serializations for a
nodetype.serialized_name (str) – a string that will be present in the serialization and will be used to look up the registration during deserialization. It is an error to attempt to register multiple serializations for a
serialized_name.serialize_auxdata (_SerializeAuxData) – serialize the PyTree auxdata (returned by the
flatten_funcargument tojax.tree_util.register_pytree_node().).deserialize_auxdata (_DeserializeAuxData) – deserialize the auxdata that was serialized by the
serialize_auxdata.from_children (_BuildFromChildren | None) – if present, this is a function that takes that result of
deserialize_auxdataalong with some children and creates an instance ofnodetype. This is similar to theunflatten_funcpassed tojax.tree_util.register_pytree_node(). If not present, we look up and use theunflatten_func. This is needed forcollections.namedtuple, which does not have aregister_pytree_node, but it can be useful to override that function. Note that the result offrom_childrenis only used withjax.tree_util.tree_structure()to construct a proper PyTree node, it is not used to construct the outputs of the serialized function.
- Returns:
the same type passed as
nodetype, so that this function can be used as a class decorator.- Return type:
type[T]