jax.checkpoint_policies.dots_with_no_batch_dims_saveable

jax.checkpoint_policies.dots_with_no_batch_dims_saveable#

checkpoint_policies.dots_with_no_batch_dims_saveable(*args, **params)[source]#

This is a useful heuristic for transformers.

Return type:

bool