jax.checkpoint_policies.offload_dot_with_no_batch_dims# checkpoint_policies.offload_dot_with_no_batch_dims(offload_dst)[source]# Same as dots_with_no_batch_dims_saveable, but offload to CPU memory instead of recomputing. This is a useful heuristic for transformers.