jax.extend.core.subjaxprs

Contents

jax.extend.core.subjaxprs#

jax.extend.core.subjaxprs(jaxpr)[source]#

Generator for all subjaxprs found in the params of jaxpr.eqns. Does not descend recursively into the found subjaxprs.

Parameters:

jaxpr (Jaxpr)

Return type:

Iterator[Jaxpr]