jax.tree.static

Contents

jax.tree.static#

jax.tree.static(**kwargs)[source]#

Convenience wrapper to declare a static pytree attribute.

Arguments are the same as those of dataclasses.field(), but static() will automatically populate metadata with static = True, as used by jax.tree_util.register_dataclass().

Example

>>> import jax
>>> from dataclasses import dataclass
...
>>> @jax.tree_util.register_dataclass
... @dataclass
... class MyOp:
...   x: jax.Array
...   y: jax.Array
...   op: str = jax.tree.static(default="add")  # static string field
...
>>> m = MyOp(x=jnp.ones(3), y=jnp.arange(3))
>>> m
MyOp(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyOp[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyOp(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')