jax.experimental.pallas.tpu.force_tpu_interpret_mode#

jax.experimental.pallas.tpu.force_tpu_interpret_mode(params=InterpretParams(detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False))[source]#

Context manager that forces TPU interpret mode under its dynamic context.

TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating a TPU’s shared memory (HBM, VMEM, etc.), communication (remote and local DMAs), and synchronization operations (semaphores, barriers, etc.). This mode is intended for debugging and testing. See InterpretParams for additional information.

Parameters:

params (InterpretParams) – an instance of InterpretParams. Any call to jax.experimental.pallas.pallas_call() or jax.experimental.pallas.core_map() that is traced under this context manager will be run with interpret=params. When params is not None, this will cause those calls to run with TPU interpret mode.