jax.experimental.pallas.tpu.CompilerParams#

class jax.experimental.pallas.tpu.CompilerParams(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=TC, disable_bounds_checks=False, disable_semaphore_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None, needs_layout_passes=None, fuse_transposed_lhs_in_matmul=False)[source]#

Mosaic TPU compiler parameters.

Parameters:
  • dimension_semantics (tuple[DimensionSemantics, ...] | None)

  • allow_input_fusion (tuple[bool, ...] | None)

  • vmem_limit_bytes (int | None)

  • collective_id (int | None)

  • has_side_effects (bool | SideEffectType)

  • flags (dict[str, Any] | None)

  • internal_scratch_in_bytes (int | None)

  • serialization_format (int)

  • kernel_type (CoreType)

  • disable_bounds_checks (bool)

  • disable_semaphore_checks (bool)

  • skip_device_barrier (bool)

  • allow_collective_id_without_custom_barrier (bool)

  • shape_invariant_numerics (bool)

  • use_tc_tiling_on_sc (bool | None)

  • needs_layout_passes (bool)

  • fuse_transposed_lhs_in_matmul (bool)

dimension_semantics#

A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “arbitrary” for dimensions that must be executed sequentially.

Type:

tuple[DimensionSemantics, …] | None

allow_input_fusion#

A list of booleans indicating whether input fusion is allowed for each argument.

Type:

tuple[bool, …] | None

vmem_limit_bytes#

Overrides the default VMEM limit for a kernel. Note that this must be used in conjunction with the –xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.

Type:

int | None

collective_id#

Indicates which barrier semaphore to use for the kernel. Note that using the same collective_id does not guarantee that the same barrier semaphore will be allocated between kernels.

Type:

int | None

has_side_effects#

Set to True to prevent kernel being CSEd by XLA.

Type:

bool | SideEffectType

flags#

A dictionary of command line flags for the kernel.

Type:

dict[str, Any] | None

internal_scratch_in_bytes#

The size of the internal scratch space used by Mosaic.

Type:

int | None

serialization_format#

The serialization format for the kernel body.

Type:

int

kernel_type#

Specify if the kernel is meant to run on TensorCore or one of the SparseCores

Type:

CoreType

disable_bounds_checks#

Disable bounds checks in the kernel.

Type:

bool

disable_semaphore_checks#

Disable semaphore checks in the kernel.

Type:

bool

skip_device_barrier#

Skip the default device barrier for the kernel.

Type:

bool

allow_collective_id_without_custom_barrier#

Allow the use of collective_id without a custom barrier.

Type:

bool

use_tc_tiling_on_sc#

Use TensorCore tiling for SparseCore. This flag is only used for SC_*_SUBCORE kernels.

Type:

bool | None

needs_layout_passes#

Whether to use vector layout inference passes. This flag is temporary and will eventually be removed.

Type:

bool

fuse_transposed_lhs_in_matmul#

Hint to compilers to attempt to fuse transposed LHS in MXU if users specify the transposed layout of LHS in matmul operations, e.g., jnp.einsum(‘km,kn->mn’, lhs, rhs); on the other hand, When transposition is performed separately from multiplication (e.g. jnp.matmul(lhs.T, rhs)), this flag does not affect the compiler’s decision (it might still decide to do it if obviously profitable). Note that this flag is at the best-effort basis, and the fusion will only be performed when compilers determine it is feasible. Also, the fusion is not always profitable and therefore should be used sparingly.

Type:

bool

__init__(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=TC, disable_bounds_checks=False, disable_semaphore_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None, needs_layout_passes=None, fuse_transposed_lhs_in_matmul=False)[source]#
Parameters:
  • dimension_semantics (Sequence[DimensionSemantics] | None)

  • allow_input_fusion (Sequence[bool] | None)

  • vmem_limit_bytes (int | None)

  • collective_id (int | None)

  • has_side_effects (bool | SideEffectType)

  • flags (Mapping[str, Any] | None)

  • internal_scratch_in_bytes (int | None)

  • serialization_format (int)

  • kernel_type (CoreType)

  • disable_bounds_checks (bool)

  • disable_semaphore_checks (bool)

  • skip_device_barrier (bool)

  • allow_collective_id_without_custom_barrier (bool)

  • shape_invariant_numerics (bool)

  • use_tc_tiling_on_sc (bool | None)

  • needs_layout_passes (bool | None)

  • fuse_transposed_lhs_in_matmul (bool)

Methods

__init__([dimension_semantics, ...])

replace(**changes)

Return a new object replacing specified fields with new values.

Attributes