jax.extend.linear_util.transformation_with_aux2#
- jax.extend.linear_util.transformation_with_aux2 = functools.partial(<class 'functools.partial'>, <function transformation_with_aux2>)[source]#
Adds one more transformation with auxiliary output to a WrappedFun.
- Parameters:
fun (WrappedFun)
use_eq_store (bool)
- Return type:
tuple[WrappedFun, Callable[[], Any]]