jax.experimental.pallas.tpu.InterpretParams#

class jax.experimental.pallas.tpu.InterpretParams(*, detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads=1, vector_clock_size=None, logging_mode=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False, buffer_bounds='logical')[source]#

Parameters for TPU interpret mode.

TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating a TPU’s shared memory (HBM, VMEM, etc.), communication (remote and local DMAs), and synchronization operations (semaphores, barriers, etc.). This mode is intended for debugging and testing.

To run a kernel under TPU interpret mode, pass an instance of InterpretParams as an argument for the interpret parameter of jax.experimental.pallas.pallas_call() or jax.experimental.pallas.core_map().

NOTE: If an exception is raised while interpreting a kernel, you must call reset_tpu_interpret_mode_state() before using TPU interpret mode again in the same process.

Parameters:
  • detect_races (bool)

  • out_of_bounds_reads (Literal['raise', 'uninitialized'])

  • skip_floating_point_ops (bool)

  • uninitialized_memory (Literal['nan', 'zero'])

  • num_cores_or_threads (int)

  • vector_clock_size (int | None)

  • logging_mode (LoggingMode | None)

  • dma_execution_mode (Literal['eager', 'on_wait'])

  • random_seed (int | None)

  • grid_point_recorder (Callable[[tuple[int32, ...], int32], None] | None)

  • allow_hbm_allocation_in_run_scoped (bool)

  • buffer_bounds (Literal['logical', 'padded'])

dma_execution_mode#

If “eager”, DMAs are executed as soon as they are issued. If “on_wait”, DMA reads or writes are only executed when a device is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: “on_wait”.

Type:

Literal[‘eager’, ‘on_wait’]

random_seed#

Seed for random number generator used during interpretation. Currently random numbers are used to randomize the grid coordinates along dimensions with ‘parallel’ semantics. Default: None.

Type:

int | None

grid_point_recorder#

Callback that is invoked by the interpreter for each grid point in the order in which the grid points are traversed. The callback is invoked with two arguments: - A tuple of grid coordinates. - The local core ID of the core that is processing the grid point. This callback is intended for inspecting - the randomization of coordinates along grid dimensions with ‘parallel’ semantics and - the mapping of grid points to local (i.e. per-device) cores. Default: None.

Type:

collections.abc.Callable[[tuple[numpy.int32, 
], numpy.int32], None] | None

allow_hbm_allocation_in_run_scoped#

If True, allows the allocation of HBM buffers (which are then shared across the cores in a device) in run_scoped. While this behavior can be enabled in the interpreter, allocating HBM buffers with run_scoped is not supported when executing Pallas kernels on a real TPU. Default: False.

Type:

bool

buffer_bounds#

If “padded”, reads and writes of a buffer are only considered out-of-bounds when they go beyond the ‘padded shape’ of a buffer. The amount of padding is determined by the TPU device kind that is being simulated by TPU interpret mode (to be set with jax.sharding.use_abstract_mesh in the context from where the interpreter Pallas kernel is called). Any part of a read that is outside of the buffer’s shape but inside the padded shape returns uninitialized values (see the “uninitialized_memory” attribute of the superclass SharedInterpretParams). Any part of a write that is outside of the buffer’s shape but inside the padded shape is ignored. If “logical”, reads and writes are considered out-of-bounds when outside of the buffer’s logical shape. Default: “logical”.

Type:

Literal[‘logical’, ‘padded’]

__init__(*, detect_races=False, out_of_bounds_reads='raise', skip_floating_point_ops=False, uninitialized_memory='nan', num_cores_or_threads=1, vector_clock_size=None, logging_mode=None, dma_execution_mode='on_wait', random_seed=None, grid_point_recorder=None, allow_hbm_allocation_in_run_scoped=False, buffer_bounds='logical')#
Parameters:
  • detect_races (bool)

  • out_of_bounds_reads (Literal['raise', 'uninitialized'])

  • skip_floating_point_ops (bool)

  • uninitialized_memory (Literal['nan', 'zero'])

  • num_cores_or_threads (int)

  • vector_clock_size (int | None)

  • logging_mode (LoggingMode | None)

  • dma_execution_mode (Literal['eager', 'on_wait'])

  • random_seed (int | None)

  • grid_point_recorder (Callable[[tuple[int32, ...], int32], None] | None)

  • allow_hbm_allocation_in_run_scoped (bool)

  • buffer_bounds (Literal['logical', 'padded'])

Return type:

None

Methods

__init__(*[, detect_races, ...])

get_vector_clock_size(num_devices)

Returns the number of vector clocks to use for TPU interpret mode.`

Attributes

allow_hbm_allocation_in_run_scoped

buffer_bounds

detect_races

dma_execution_mode

grid_point_recorder

logging_mode

num_cores_or_threads

num_cores_per_device

out_of_bounds_reads

random_seed

skip_floating_point_ops

uninitialized_memory

vector_clock_size