jax.tree.reduce#
- jax.tree.reduce(function, tree, initializer=<jax._src.tree_util.Unspecified object>, is_leaf=None)[source]#
Call reduce() over the leaves of a tree.
- Parameters:
function (Callable[[T, Any], T]) – the reduction function
tree (Any) – the pytree to reduce over
initializer (T | tree_util.Unspecified) – the optional initial value
is_leaf (Callable[[Any], bool] | None) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
- Returns:
the reduced value.
- Return type:
result
Examples
>>> import jax >>> import operator >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) 21
Notes
Tip: You can exclude leaves from the reduction by first mapping them to
Noneusingjax.tree.map(). This causes them to not be counted as leaves after that.