jax.extend.core.check_jaxpr#
- jax.extend.core.check_jaxpr(jaxpr)[source]#
Checks well-formedness of a jaxpr.
Specifically, check that: - variables that are read are bound beforehand - variables are typed equally throughout a jaxpr - variable type annotations are compatible with their binding expression
Raises JaxprTypeError if jaxpr is determined invalid. Returns None otherwise.
- Parameters:
jaxpr (Jaxpr)