jax.lax.optimization_barrier#

jax.lax.optimization_barrier(operand, /)[source]#

Prevents the compiler from moving operations across the barrier.

Optimization barriers have a number of possible uses:

  • An optimization barrier ensures that every output of the barrier that is used by any operator, has been evaluated before any operator that depends on one of the barrier’s outputs. This can be used to enforce a particular order of operations.

    Note that all operands must be used through the barrier for this to work. There are no ordering constraints between an operator that uses one of the barrier’s outputs, and an operator that directly (not through the barrier) uses one of the barrier’s inputs.

  • An optimization barrier prevents common subexpression elimination. This is used by JAX to implement rematerialization.

  • Optimization barriers prevent compiler fusions. That is, operations before the barrier may not be fused into the same kernel as operations after the barrier by the compiler.

JAX does not define derivative or batching rules for an optimization barrier.

Optimization barriers have no effect outside a compiled function.

Parameters:

operand – a pytree of JAX values.

Returns:

A pytree of JAX values, with the same structure and contents as operand.

Examples

Prevents common-subexpression elimination between the two calls to sin:

>>> def f(x):
...   return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
>>> jax.jit(f)(0.)
Array(0., dtype=float32, weak_type=True)