jax.experimental.pallas.tpu.force_tpu_interpret_mode#
- jax.experimental.pallas.tpu.force_tpu_interpret_mode(params=InterpretParams(detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False))[source]#
Context manager that forces TPU interpret mode under its dynamic context.
TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating a TPU’s shared memory (HBM, VMEM, etc.), communication (remote and local DMAs), and synchronization operations (semaphores, barriers, etc.). This mode is intended for debugging and testing. See
InterpretParamsfor additional information.- Parameters:
params (InterpretParams) – an instance of
InterpretParams. Any call tojax.experimental.pallas.pallas_call()orjax.experimental.pallas.core_map()that is traced under this context manager will be run withinterpret=params. Whenparamsis notNone, this will cause those calls to run with TPU interpret mode.