Ref: mutable arrays for data plumbing and memory control#
JAX Arrays are immutable, representing mathematical values. Immutability can
make code easier to reason about, and is useful for optimized compilation,
parallelization, rematerialization, and transformations like autodiff.
But immutability is constraining too:
expressiveness — plumbing out intermediate data or maintaining state, e.g. for normalization statistics or metrics, can feel heavyweight;
performance — it’s more difficult to reason about performance, like memory lifetimes and in-place updates.
Refs can help! They represent mutable arrays that can be read and written
in-place. These array references are compatible with JAX transformations, like
jax.jit and jax.grad:
import jax
import jax.numpy as jnp
x_ref = jax.new_ref(jnp.zeros(3)) # new array ref, with initial value [0., 0., 0.]
@jax.jit
def f():
x_ref[1] += 1. # indexed add-update
print(x_ref) # Ref([0., 0., 0.])
f()
f()
print(x_ref) # Ref([0., 2., 0.])
Ref([0., 0., 0.], dtype=float32)
Ref([0., 2., 0.], dtype=float32)
The indexing syntax follows NumPy’s. For a Ref called x_ref, we can
read its entire value into an Array by writing x_ref[...], and write its
entire value using x_ref[...] = A for some Array-valued expression A:
def g(x):
x_ref = jax.new_ref(0.)
x_ref[...] = jnp.sin(x)
return x_ref[...]
print(jax.grad(g)(1.0)) # 0.54
0.5403023
Ref is a distinct type from Array, and it comes with some important
constraints and limitations. In particular, indexed reading and writing is just
about the only thing you can do with an Ref. References can’t be passed
where Arrays are expected:
x_ref = jax.new_ref(1.0)
try:
jnp.sin(x_ref) # error! can't do math on refs
except Exception as e:
print(e)
Attempting to pass a Ref Ref{float32[]} to a primitive: sin -- did you forget to unpack ([...]) the ref?
To do math, you need to read the ref’s value first, like jnp.sin(x_ref[...]).
So what can you do with Ref? Read on for the details, and some useful
recipes.
API#
If you’ve ever used
Pallas, then Ref
should look familiar. A big difference is that you can create new Refs
yourself directly using jax.new_ref:
from jax import Array, Ref
def array_ref(init_val: Array) -> Ref:
"""Introduce a new reference with given initial value."""
jax.freeze is its antithesis, invalidating the given ref (so that accessing it
afterwards is an error) and producing its final value:
def freeze(ref: Ref) -> Array:
"""Invalidate given reference and produce its final value."""
In between creating and destroying them, you can perform indexed reads and
writes on refs. You can read and write using the functions jax.ref.get and
jax.ref.swap, but usually you’d just use NumPy-style array indexing syntax:
import types
Index = int | slice | Array | types.EllipsisType
Indexer = Index | tuple[Index, ...]
def get(ref: Ref, idx: Indexer) -> Array:
"""Returns `ref[idx]` for NumPy-style indexer `idx`."""
def swap(ref: Ref, idx: Indexer, val: Array) -> Array:
"""Performs `newval, ref[idx] = ref[idx], val` and returns `newval`."""
Here, Indexer can be any NumPy indexing expression:
x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4))
# int indexing
row = x_ref[0]
x_ref[1] = row
# tuple indexing
val = x_ref[1, 2]
x_ref[2, 3] = val
# slice indexing
col = x_ref[:, 1]
x_ref[0, :3] = col
# advanced int array indexing
vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])]
x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals
As with Arrays, indexing mostly follows NumPy behavior, except for
out-of-bounds indexing which behaves in the usual way for JAX
Arrays.
Pure and impure functions#
A function that takes a ref as an argument (either explicitly or by lexical closure) is considered impure. For example:
# takes ref as an argument => impure
@jax.jit
def impure1(x_ref, y_ref):
x_ref[...] = y_ref[...]
# closes over ref => impure
y_ref = jax.new_ref(0)
@jax.jit
def impure2(x):
y_ref[...] = x
If a function only uses refs internally, it is still considered pure. Purity is in the eye of the caller. For example:
# internal refs => still pure
@jax.jit
def pure1(x):
ref = jax.new_ref(x)
ref[...] = ref[...] + ref[...]
return ref[...]
Pure functions, even those that use refs internally, are familiar: for example,
they work with transformations like jax.grad, jax.vmap, jax.shard_map, and
others in the usual way.
Impure functions are sequenced in Python program order.
Restrictions#
Refs are second-class, in the sense that there are restrictions on their
use:
Can’t return refs from
jit-decorated functions or the bodies of higher-order primitives likejax.lax.scan,jax.lax.while_loop, orjax.lax.condCan’t pass a ref as an argument more than once to
jit-decorated functions or higher-order primitivesCan only
freezein creation scopeNo higher-order refs (refs-to-refs)
For example, these are errors:
x_ref = jax.new_ref(0.)
# can't return refs
@jax.jit
def err1(x_ref):
x_ref[...] = 5.
return x_ref # error!
try:
err1(x_ref)
except Exception as e:
print(e)
# can't pass a ref as an argument more than once
@jax.jit
def err2(x_ref, y_ref):
...
try:
err2(x_ref, x_ref) # error!
except Exception as e:
print(e)
# can't pass and close over the same ref
@jax.jit
def err3(y_ref):
y_ref[...] = x_ref[...]
try:
err3(x_ref) # error!
except Exception as e:
print(e)
# can only freeze in creation scope
@jax.jit
def err4(x_ref):
jax.freeze(x_ref)
try:
err4(x_ref) # error!
except Exception as e:
print(e)
function err1 at /tmp/ipykernel_1387/3915325362.py:4 traced for jit returned a mutable array reference of type Ref{float32[]}, but mutable array references cannot be returned.
The returned mutable array was passed in as the argument x_ref.
only one reference to a mutable array may be passed as an argument to a function, but when tracing err2 at /tmp/ipykernel_1387/3915325362.py:14 for jit the mutable array reference of type Ref{float32[]} appeared at both x_ref and y_ref.
when tracing err3 at /tmp/ipykernel_1387/3915325362.py:23 for jit, a mutable array reference of type Ref{float32[]} was both closed over and passed as the argument y_ref
list index out of range
These restrictions exist to rule out aliasing, where two refs might refer to the same mutable memory, making programs harder to reason about and transform. Weaker restrictions would also suffice, so some of these restrictions may be lifted as we improve JAX’s ability to verify that no aliasing is present.
There are also restrictions stemming from undefined semantics, e.g. in the presence of parallelism or rematerialization:
Can’t
vmaporshard_mapa function that closes over refsCan’t apply
jax.remat/jax.checkpointto an impure function
For example, here are ways you can and can’t use vmap with impure functions:
# vmap over ref args is okay
def dist(x, y, out_ref):
assert x.ndim == y.ndim == 1
assert out_ref.ndim == 0
out_ref[...] = jnp.sum((x - y) ** 2)
vecs = jnp.arange(12.).reshape(3, 4)
out_ref = jax.new_ref(jnp.zeros((3, 3)))
jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref) # ok!
print(out_ref)
Ref([[ 0., 64., 256.],
[ 64., 0., 64.],
[256., 64., 0.]], dtype=float32)
# vmap with a closed-over ref is not
x_ref = jax.new_ref(0.)
def err5(x):
x_ref[...] = x
try:
jax.vmap(err5)(jnp.arange(3.)) # error!
except Exception as e:
print(e)
performing a set/swap operation with vmapped value on an unbatched array reference of type Ref{float32[]}. Move the array reference to be an argument to the vmapped function?
The latter is an error because it’s not clear which value x_ref should be
after we run jax.vmap(err5).
Refs and automatic differentiation#
Autodiff can be applied to pure functions as before, even if they use array refs internally. For example:
@jax.jit
def pure2(x):
ref = jax.new_ref(x)
ref[...] = ref[...] + ref[...]
return ref[...]
print(jax.grad(pure1)(3.0)) # 2.0
2.0
Autodiff can also be applied to functions that take array refs as arguments, if those arguments are only used for plumbing and not involved in differentiation:
# error
def err6(x, some_plumbing_ref):
y = x + x
some_plumbing_ref[...] += y
return y
# fine
def foo(x, some_plumbing_ref):
y = x + x
some_plumbing_ref[...] += jax.lax.stop_gradient(y)
return y
You can combine plumbing refs with custom_vjp to plumb data out of the
backward pass of a differentiated function:
# First, define the helper `stash_grads`:
@jax.custom_vjp
def stash_grads(grads_ref, x):
return x
def stash_grads_fwd(grads_ref, x):
return x, grads_ref
def stash_grads_bwd(grads_ref, g):
grads_ref[...] = g
return None, g
stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)
# Now, use `stash_grads` to stash intermediate gradients:
def f(x, grads_ref):
x = jnp.sin(x)
x = stash_grads(grads_ref, x)
return x
grads_ref = jax.new_ref(0.)
f(1., grads_ref)
print(grads_ref)
Ref(0., dtype=float32, weak_type=True)
Notice stash_grads_fwd is returning a Ref here. That’s a special
allowance for custom_vjp fwd rules: it’s really syntax for indicating which
ref arguments should be shared by both the fwd and bwd rules. So any refs
returned by a fwd rule must be arguments to that fwd rule.
Refs and performance#
At the top level, when calling jit-decorated functions, Refs obviate
the need for donation, since they are effectively always donated:
@jax.jit
def sin_inplace(x_ref):
x_ref[...] = jnp.sin(x_ref[...])
x_ref = jax.new_ref(jnp.arange(3.))
print(x_ref.unsafe_buffer_pointer(), x_ref)
sin_inplace(x_ref)
print(x_ref.unsafe_buffer_pointer(), x_ref)
94787283736256 Ref([0., 1., 2.], dtype=float32)
94787283736256 Ref([0. , 0.84147096, 0.9092974 ], dtype=float32)
Here sin_inplace operates in-place, updating the buffer backing x_ref so
that its address stays the same.
Under a jit, you should expect array references to point to fixed buffer
addresses, and for indexed updates to be performed in-place.
Temporary caveat: dispatch from Python to impure jit-compiled functions
that take Ref inputs is currently slower than dispatch to pure
jit-compiled functions, since it takes a less optimized path.
foreach, a new way to write scan#
As you may know, jax.lax.scan is a loop construct with a built-in fixed access
pattern for scanned-over inputs and outputs. The access pattern is built in for
autodiff reasons: if we were instead to slice into immutable inputs directly,
reverse-mode autodiff would end up creating one-hot gradients and summing them
up, which can be asymptotically inefficient. See Sec 5.3.3 of the Dex
paper.
But reading slices of Refs doesn’t have this efficiency problem: when we
apply reverse-mode autodiff, we always generate in-place accumulation
operations. As a result, we no longer need to be constrained by scan’s fixed
access pattern. We can write more flexible loops, e.g. with non-sequential
access.
Moreover, having mutation available allows for some syntax tricks, like in this
recipe for a foreach decorator:
import jax
import jax.numpy as jnp
from jax.lax import scan
def foreach(*args):
def decorator(body):
return scan(lambda _, elts: (None, body(*elts)), None, args)[1]
return decorator
r = jax.new_ref(0)
xs = jnp.arange(10)
@foreach(xs)
def ys(x):
r[...] += x
return x * 2
print(r) # Ref(45, dtype=int32)
print(ys) # [ 0 2 4 6 8 10 12 14 16 18]
Ref(45, dtype=int32)
[ 0 2 4 6 8 10 12 14 16 18]
Here, the loop runs immediately, updating r in-place and binding ys to be
the mapped result.