jax.experimental.pallas.tpu.InterpretParams#
- class jax.experimental.pallas.tpu.InterpretParams(*, detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False)[source]#
Parameters for TPU interpret mode.
TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating a TPUâs shared memory (HBM, VMEM, etc.), communication (remote and local DMAs), and synchronization operations (semaphores, barriers, etc.). This mode is intended for debugging and testing.
To run a kernel under TPU interpret mode, pass an instance of
InterpretParamsas an argument for theinterpretparameter ofjax.experimental.pallas.pallas_call()orjax.experimental.pallas.core_map().NOTE: If an exception is raised while interpreting a kernel, you must call
reset_tpu_interpret_mode_state()before using TPU interpret mode again in the same process.- Parameters:
detect_races (bool)
out_of_bounds_reads (Literal['raise', 'uninitialized'])
skip_floating_point_ops (bool)
uninitialized_memory (Literal['nan', 'zero'])
num_cores_or_threads_per_device (int)
vector_clock_size (int | None)
dma_execution_mode (Literal['eager', 'on_wait'])
random_seed (int | None)
grid_point_recorder (Callable[[tuple[int32, ...], int32], None] | None)
allow_hbm_allocation_in_run_scoped (bool)
- dma_execution_mode#
If âeagerâ, DMAs are executed as soon as they are issued. If âon_waitâ, DMA reads or writes are only executed when a device is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: âon_waitâ.
- Type:
Literal[âeagerâ, âon_waitâ]
- random_seed#
Seed for random number generator used during interpretation. Currently random numbers are used to randomize the grid coordinates along dimensions with âparallelâ semantics. Default: None.
- Type:
int | None
- grid_point_recorder#
Callback that is invoked by the interpreter for each grid point in the order in which the grid points are traversed. The callback is invoked with two arguments: - A tuple of grid coordinates. - The local core ID of the core that is processing the grid point. This callback is intended for inspecting - the randomization of coordinates along grid dimensions with âparallelâ semantics and - the mapping of grid points to local (i.e. per-device) cores. Default: None.
- Type:
collections.abc.Callable[[tuple[numpy.int32, âŠ], numpy.int32], None] | None
- allow_hbm_allocation_in_run_scoped#
If True, allows the allocation of HBM buffers (which are then shared across the cores in a device) in run_scoped. While this behavior can be enabled in the interpreter, allocating HBM buffers with run_scoped is not supported when executing Pallas kernels on a real TPU. Default: False.
- Type:
- __init__(*, detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False)#
- Parameters:
detect_races (bool)
out_of_bounds_reads (Literal['raise', 'uninitialized'])
skip_floating_point_ops (bool)
uninitialized_memory (Literal['nan', 'zero'])
num_cores_or_threads_per_device (int)
vector_clock_size (int | None)
dma_execution_mode (Literal['eager', 'on_wait'])
random_seed (int | None)
grid_point_recorder (Callable[[tuple[int32, ...], int32], None] | None)
allow_hbm_allocation_in_run_scoped (bool)
- Return type:
None
Methods
__init__(*[, detect_races, ...])get_uninitialized_array(shape, dtype)get_vector_clock_size(num_devices)Returns the number of vector clocks to use.`
pad_to_block_dimension(value, block_shape)Pads values so the shape evenly divides into block dimensions.
Attributes
detect_racesnum_cores_or_threads_per_devicenum_cores_per_deviceout_of_bounds_readsskip_floating_point_opsuninitialized_memoryvector_clock_size