jax.experimental.pallas.when#

jax.experimental.pallas.when(condition, /)[source]#

Calls the decorated function when the condition is met.

Parameters:

condition (bool | Array | ndarray | bool | number | int | float | complex | TypedNdArray) – If a boolean, this is equivalent to if condition: f(). If an array, when produces a jax.lax.cond() with the decorated function as the true branch.

Returns:

A decorator.

Return type:

Callable[[Callable[[], None]], Callable[[], None]]