jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims#

checkpoint_policies.checkpoint_dots_with_no_batch_dims(*args, **params)[source]#

This is a useful heuristic for transformers.

Return type:

bool