jax.array_garbage_collection_guard

jax.array_garbage_collection_guard#

jax.array_garbage_collection_guard = <jax._src.config.State object>#

Context manager for jax_array_garbage_collection_guard config option.

Select garbage collection guard level for jax.Array objects.

This option can be used to control what happens when a jax.Array object is garbage collected. It is desirable for jax.Array objects to be freed by Python reference counting rather than garbage collection in order to avoid device memory being held by the arrays until garbage collection occurs.

Valid values are:

  • allow: do not log garbage collection of jax.Array objects.

  • log: log an error when a jax.Array is garbage collected.

  • fatal: fatal error if a jax.Array is garbage collected.

Default is allow. Note that not all cycles may be detected.

Parameters:

new_val (Any)