Tracing#

jax.jit and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type. For a window into tracing, let’s put a few print() statements within a JIT-compiled function and then call the function:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = JitTracer(float32[3,4])
  y = JitTracer(float32[4])
  result = JitTracer(float32[3])
Array([-2.0893092,  1.7771789, -0.3445245], dtype=float32)

Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them (something like Traced<ShapedArray(float32[])>).

These tracer objects are what jax.jit uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the shape and dtype of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([-2.0261726,  6.5171046,  4.605681 ], dtype=float32)

The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr transformation:

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0:f32[]
    d:f32[4] = add b 1.0:f32[]
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values (see Control flow and logical operators with JIT). For example, this fails:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[4], line 5
      1 @jit
      2 def f(x, neg):
      3   return -x if neg else x
----> 5 f(1, True)

    [... skipping hidden 12 frame]

Cell In[4], line 3, in f(x, neg)
      1 @jit
      2 def f(x, neg):
----> 3   return -x if neg else x

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/core.py:1829, in concretization_function_error.<locals>.error(self, arg)
   1828 def error(self, arg):
-> 1829   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_3460/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:

f(1, False)
Array(1, dtype=int32, weak_type=True)

Static vs traced operations#

Just as values can be either static or traced, operations can be static or traced. Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.

This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 9
      6   return x.reshape(jnp.array(x.shape).prod())
      8 x = jnp.ones((2, 3))
----> 9 f(x)

    [... skipping hidden 12 frame]

Cell In[7], line 6, in f(x)
      4 @jit
      5 def f(x):
----> 6   return x.reshape(jnp.array(x.shape).prod())

    [... skipping hidden 2 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:461, in _compute_newshape(arr, newshape)
    459 except:
    460   newshape = [newshape]
--> 461 newshape = core.canonicalize_shape(newshape)  # type: ignore[arg-type]
    462 neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
    463 if len(neg1s) > 1:

File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/core.py:1985, in canonicalize_shape(shape, context)
   1983 except TypeError:
   1984   pass
-> 1985 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [JitTracer(int32[])].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_3460/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = reduce_prod[axes=(0,)] b
    from line /tmp/ipykernel_3460/1983583872.py:6:19 (f)

This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = JitTracer(float32[2,3])
x.shape = (2, 3)
jnp.array(x.shape).prod() = JitTracer(int32[])

Notice that although x is traced, x.shape is a static value. However, when we use jnp.array and jnp.prod on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape() that requires a static input (recall: array shapes must be static).

A useful pattern is to use numpy for operations that should be static (i.e. done at compile-time), and use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

For this reason, a standard convention in JAX programs is to import numpy as np and import jax.numpy as jnp so that both interfaces are available for finer control over whether operations are performed in a static manner (with numpy, once at compile-time) or a traced manner (with jax.numpy, optimized at run-time).

Understanding which values and operations will be static and which will be traced is a key part of using jax.jit effectively.

Different kinds of JAX values#

A tracer value carries an abstract value, e.g., ShapedArray with information about the shape and dtype of an array. We will refer here to such tracers as abstract tracers. Some tracers, e.g., those that are introduced for arguments of autodiff transformations, carry ConcreteArray abstract values that actually include the regular array data, and are used, e.g., for resolving conditionals. We will refer here to such tracers as concrete tracers. Tracer values computed from these concrete tracers, perhaps in combination with regular values, result in concrete tracers. A concrete value is either a regular value or a concrete tracer.

Typically, computations that involve at least a tracer value will produce a tracer value. There are very few exceptions, when a computation can be entirely done using the abstract value carried by a tracer, in which case the result can be a regular Python value. For example, getting the shape of a tracer with ShapedArray abstract value. Another example is when explicitly casting a concrete tracer value to a regular type, e.g., int(x) or x.astype(float). Another such situation is for bool(x), which produces a Python bool when concreteness makes it possible. That case is especially salient because of how often it arises in control flow.

Here is how the transformations introduce abstract or concrete tracers:

  • jax.jit(): introduces abstract tracers for all positional arguments except those denoted by static_argnums, which remain regular values.

  • jax.pmap(): introduces abstract tracers for all positional arguments except those denoted by static_broadcasted_argnums.

  • jax.vmap(), jax.make_jaxpr(), xla_computation(): introduce abstract tracers for all positional arguments.

  • jax.jvp() and jax.grad() introduce concrete tracers for all positional arguments. An exception is when these transformations are within an outer transformation and the actual arguments are themselves abstract tracers; in that case, the tracers introduced by the autodiff transformations are also abstract tracers.

  • All higher-order control-flow primitives (lax.cond(), lax.while_loop(), lax.fori_loop(), lax.scan()) when they process the functionals introduce abstract tracers, whether or not there is a JAX transformation in progress.

All of this is relevant when you have code that can operate only on regular Python values, such as code that has conditional control-flow based on data:

def divide(x, y):
  return x / y if y >= 1. else 0.

If we want to apply jax.jit(), we must ensure to specify static_argnums=1 to ensure y stays a regular value. This is due to the boolean expression y >= 1., which requires concrete values (regular or tracers). The same would happen if we write explicitly bool(y >= 1.), or int(y), or float(y).

Interestingly, jax.grad(divide)(3., 2.), works because jax.grad() uses concrete tracers, and resolves the conditional using the concrete value of y.