Key concepts#
This section briefly introduces some key concepts of the JAX package.
Transformations#
Along with functions to operate on arrays, JAX includes a number of transformations which operate on JAX functions. These include
jax.jit(): Just-in-time (JIT) compilation; see Just-in-time compilationjax.vmap(): Vectorizing transform; see Automatic vectorizationjax.grad(): Gradient transform; see Automatic differentiation
as well as several others. Transformations accept a function as an argument, and return a new transformed function. For example, here’s how you might JIT-compile a simple SELU function:
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05
Often you’ll see transformations applied using Python’s decorator syntax for convenience:
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
Tracing#
The magic behind transformations is the notion of a Tracer. Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order to extract the sequence of operations that the function encodes.
You can see this by printing any array value within transformed JAX code; for example:
@jax.jit
def f(x):
print(x)
return x + 1
x = jnp.arange(5)
result = f(x)
JitTracer(int32[5])
The value printed is not the array x, but a Tracer instance that
represents essential attributes of x, such as its shape and dtype. By executing
the function with traced values, JAX can determine the sequence of operations encoded
by the function before those operations are actually executed: transformations like
jit(), vmap(), and grad() can then map this sequence
of input operations to a transformed sequence of operations.
Static vs traced operations: Just as values can be either static or traced, operations can be static or traced. Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
For more details, see Tracing.
Jaxprs#
JAX has its own intermediate representation for sequences of operations, known as a jaxpr. A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations.
For example, consider the selu function we defined above:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
We can use the jax.make_jaxpr() utility to convert this function into a jaxpr
given a particular input:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0:f32[]
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558:f32[] c
e:f32[5] = sub d 1.6699999570846558:f32[]
f:f32[5] = jit[
name=_where
jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
f:f32[5] = select_n b e a
in (f,) }
] b a e
g:f32[5] = mul 1.0499999523162842:f32[] f
in (g,) }
Comparing this to the Python function definition, we see that it encodes the precise sequence of operations that the function represents. We’ll go into more depth about jaxprs later in JAX internals: The jaxpr language.
Pytrees#
JAX functions and transformations fundamentally operate on arrays, but in practice it is convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the pytree abstraction to treat such collections in a uniform manner.
Here are some examples of objects that can be treated as pytrees:
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX has a number of general-purpose utilities for working with PyTrees; for example
the functions jax.tree.map() can be used to map a function to every leaf in a
tree, and jax.tree.reduce() can be used to apply a reduction across the leaves
in a tree.
You can learn more in the Pytrees tutorial.
JAX API layering: NumPy, lax & XLA#
All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler. If you look at the source of jax.numpy, you’ll see that all the operations are eventually expressed in terms of functions defined in jax.lax. While jax.numpy is a high-level wrapper that provides a familiar interface, you can think of jax.lax as a stricter, but often more powerful, lower-level API for working with multi-dimensional arrays.
For example, while jax.numpy will implicitly promote arguments to allow operations between mixed data types, jax.lax will not:
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[10], line 2
1 from jax import lax
----> 2 lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/lax/lax.py:1161, in add(x, y)
1141 r"""Elementwise addition: :math:`x + y`.
1142
1143 This function lowers directly to the `stablehlo.add`_ operation.
(...) 1158 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add
1159 """
1160 x, y = core.standard_insert_pvary(x, y)
-> 1161 return add_p.bind(x, y)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/core.py:633, in Primitive.bind(self, *args, **params)
631 def bind(self, *args, **params):
632 args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 633 return self._true_bind(*args, **params)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/core.py:649, in Primitive._true_bind(self, *args, **params)
647 trace_ctx.set_trace(eval_trace)
648 try:
--> 649 return self.bind_with_trace(prev_trace, args, params)
650 finally:
651 trace_ctx.set_trace(prev_trace)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/core.py:661, in Primitive.bind_with_trace(self, trace, args, params)
659 with set_current_trace(trace):
660 return self.to_lojax(*args, **params) # type: ignore
--> 661 return trace.process_primitive(self, args, params)
662 trace.process_primitive(self, args, params) # may raise lojax error
663 raise Exception(f"couldn't apply typeof to args: {args}")
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/core.py:1210, in EvalTrace.process_primitive(self, primitive, args, params)
1208 args = map(full_lower, args)
1209 check_eval_args(args)
-> 1210 return primitive.impl(*args, **params)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/dispatch.py:91, in apply_primitive(prim, *args, **params)
89 prev = config.disable_jit.swap_local(False)
90 try:
---> 91 outs = fun(*args)
92 finally:
93 config.disable_jit.set_local(prev)
[... skipping hidden 25 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/31867/lib/python3.12/site-packages/jax/_src/lax/lax.py:8637, in check_same_dtypes(name, *avals)
8635 equiv = _JNP_FUNCTION_EQUIVALENTS[name]
8636 msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
-> 8637 raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))
TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).
If using jax.lax directly, you’ll have to do type promotion explicitly in such cases:
lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)
Along with this strictness, jax.lax also provides efficient APIs for some more general operations than are supported by NumPy.
For example, consider a 1D convolution, which can be expressed in NumPy this way:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
Under the hood, this NumPy operation is translated to a much more general convolution implemented by lax.conv_general_dilated:
from jax import lax
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See Convolutions in JAX for more detail on JAX convolutions).
At their heart, all jax.lax operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by XLA:ConvWithGeneralPadding.
Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation.