jax.experimental.pallas.tpu.set_tpu_interpret_mode#
- jax.experimental.pallas.tpu.set_tpu_interpret_mode(params=InterpretParams(detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False))[source]#
- Parameters:
params (InterpretParams)