Calls the decorated function when the condition is met.
Parameters:
condition (bool | Array | ndarray | bool | number | int | float | complex | TypedNdArray) – If a boolean, this is equivalent to ifcondition:f(). If an
array, when produces a jax.lax.cond() with the decorated
function as the true branch.