jax.tree_util.is_tree_node

Contents

jax.tree_util.is_tree_node#

jax.tree_util.is_tree_node(typ)[source]#

Returns True if the type is a registered PyTree node type.

Parameters:

typ (type) – The type to check.

Returns:

True if the type is a registered PyTree node type (built-in or custom) or a namedtuple type.

Return type:

bool