Generating reproducers for JAX errors#
WARNING: this code is experimental, expect changes or deletion.
Have you encountered a hard-to-debug JAX error in a large user program, perhaps using several other libraries on top of JAX? Do you believe that there is a small and pure JAX program, without additional layers of libraries, that reproduces the same error? The reproducer tool aims to help you produce such a program.
Summary:
Usage#
If the JAX_REPRO_DIR is set to a directory, and JAX encounters an
uncaught exception under a JAX API call, e.g., jax.jit, it will attempt
to save in that directory a Python source file that when executed should
reproduce the error. That source file is standalone except for
importing jax.
JAX will track the sequence of nested JAX API calls, capturing all
user-functions that JAX traces, their calls to JAX APIs, and so on.
In large JAX programs we have seen the call depth grow to 30 or so.
If an uncaught exception arises, then we save a repro that should result in
the same call tree, and hopefully can reproduce the error.
One can get the path and source code of the saved repro by
calling repro.last_saved_repro().
The above use case is the “implicit” repro generation, when all you need
to do is to set JAX_REPRO_DIR, and encounter an error.
You can also generate repros “explicitly”, even in absence of errors,
by using the repro.collect function transformation:
fn = repro.collect(fn, repro_name="name")
# Once we trace `fn` we save `$JAX_REPRO_DIR/name_<counter>.py`
In the explicit usage mode, all tracked JAX API invocations are saved in the repro. In the implicit use (save on uncaught exception), successful JAX calls at the top-level are not retained.
One can think of this mechanism as a way to stage out a pure JAX program
from a large JAX program.
This is somewhat similar to staging out a Jaxpr with
jax.jit(f).trace(*args).jaxpr (or the old jax.make_jaxpr), except that:
it produces Python source code rather that a dump of a Jaxpr, which should be more readable, more editable, and can be executed directly, e.g., in a debugger.
it works even if there are errors encountered before a Jaxprs is produced; the repro source may reproduce the error.
the repro is higher-level than the Jaxpr. E.g., instead of seeing the
lax.scan_pprimitive with its low-level details, you will see a call tolax.scan. The higher-order primitives in JAX often have complicated parameters, and sometimes even references to Python callables. Furthermore, some JAX transformations, e.g.,vmaporjvp, do not stage a Jaxpr, and the first Jaxpr produced will reflect the result of the transformations. In contrast, the repro source will contains calls tojax.vmapandjax.jvp.In Jaxpr the arguments are passed in a flat list, while in repros we retain standard PyTrees, with the same dictionary keys as in the user program.
Configuration options#
This section is very likely to change.
There are two main configuration flags:
JAX_REPRO_DIRdenotes the directory where reproducers are saved. A non-empty value also triggers the tracking of the call tree, so that a reproducer is saved on error. It can bespongefor use in internal Google tests.JAX_REPRO_FLAGScontains comma-separated flags that configure details of repro generation. You can specify a flag without a value, in which case it takes a default value, e.g.,True, or you can specify a value using=value. For example,log_calls,log_traceback_frames=10.log_calls(default 0). An integer value that controls the repro tracking logging (for debugging the repro module). The recognized values are: 0 (no logging, default), 1 (log all calls except the JAX primitive.bind), 2 (log all calls).log_call_details(default “”). A sequence of call ids for which to log more details (for debugging the repro module). E.g.,log_call_details=3+5+6.error_mode(default “defer”). Configures the handling of repro collection and generation errors. The possible values are:“ignore”
“log” – the errors are logged as
logging.error. Each error message containslog_traceback_framesstack frames.“defer” – the errors are logged and at the end of the explicit repro collection a
repro.ReproErrorwill be generated.“raise” – a
repro.ReproErroris raised when the first error appears.
log_traceback_frames(default 40) how many frames from the traceback to show.fake_array_threshold(default 128) arrays with.size()larger than this value are replaced withnp.oneswith the right shape and dtype. Smaller arrays are emitted asnparray literals.
Limitations#
So many …
Not all errors will be reproduced by this mechanism:
Errors from numpy won’t be captured,
Errors from the jax.numpy layer are not captured. E.g., the rank checking happens in the jax.numpy layer, so it won’t be reproduced by this version of repros, but the shape checking happens after binding the JAX primitives, so it will be reproducible.
The repro contains some argument preprocessing, e.g., the
static_argnumsare removed. Errors involving the handling ofstatic_argnumswon’t be reproduced.We currently do not try to preserve the exact data arrays, and for arrays larger than a threshold we will use np.ones.
jax.named_callis not handled currently (treated as a noop)We do not capture some of the context managers that control the lowered code, e.g.,
set_xla_metadata. We do captureset_meshthough.
During call tracking we attempt to alter the execution as little as possible. There are a few known differences though:
JAX cache tracing is foiled, so you are going to see functions traced multiple times. E.g., JAX will normally avoid tracing a function repeatedly with similar arguments, but when repros are enabled this cache is disabled. This will result in slower tracing, and for functions with side-effects it may even alter the result or side-effects of the program.
when using
jax.custom_vjp, when we do higher-order differentiation the custom fwd and bwd functions are being called more than once, and we assume that the subsequent calls will produce the same Jaxpr as earlier ones.we emit jax.Array as np.ndarray (e.g., loosing sharding)
if we don’t recognize the jax.checkpoint policy param, we print a warning and we use
dots_saveableas a replacement.
Design#
JAX higher-order APIs#
There are a few different categories of JAX APIs. Here are some examples:
higher-order functions that return only arrays, e.g.
jax.lax.cond: Arr x (Arr -> Arr) x (Arr -> Arr) -> Arr.higher-order functions that return only a single function e.g.
jax.vmap: (Arr -> Arr) -> (Arr -> Arr).higher-order function that return arrays and functions e.g.
jax.vjp: (Arr -> Arr) x Arr -> Arr x (Arr -> Arr).
APIs that return only arrays#
The simplest higher-order JAX APIs, such as jax.lax.cond and
jax.lax.scan, return only arrays, not functions. For example, the type of
jax.lax.cond is:
jax.lax.cond: Arr x (Arr -> Arr) x (Arr -> Arr) -> Arr
Once we intercept a call to jax.lax.cond we assume that its callable arguments
are user functions for which we need to synthesize Python source code. We wrap
those functions so that we can intercept when they are called, and then we call
the actual jax.lax.cond. We will notice when JAX traces the wrapped user
functions and we record the arguments. We also record all the calls to
the first-order JAX primitives made from user functions.
It is possible that the user functions call
jax.lax.cond themselves. Therefore, we record a call tree with the root being
the top-level call to jax.lax.cond. On each path through the tree, JAX API
calls alternate with calls to user functions. The leaves of the call tree are
calls to first-order JAX primitives. We call this process that happens during
tracing “tracking”.
If we need to emit a repro (on an uncaught exception or if explicitly requested) we can turn the call tree into source code. For more details, see Repro emitting.
APIs that return only a single function#
A more complicated form of higher-order APIs are function transformations,
which return a single first-order function. For example, jax.vmap,
jax.jit, jax.grad, jax.shard_map, and many others.
The type of jax.vmap is:
jax.vmap: (Arr -> Arr) -> (Arr -> Arr)
These are more complicated to handle because the result of jax.vmap(fun)
may be invoked multiple times, and each invocation may trigger
re-tracing of the user function fun. In the most general case, there
could be conditionals inside fun that result in different executions.
We must be ready to generate different function definitions for fun for
each invocation of jax.vmap(fun)(...).
We reduce these function transformations to the previous case by uncurrying
them. For example, we rewrite jax.vmap to a two-argument function vmap_call:
vmap_call: (Arr -> Arr) x Arr -> Arr
Thus a program:
def fun(...): ...
vf = jax.vmap(fun)
vf(a1)
vf(a2)
would be rewritten as
vmap_call(fun, a1)
vmap_call(fun, a2)
We handle this program in a similar way to jax.lax.cond (or other APIs that
return arrays). This allows us to generate a separate source for fun for
each of the calls.
We assume that for a JAX API that returns to the user program a single
function, the only thing that the user program can do with it is to call it.
This is not tecnically true in JAX, e.g., the return from jax.jit has some
extra fields for manipulating caches, or for invoking the AOT APIs, .lower(),
.trace(), etc. We have a few bespoke workarounds to handle those cases.
APIs that return a tuple of arrays and functions#
A small number of JAX APIs return tuples of functions and arrays. For example,
jax.vjp:
jax.vjp: (Arr -> Arr) x Arr -> Arr x (Arr -> Arr)
jax.vjp returns the primal output for the function passed as the first
argument at the point specified by the second argument, and it also returns
the function to compute the cotangent of the primal input given the
cotangent of the output.
The general case#
For the general case of a JAX higher-order API jaxapi of the form:
*out_arrays, *out_funcs = jaxapi(*in_arrays, *in_funcs)
where out_arrays is non-empty or out_funcs is not a singleton.
We make the following assumptions (see later for a discussion of some
exceptions):
jaxapiis annotated withapi_boundarywith an optionalmap_user_func_argsto identify thein_funcsfunction arguments.if one of these annotations is missing, we are going to see user code calling directly into the tracing machinery. The repro tracking machinery will give an error.
the
in_funcsare considered references to user functions for which we must generate reproducers.jaxapiinvokes each user function at most once, for tracing (with tracer arguments).jaxapidoes not leak the user functions to the user code, i.e., the wrapped user functions it took as arguments can only be invoked internally byjaxapi.in an earlier implementation of the handling of
fuser.push_block_specI forgot to traverse the returnedpallas.BlockSpecto wrap theindex_mapfunctions and I ended up with spurious calls to wrapped user functions directly from user code.
if
out_funcsis not empty, then they are first-order functions. They are considered JAX functions, and as such are not allowed to further invoke the user functions passed tojaxapi.
There are a few other more complicated cases, e.g., custom_vjp, or the
fuser APIs, but they can be reduced to the cases above with a bit of argument
and result rewriting.
Trampolines#
For many higher-order JAX APIs we use a system of trampolines to rewrite their
calls to helper functions that help the reproducers. These trampolines are
defined in trampolines.api_trampolines (indexed by repro_api_name) and
are processed by api_boundary. The trampolines typically redirect the call to
helper functions defined in repro_api.py. The reproducer code will contain
references to the functions in repro_api.py, not to the original JAX APIs.
For example, for jax.lax.cond we use a trampoline repro_api.cond for the
sole purpose of being able to specify wrap_user_func_args without putting it
in the main JAX sources. So, in the main sources we annotate jax.lax.cond
with api_boundary(repro_api_name="jax.lax.cond"). Then we define a trampoline
that redirects the call to repro_api.cond, on which we put
a api_boundary(repro_api_name="cond", wrap_user_func_args=...).
The reproducer source contains calls to repro_api.cond, which allows us
to synthesize reproducers for reproducers (this is handy, see below).
Another class of trampolines are the uncurrying trampolines, which we use
for APIs that return a single function, such as jax.vmap. The trampoline
for jax.vmap will rewrite the call to repro_api.vmap_call.
We explain here the naming convention for the repro_api_name parameter
passed to api_boundary and used by the trampolines:
repro_api_nameis the name by which we expect user code to reference the function. It starts withjax.(orpallas.orfuser.). In any case, there is generally a.in the name.Some of the APIs do not have a trampoline, e.g.,
lax.scatter_add. The reproducer code for this uses the exactrepro_api_name. The generated repro sources import the filerepro_runtime.pyto provide the right names.The name of the
repro_apitrampoline function is formed by taking therepro_api_name, removing thejax.prefix if present, and perhaps append_callif it is an uncurried trampoline. E.g.,jax.vmapbecomesrepro_api.vmap_call.The functions in
repro_api.pyalso have their ownapi_boundaryannotations, usingrepro_api_nameas the name of the function inrepro_api.py, e.g.,vmap.
Handling caches#
One of the trickiest issues is to account for the JAX internal caches. JAX has several levels of caches, e.g., to avoid tracing a function multiple times for the same types and shapes of arguments.
Consider the following example:
def fun(x): ...
j_fun = jax.jit(fun)
y0 = j_fun(0)
y1 = j_fun(0.)
y2 = j_fun(1) # hits the `fun` tracing cache for `j_fun(0)`
which is rewritted using trampolines to:
def fun(x): ...
y0 = jit_call(fun, 0)
y1 = jit_call(fun, 0.)
y2 = jit_call(fun, 1)
On the 3rd call to jit_call we will not observe a tracing call to fun, and
we don’t know which of several previous invocations of fun to use here.
The current solution is to foil the caches. The call to jit_call will
create a fresh wrapper for fun which will ensure that each separate call
to jit_call will miss the cache at least once.
TODO:
A significant issue with this solution is that tracing in presence of reproducers will be slower, possibly much slower.
Also, we would generate a much bigger reproducer source, with possibly many identical copies generated for a user function. This could be fixed by detecting and squashing identical definitions.
Handling PyTrees#
For each call, we store the arguments and results. These would be tracers for the user functions, but may be arrays or other objects. We cannot store these directly in the call trees:
If these are user PyTrees we cannot generate code to construct them in the repro, because the constructors could depend on user code that we cannot trace. To remedy this, we normalize arguments and results by traversing them and keeping the standard containers (tuples, lists, dictionaries, None) and flattening one level of the user PyTrees as tuples.
However, we keep some JAX PyTrees for which we have special emitter rules, e.g., lax.GatherDimensionNumbers.
if these are mutable values, they may be mutated by the user program. Normalization ensures that we make copies of lists and dictionaries.
We considered completely flattening the arguments and results, but that
turned out to be unnecessary and removes one level of readability of the
repros, e.g., by turning nested dictionaries with user-readable key names
into very large tuples. Furthermore, this would require some processing of
parameters such as nondiff_argnums to refer to flat tuple elements.
Handling statics#
A few JAX APIs allow “static” values that are not captured by JAX tracing.
E.g., jax.jit can declare that some arguments are static, which means that
they are passed to the user function but are not captured in the Jaxpr.
For reproducers we must also not capture statics because some of them may
involve complex user data structures that we don’t want to have in a
standalone reproducer.
For example, in the following program:
class MyData: ...
def f(x: jax.Array, extra: MyData) -> jax.Array: ...
jax.jit(f, static_argnums=(1,))(x)
we would like to see in the reproducer something like:
def f(x: jax.Array) -> jax.Array: ...
jit_call(f, x)
TODO: explain how we deal with statics.
Data structure#
Tracking collects the following data structure (defined in tracker.py):
a
FunctionDefrepresents the definition of a user function that was passed to one of the higher-order JAX APIs, e.g., tojax.jit, or even torepro.collect. It contains abodyfield with the list ofStatements that were encountered during the JAX tracing of that function.a
Statementrepresents a call to a bind of a JAX primitive (e.g.,add_p.bind(...)), or a call to a JAX API (e.g.jax.sin(...)). There is a top-level statement, which is either the call torepro.collectwhen we are doing explicit collection, or the top-level JAX API call being traced (for implicit collection).
Both FunctionDef and Statement are subclasses of Call which provides
common fields:
parent- for aStatementthis is theFunctionDefthat contains it, for aFunctionDefthis is theStatementthat called the higher-order JAX function that had this user-function as one of its arguments.id- a unique index countingCallconstructors, for debugging.level- 1 more than the level of theparent. We logCalls as “[. ]”. You can enable this logging with log_calls.args,kwargs,resultthese are normalized objects that were used at the call site. These are typically PyTress of tracers.
We also use object of type Func to represent wrapped functions:
either a JAX API, in which case
api_nameis the name of the API andis_user==False.or a user function that was passed to a higher-order API, in which case
is_user==Trueandfunction_defpoints to theFunctionDef(once JAX calls the function).of a JAX function that was returned from a higher-order API that is not curried, e.g.,
jax.vjp.
We currently maintain the following invariants and assumptions:
The
StatementandFunctionDefform a call tree with aStatementat the root, and alternatingStatementandFunctionDefon any path. The leaves of the tree areStatements representing calls to first-order primitives.All higher-order JAX APIs are annotated with
traceback_util.api_boundarywith arepro_api_namefield. We currently support about 50 such API functions. If we forget to annotate a higher-order API, and user-level code calls it, you may get several kinds of errors:“USER function calls directly into tracing” - if a user-level function calls this unannotated API and JAX starts tracing some function.
“Binding primitive scan_p containing Jaxprs or functions, for f” - This would happen, e.g., if
jax.scanwere not annotated.
We expect that when a higher-order JAX API is called with a user function, that user function is invoked exactly once, to be traced (with
jax.core.Tracerarguments). If the user function is called more than once you will see a warning “Ignoring additional invocation to USER function”.It is possible that we don’t see a call for some USER functions, if JAX does not trace them. E.g., in
jax.custom_vjpwe pass a user function for the forward pass and one for the backwards, but the latter will never be traced if we do not differentiate the function.If the USER function passed to a higher-order API leaks the USER function back to the user-level code, you may see the additional invocation warning or a “USER call made from USER parent” error message.
This is true for the JAX APIs, but in one earlier attempt we have tried to wrap the
BlockSpec.index_mapas a USER function in theBlockSpecconstructor. This was a problem because theBlockSpecis accessible to the user code and there are programs that call the index map themselves. A further problem was that this could be done before the repro collection starts withrepro.collector, which resets the tracking state and we end up reusing the ids for theBlockSpec.index_mapobjects in later calls. This confused the repro generation code.
User functions can call
jax.numpyandlaxoperations, which will result in binding one or more JAX primitives. We intercept the calls tocore.Primitive.bindto collect such calls.User functions cannot bind directly higher-order primitives, e.g.,
scan_p. We expect that they call instead one of the JAX higher-order APIs, e.g.,jax.scan.We ignore all calls to JAX functions (primitives or JAX APIs) if they are made from a
Statement. This can happen when we are inside JAX and we calljax.jitorlax.scaninternally. The same mechanism will ignore binding of higher-order primitives, e.g.,lax.scan_p, because those are bound always from a JAX API call (since we annotate all higher-order JAX APIs).
Repro emitting#
We generate code for a FunctionDef by emitting a sequence of calls either
to bind a first-order primitive, or to functions like repro_api.vmap_call.
We introduce function arguments and local variables for all tracers that appear
as args for the FunctionDef and as the result of the Statements in the
body. These are then passed as the args of Statement and the result of
the enclosing FunctionDef.
JAX primitives can take a variety of parameters of internal types, such as
jax.NamedSharding. Most of the emitter.py module defines functions that
can emit code to construct these internal data structures.
Emitting starts from the top-level Statement, which typically is a call
to a JAX API such as repro_api.jit_call. As we emit the code for the
arguments we encounter USER Func objects, for which we must generate code
for their FunctionDef.
The main difficulty when emitting is to handle nested function definitions. In the following emitted repro the names have a numeric suffix that denotes the order in which they are generated:
def sin3(x4):
v5 = repro_api.bind("sin", x4)
return v5
def f1(x2):
y6 = jit_call(sin3, x2)
def g7(y8):
v9 = y8 + x2 # x2 is external
return v9
z10 = jit_call(g7, y6)
return z10
v11 = jit_call(f1, 5)
As we emit the FunctionDef of f1, we encounter first a call to jit_call(sin),
and we generate the body for sin. Since this body does not have external
references we can emit it at top-level.
For the call jit_call(g7) we emit the body of g7 and we observe an
external reference to x2. We cannot lift the body of g7 to top level,
it must be as high as possible while its references are in scope.
Development#
Until I can merge this somewhere upstream I am evolving it in a personal branch. I keep it as a stack of commits, with the base one adding the infrastructure and a bunch of basic APIs (jit, grad, vmap, etc.) and subsequent ones adding more exotic APIs: various Pallas APIs, fuser, hijax. Those are still in flux. It is also useful to see what it takes to add support for a new API.