jax.experimental.pallas.tpu.CompilerParams#

class jax.experimental.pallas.tpu.CompilerParams(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=KernelType.TC, disable_bounds_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None)[source]#

Mosaic TPU compiler parameters.

Parameters:
  • dimension_semantics (tuple[DimensionSemantics, ...] | None)

  • allow_input_fusion (tuple[bool, ...] | None)

  • vmem_limit_bytes (int | None)

  • collective_id (int | None)

  • has_side_effects (bool | SideEffectType)

  • flags (dict[str, Any] | None)

  • internal_scratch_in_bytes (int | None)

  • serialization_format (int)

  • kernel_type (KernelType)

  • disable_bounds_checks (bool)

  • skip_device_barrier (bool)

  • allow_collective_id_without_custom_barrier (bool)

  • shape_invariant_numerics (bool)

  • use_tc_tiling_on_sc (bool | None)

dimension_semantics#

A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “arbitrary” for dimensions that must be executed sequentially.

Type:

tuple[DimensionSemantics, …] | None

allow_input_fusion#

A list of booleans indicating whether input fusion is allowed for each argument.

Type:

tuple[bool, …] | None

vmem_limit_bytes#

Overrides the default VMEM limit for a kernel. Note that this must be used in conjunction with the –xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.

Type:

int | None

collective_id#

Indicates which barrier semaphore to use for the kernel. Note that using the same collective_id does not guarantee that the same barrier semaphore will be allocated between kernels.

Type:

int | None

has_side_effects#

Set to True to prevent kernel being CSEd by XLA.

Type:

bool | SideEffectType

flags#

A dictionary of command line flags for the kernel.

Type:

dict[str, Any] | None

internal_scratch_in_bytes#

The size of the internal scratch space used by Mosaic.

Type:

int | None

serialization_format#

The serialization format for the kernel body.

Type:

int

kernel_type#

Specify if the kernel is meant to run on TensorCore or one of the SparseCores

Type:

KernelType

disable_bounds_checks#

Disable bounds checks in the kernel.

Type:

bool

skip_device_barrier#

Skip the default device barrier for the kernel.

Type:

bool

allow_collective_id_without_custom_barrier#

Allow the use of collective_id without a custom barrier.

Type:

bool

use_tc_tiling_on_sc#

Use TensorCore tiling for SparseCore. This flag is only used for SC_*_SUBCORE kernels.

Type:

bool | None

__init__(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=KernelType.TC, disable_bounds_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None)[source]#
Parameters:
  • dimension_semantics (Sequence[DimensionSemantics] | None)

  • allow_input_fusion (Sequence[bool] | None)

  • vmem_limit_bytes (int | None)

  • collective_id (int | None)

  • has_side_effects (bool | SideEffectType)

  • flags (Mapping[str, Any] | None)

  • internal_scratch_in_bytes (int | None)

  • serialization_format (int)

  • kernel_type (KernelType)

  • disable_bounds_checks (bool)

  • skip_device_barrier (bool)

  • allow_collective_id_without_custom_barrier (bool)

  • shape_invariant_numerics (bool)

  • use_tc_tiling_on_sc (bool | None)

Methods

__init__([dimension_semantics, ...])

replace(**changes)

Return a new object replacing specified fields with new values.

Attributes

BACKEND

allow_collective_id_without_custom_barrier

allow_input_fusion

collective_id

dimension_semantics

disable_bounds_checks

flags

has_side_effects

internal_scratch_in_bytes

kernel_type

serialization_format

shape_invariant_numerics

skip_device_barrier

use_tc_tiling_on_sc

vmem_limit_bytes