Repro implementation details#

The challenge for repro extraction for JAX, compared to a regular compiler, is that JAX does not get the input as a data structure that we can save. Instead, we have to augment the JAX tracing mechanism to track which JAX API calls are being made by the user program, and what user functions JAX calls while tracing the program. The repro tracker (in tracker.py) constructs a representation of the call tree. Then the repro emitter (in emitter.py) outputs a pure JAX program that would result in the same call tree.

How do we track?#

First, we wrap the JAX API functions that take user functions as arguments, e.g., jax.jit, jax.vmap. We do this by adding a repro_api_name to the existing traceback_util.api_boundary annotation. This annotation was already present in most places we needed it, but we had to add it in a few places where it was missing, e.g., in lax.loops.while_loop. Whenever we call one of these annotated APIs, we scan the arguments looking for callables, and we wrap those as well. One goal would be to emit repro code for these callables.

We use the class tracker.Func to wrap callables of interest. They are of several kinds:

  • JAX API functions. These are constructed for the JAX API entry points annotated with repro_api_name.

  • USER functions. These are constructed for callables passed to JAX API functions.

  • JAX non-API functions. These are constructed for callables returned by JAX API functions, e.g., the returned value from jax.jit. Note: this kind of functions will go away, see below.

When one of the tracker functions is called, we construct a tracker.Call object that has references to the Func that was called, the actual arguments and results of the call (these would be actual tracers, or constants, or even non-JAX values for the static arguments). The call objects for user functions have a body, which is a list of calls to JAX functions that the user function makes.

Furthermore, we modified the core.Primitive._true_bind method to call into the repro source code with the primitive and its arguments. If this call happens while we are currently in a user call, we record the primitive.

Thus, the call objects for a user function will contain a list of calls to JAX functions and to primitives.

Dealing with JAX caches#

How do we emit?#

TO EXPLAIN …

How do we reduce?#

TO DO …