jax.extend.linear_util module#

Callable()

StoreException

WrappedFun(f, f_transformed, transforms, ...)

Represents a function f to which transforms are to be applied.

cache(call, *[, explain])

Memoization decorator for functions taking a WrappedFun as first argument.

merge_linear_aux(aux1, aux2)

transformation

transformation2

Adds one more transformation to a WrappedFun.

transformation_with_aux

transformation_with_aux2

Adds one more transformation with auxiliary output to a WrappedFun.

wrap_init(f[, params, debug_info])