jax.experimental.pallas.tpu.reset_tpu_interpret_mode_state#

jax.experimental.pallas.tpu.reset_tpu_interpret_mode_state()[source]#

Resets all global, shared state used by TPU interpret mode.

TPU interpret mode uses global, shared state for simulating memory buffers and semaphores, for race detection, etc., when interpreting a kernel. Normally, this shared state is cleaned up after a kernel is interpreted.

But if an exception is thrown while interpreting a kernel, the shared state is not cleaned up, allowing the simulated TPU state to be examined for debugging purposes. In this case, the shared state must be reset before any further kernels are interpreted.