jax.experimental.pallas.tpu.set_tpu_interpret_mode#
- jax.experimental.pallas.tpu.set_tpu_interpret_mode(params=InterpretParams(detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads=1, vector_clock_size=None, logging_mode=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False, buffer_bounds='logical'))[source]#
- Parameters:
params (InterpretParams)