jax.tree_util.tree_structure# jax.tree_util.tree_structure(tree, is_leaf=None)[source]# Alias of jax.tree.structure(). Parameters: tree (Any) is_leaf (None | Callable[[Any], bool]) Return type: PyTreeDef