jax.tree.static#
- jax.tree.static(**kwargs)[source]#
Convenience wrapper to declare a static pytree attribute.
Arguments are the same as those of
dataclasses.field(), butstatic()will automatically populate metadata with static = True, as used byjax.tree_util.register_dataclass().Example
>>> import jax >>> from dataclasses import dataclass ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyOp: ... x: jax.Array ... y: jax.Array ... op: str = jax.tree.static(default="add") # static string field ... >>> m = MyOp(x=jnp.ones(3), y=jnp.arange(3)) >>> m MyOp(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef PyTreeDef(CustomNode(MyOp[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves) MyOp(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
See also