jax.tree_util.tree_map# jax.tree_util.tree_map(f, tree, *rest, is_leaf=None)[source]# Alias of jax.tree.map(). Parameters: f (Callable[..., Any]) tree (Any) rest (Any) is_leaf (Callable[[Any], bool] | None) Return type: Any