jax.experimental.pallas.tpu.InterpretParams#

class jax.experimental.pallas.tpu.InterpretParams(*, detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False)[source]#

Parameters for TPU interpret mode.

TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating a TPU’s shared memory (HBM, VMEM, etc.), communication (remote and local DMAs), and synchronization operations (semaphores, barriers, etc.). This mode is intended for debugging and testing.

To run a kernel under TPU interpret mode, pass an instance of InterpretParams as an argument for the interpret parameter of jax.experimental.pallas.pallas_call() or jax.experimental.pallas.core_map().

NOTE: If an exception is raised while interpreting a kernel, you must call reset_tpu_interpret_mode_state() before using TPU interpret mode again in the same process.

Parameters:
  • detect_races (bool)

  • out_of_bounds_reads (Literal['raise', 'uninitialized'])

  • skip_floating_point_ops (bool)

  • uninitialized_memory (Literal['nan', 'zero'])

  • num_cores_or_threads_per_device (int)

  • vector_clock_size (int | None)

  • dma_execution_mode (Literal['eager', 'on_wait'])

  • random_seed (int | None)

  • grid_point_recorder (Callable[[tuple[int32, ...], int32], None] | None)

  • allow_hbm_allocation_in_run_scoped (bool)

dma_execution_mode#

If “eager”, DMAs are executed as soon as they are issued. If “on_wait”, DMA reads or writes are only executed when a device is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: “on_wait”.

Type:

Literal[‘eager’, ‘on_wait’]

random_seed#

Seed for random number generator used during interpretation. Currently random numbers are used to randomize the grid coordinates along dimensions with ‘parallel’ semantics. Default: None.

Type:

int | None

grid_point_recorder#

Callback that is invoked by the interpreter for each grid point in the order in which the grid points are traversed. The callback is invoked with two arguments: - A tuple of grid coordinates. - The local core ID of the core that is processing the grid point. This callback is intended for inspecting - the randomization of coordinates along grid dimensions with ‘parallel’ semantics and - the mapping of grid points to local (i.e. per-device) cores. Default: None.

Type:

collections.abc.Callable[[tuple[numpy.int32, 
], numpy.int32], None] | None

allow_hbm_allocation_in_run_scoped#

If True, allows the allocation of HBM buffers (which are then shared across the cores in a device) in run_scoped. While this behavior can be enabled in the interpreter, allocating HBM buffers with run_scoped is not supported when executing Pallas kernels on a real TPU. Default: False.

Type:

bool

__init__(*, detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads_per_device=1, vector_clock_size=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False)#
Parameters:
  • detect_races (bool)

  • out_of_bounds_reads (Literal['raise', 'uninitialized'])

  • skip_floating_point_ops (bool)

  • uninitialized_memory (Literal['nan', 'zero'])

  • num_cores_or_threads_per_device (int)

  • vector_clock_size (int | None)

  • dma_execution_mode (Literal['eager', 'on_wait'])

  • random_seed (int | None)

  • grid_point_recorder (Callable[[tuple[int32, ...], int32], None] | None)

  • allow_hbm_allocation_in_run_scoped (bool)

Return type:

None

Methods

__init__(*[, detect_races, ...])

get_uninitialized_array(shape, dtype)

get_vector_clock_size(num_devices)

Returns the number of vector clocks to use.`

pad_to_block_dimension(value, block_shape)

Pads values so the shape evenly divides into block dimensions.

Attributes

allow_hbm_allocation_in_run_scoped

detect_races

dma_execution_mode

grid_point_recorder

num_cores_or_threads_per_device

num_cores_per_device

out_of_bounds_reads

random_seed

skip_floating_point_ops

uninitialized_memory

vector_clock_size