jax.checkpoint_policies.offload_dot_with_no_batch_dims#

checkpoint_policies.offload_dot_with_no_batch_dims(offload_dst)[source]#

Same as dots_with_no_batch_dims_saveable, but offload to CPU memory instead of recomputing.

This is a useful heuristic for transformers.