jax.extend.core.check_jaxpr

Contents

jax.extend.core.check_jaxpr#

jax.extend.core.check_jaxpr(jaxpr)[source]#

Checks well-formedness of a jaxpr.

Specifically, check that: - variables that are read are bound beforehand - variables are typed equally throughout a jaxpr - variable type annotations are compatible with their binding expression

Raises JaxprTypeError if jaxpr is determined invalid. Returns None otherwise.

Parameters:

jaxpr (Jaxpr)