jax.extend.linear_util.transformation_with_aux2#

jax.extend.linear_util.transformation_with_aux2 = functools.partial(<class 'functools.partial'>, <function transformation_with_aux2>)[source]#

Adds one more transformation with auxiliary output to a WrappedFun.

Parameters:
Return type:

tuple[WrappedFun, Callable[[], Any]]