Autodidax2, part 1: JAX from scratch, again#
If you want to understand how JAX works you could trying reading the code. But the code is complicated, often for no good reason. This notebook presents a stripped-back version without the cruft. Itās a minimal version of JAX from first principles. Enjoy!
Main idea: context-sensitive interpretation#
JAX is two things:
a set of primitive operations (roughly the NumPy API)
a set of interpreters over those primitives (compilation, AD, etc.)
In this minimal version of JAX weāll start with just two primitive operations, addition and multiplication, and weāll add interpreters one by one. Suppose we have a user-defined function like this:
def foo(x):
return mul(x, add(x, 3.0))
We want to be able to interpret foo in different ways without changing its
implementation: we want to evaluate it on concrete values, differentiate it,
stage it out to an IR, compile it and so on.
Hereās how weāll do it. For each of these interpretations weāll define an
Interpreter object with a rule for handling each primitive operation. Weāll
keep track of the current interpreter using a global context variable. The
user-facing functions add and mul will dispatch to the current
interpreter. At the beginning of the program the current interpreter will be
the āevaluatingā interpreter which just evaluates the operations on ordinary
concrete data. Hereās what this all looks like so far.
from enum import Enum, auto
from contextlib import contextmanager
from typing import Any
# The full (closed) set of primitive operations
class Op(Enum):
add = auto() # addition on floats
mul = auto() # multiplication on floats
# Interpreters have rules for handling each primitive operation.
class Interpreter:
def interpret_op(self, op: Op, args: tuple[Any, ...]):
assert False, "subclass should implement this"
# Our first interpreter is the "evaluating interpreter" which performs ordinary
# concrete evaluation.
class EvalInterpreter:
def interpret_op(self, op, args):
assert all(isinstance(arg, float) for arg in args)
match op:
case Op.add:
x, y = args
return x + y
case Op.mul:
x, y = args
return x * y
case _:
raise ValueError(f"Unrecognized primitive op: {op}")
# The current interpreter is initially the evaluating interpreter.
current_interpreter = EvalInterpreter()
# A context manager for temporarily changing the current interpreter
@contextmanager
def set_interpreter(new_interpreter):
global current_interpreter
prev_interpreter = current_interpreter
try:
current_interpreter = new_interpreter
yield
finally:
current_interpreter = prev_interpreter
# The user-facing functions `mul` and `add` dispatch to the current interpreter.
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))
At this point we can call foo with ordinary concrete inputs and see the
results:
print(foo(2.0))
10.0
Aside: forward-mode automatic differentiation#
For our second interpreter weāre going to try forward-mode automatic differentiation (AD). Hereās a quick introduction to forward-mode AD in case this is the first time youāve come across it. Otherwise skip ahead to the āJVPInterprerā section.
Suppose weāre interested in the derivative of foo(x) evaluated at x=2.0.
We could approximate it with finite differences:
print((foo(2.00001) - foo(2.0)) / 0.00001)
7.000009999913458
The answer is close to 7.0 as expected. But computing it this way required two evaluations of the function (not to mention the roundoff error and truncation error). Hereās a funny thing though. We can almost get the answer with a single evaluation:
print(foo(2.00001))
10.0000700001
The answer weāre looking for, 7.0, is right there in the insignificant digits!
Hereās one way to think about whatās happening. The initial argument to foo,
2.00001, carries two pieces of data: a āprimalā value, 2.0, and a ātangentā
value, 1.0. The representation of this primal-tangent pair, 2.00001, is
the sum of the two, with the tangent scaled by a small fixed epsilon, 1e-5.
Ordinary evaluation of foo(2.00001) propagates this primal-tangent pair,
producing 10.0000700001 as the result. The primal and tangent components are
well separated in scale so we can visually interpret the result as the
primal-tangent pair (10.0, 7.0), ignoring the ~1e-10 truncation error at
the end.
The idea with forward-mode differentiation is to do the same thing but exactly and explicitly (eyeballing floats doesnāt really scale). Weāll represent the primal-tangent pair as an actual pair instead of folding them both into a single floating point number. For each primitive operation weāll have a rule that describes how to propagate these primal tangent pairs. Letās work out the rules for our two primitives.
Addition is easy. Consider x + y where x = xp + xt * eps and y = yp + yt * eps
(āpā for āprimalā, ātā for ātangentā):
x + y = (xp + xt * eps) + (yp + yt * eps)
= (xp + yp) # primal component
+ (xt + yt) * eps # tangent component
The result is a first-order polynomial in eps and we can read off the
primal-tangent pair as (xp + yp, xt + yt).
Multiplication is more interesting:
x * y = (xp + xt * eps) * (yp + yt * eps)
= (xp * yp) # primal component
+ (xp * yt + xt * yp) * eps # tangent component
+ (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit
Now we have a second order polynomial. But as epsilon goes to zero the
quadratic term vanishes and our primal-tangent pair
is just (xp * yp, xp * yt + xt * yp)
(In our earlier example with finite eps this term not vanishing is
why we had the 1e-10 ātruncation errorā).
Putting this into code, we can write down the forward-AD rules for addition
and multiplication and express foo in terms of these:
from dataclasses import dataclass
# A primal-tangent pair is conventionally called a "dual number"
@dataclass
class DualNumber:
primal : float
tangent : float
def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
return DualNumber(x.primal + y.primal, x.tangent + y.tangent)
def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)
def foo_dual(x : DualNumber) -> DualNumber:
return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))
print (foo_dual(DualNumber(2.0, 1.0)))
DualNumber(primal=10.0, tangent=7.0)
That works! But rewriting foo to use the _dual versions of addition and
multiplication was a bit tedious. Letās get back to the main program and use
our interpretation machinery to do the rewrite automatically.
JVP Interpreter#
Weāll set up a new interpreter called JVPInterpreter (āJVPā for
āJacobian-vector productā) which propagates these dual numbers instead of
ordinary values. The JVPInterpreter has methods āaddā and āmulā that operate
on dual number. They cast constant arguments to dual numbers as needed by
calling JVPInterpreter.lift. In our manually rewritten version above we did
that by replacing the literal 3.0 with DualNumber(3.0, 0.0).
# This is like DualNumber above except that is also has a pointer to the
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
# in higher order differentiation.
@dataclass
class TaggedDualNumber:
interpreter : Interpreter
primal : float
tangent : float
class JVPInterpreter(Interpreter):
def __init__(self, prev_interpreter: Interpreter):
# We keep a pointer to the interpreter that was current when this
# interpreter was first invoked. That's the context in which our
# rules should run.
self.prev_interpreter = prev_interpreter
def interpret_op(self, op, args):
args = tuple(self.lift(arg) for arg in args)
with set_interpreter(self.prev_interpreter):
match op:
case Op.add:
# Notice that we use `add` and `mul` here, which are the
# interpreter-dispatching functions defined earlier.
x, y = args
return self.dual_number(
add(x.primal, y.primal),
add(x.tangent, y.tangent))
case Op.mul:
x, y = args
x = self.lift(x)
y = self.lift(y)
return self.dual_number(
mul(x.primal, y.primal),
add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))
def dual_number(self, primal, tangent):
return TaggedDualNumber(self, primal, tangent)
# Lift a constant value (constant with respect to this interpreter) to
# a TaggedDualNumber.
def lift(self, x):
if isinstance(x, TaggedDualNumber) and x.interpreter is self:
return x
else:
return self.dual_number(x, 0.0)
def jvp(f, primal, tangent):
jvp_interpreter = JVPInterpreter(current_interpreter)
dual_number_in = jvp_interpreter.dual_number(primal, tangent)
with set_interpreter(jvp_interpreter):
result = f(dual_number_in)
dual_number_out = jvp_interpreter.lift(result)
return dual_number_out.primal, dual_number_out.tangent
# Let's try it out:
print(jvp(foo, 2.0, 1.0))
# Because we were careful to consider nesting interpreters, higher-order AD
# works out of the box:
def derivative(f, x):
_, tangent = jvp(f, x, 1.0)
return tangent
def nth_order_derivative(n, f, x):
if n == 0:
return f(x)
else:
return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
(10.0, 7.0)
print(nth_order_derivative(0, foo, 2.0))
10.0
print(nth_order_derivative(1, foo, 2.0))
7.0
print(nth_order_derivative(2, foo, 2.0))
2.0
# The rest are zero because `foo` is only a second-order polymonial
print(nth_order_derivative(3, foo, 2.0))
0.0
print(nth_order_derivative(4, foo, 2.0))
0.0
There are some subtleties worth discussing. First, how do you tell if
something is constant with respect to differentiation? Itās tempting to say
āitās a constant if and only if itās not a dual numberā. But actually dual
numbers created by a different JVPInterpreter also need to be considered
constants with respect to the JVPInterpreter weāre currently handling. Thatās
why we need the x.interpreter is self check in JVPInterpreter.lift. This
comes up in higher order differentiation when there are multiple JVPInterprers
in scope. The sort of bug where you accidentally interpret a dual number from
a different interpreter as non-constant is sometimes called āperturbation
confusionā in the literature. Hereās an example program that would have given
the wrong answer if we hadnāt had the and x.interpreter is self check in
JVPInterpreter.lift.
def f(x):
# g is constant in its (ignored) argument `y`. Its derivative should be zero
# but our AD will mess it up if we don't distinguish perturbations from
# different interpreters.
def g(y):
return x
should_be_zero = derivative(g, 0.0)
return mul(x, should_be_zero)
print(derivative(f, 0.0))
0.0
Another subtlety: JVPInterpreter.add and JVPInterpreter.mul describe
addition and multiplication on dual numbers in terms of addition and
multiplication on the primal and tangent components. But we donāt use ordinary
+ and * for this. Instead we use our own add and mul functions which
dispatch to the current interpreter. Before calling them we set the current
interpreter to be the previous interpreter, i.e. the interpreter that was
current when JVPInterpreter was first invoked. If we didnāt do this weād
have an infinite recursion, with add and mul dispatching to
JVPInterpreter endlessly. The advantage of using own add and mul instead
of ordinary + and * is that it means we can nest these interpreters and do
higher-order AD.
At this point you might be wondering: have we just reinvented operator
overloading? Python overloads the infix ops + and * to dispatch to the
argumentās __add__ and __mul__. Could we have just used that mechanism
instead of this whole interpreter business? Yes, actually. Indeed, the earlier
automatic differentiation (AD) literature uses the term āoperator overloadingā
to describe this style of AD implementation. One detail is that we canāt rely
exclusively on Python built-in overloading because that only lets us overload
a handful of built-in infix ops whereas we eventually want to overload
numpy-level operations like sin and cos. So we need our own mechanism.
But thereās a more important difference: our dispatch is based on context whereas traditional Python-style overloading is based on data. This is actually a recent development for JAX. The earliest versions of JAX looked more like traditional data-based overloading. An interpreter (a ātraceā in JAX jargon) for an operation would be chosen based on data attached to the arguments to that operation. Weāve gradually made the interpreter-dispatch decision rely more and more on context rather than data (omnistaging [link], stackless [link]). The reason to prefer context-based interpretation over data-based interpretation is that it makes the implementation much simpler.
All that said, we do also want to take advantage of Pythonās built-in
overloading mechanism. That way we get the syntactic convenience of using
infix operators + and * instead of writing out add(..) and mul(..).
But weāll put that aside for now.
3. Staging to an untyped IR#
The two program transformations weāve seen so far ā evaluation and JVP ā both traverse the input program from top to bottom. They visit the operations one by one in the same order as ordinary evaluation. A convenient thing about top-to-bottom transformations is that they can be implemented eagerly, or āonlineā, meaning that we can evaluate the program from top to bottom and perform the necessary transformations as we go. We never look at the entire program at once.
But not all transformations work this way. For example, dead-code elimination requires traversing from bottom to top, collecting usage statistics on the way up and eliminating pure operations whose results have no uses. Another bottom-to-top transformation is AD transposition, which we use to implement reverse-mode AD. For these we need to first āstageā the program into an IR (internal representation), a data structure representing the program, which we can then traverse in any order we like. Building this IR from a Python program will be the goal of our third and final interpreter.
First, letās define the IR. Weāll do an untypes ANF IR to start. A function (we call IR functions ājaxprsā in JAX) will have a list of formal parameters, a list of operations, and a return value. Each argument to an operation must be an āatomā, which is either a variable or a literal. The return value of the function is also an atom.
Var = str # Variables are just strings in this untyped IR
Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals
# Equation - a single line in our IR like `z = mul(x, y)`
@dataclass
class Equation:
var : Var # The variable name of the result
op : Op # The primitive operation we're applying
args : tuple[Atom] # The arguments we're applying the primitive operation to
# We call an IR function a "Jaxpr", for "JAX expression"
@dataclass
class Jaxpr:
parameters : list[Var] # The function's formal parameters (arguments)
equations : list[Equation] # The body of the function, a list of instructions/equations
return_val : Atom # The function's return value
def __str__(self):
lines = []
lines.append(', '.join(b for b in self.parameters) + ' ->')
for eqn in self.equations:
args_str = ', '.join(str(arg) for arg in eqn.args)
lines.append(f' {eqn.var} = {eqn.op}({args_str})')
lines.append(self.return_val)
return '\n'.join(lines)
To build the IR from a Python function we define a StagingInterpreter that
takes each operation and adds it to a growing list of all the operations weāve
seen so far:
class StagingInterpreter(Interpreter):
def __init__(self):
self.equations = [] # A mutable list of all the ops we've seen so far
self.name_counter = 0 # Counter for generating unique names
def fresh_var(self):
self.name_counter += 1
return "v_" + str(self.name_counter)
def interpret_op(self, op, args):
binder = self.fresh_var()
self.equations.append(Equation(binder, op, args))
return binder
def build_jaxpr(f, num_args):
interpreter = StagingInterpreter()
parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
with set_interpreter(interpreter):
result = f(*parameters)
return Jaxpr(parameters, interpreter.equations, result)
Now we can construct an IR for a Python program and print it out:
print(build_jaxpr(foo, 1))
v_1 ->
v_2 = Op.add(v_1, 3.0)
v_3 = Op.mul(v_1, v_2)
v_3
We can also evaluate our IR by writing an explicit interpreter that traverses the operations one by one:
def eval_jaxpr(jaxpr, args):
# An environment mapping variables to values
env = dict(zip(jaxpr.parameters, args))
def eval_atom(x): return env[x] if isinstance(x, Var) else x
for eqn in jaxpr.equations:
args = tuple(eval_atom(x) for x in eqn.args)
env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
return eval_atom(jaxpr.return_val)
print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
10.0
Weāve written this interpreter in terms of current_interpreter.interpret_op
which means weāve done a full round-trip: interpretable Python program to IR
to interpretable Python program. Since the result is āinterpretableā we can
differentiate it again, or stage it out or anything we like:
print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
(10.0, 7.0)
Up nextā¦#
Thatās it for part one of this tutorial. Weāve done two primitives, three interpreters and the tracing mechanism that weaves them together. In the next part weāll add types other than floats, error handling, compilation, reverse-mode AD and higher-order primitives. Note that the second part is structured differently. Rather than trying to have a top-to-bottom order that obeys both code dependencies (e.g. data structures need to be defined before theyāre used) and pedagogical dependencies (concepts need to be introduced before theyāre implemented) weāre going with a single file that can be approached in any order.