jax.experimental.pallas.tpu.CompilerParams#
- class jax.experimental.pallas.tpu.CompilerParams(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=TC, disable_bounds_checks=False, disable_semaphore_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None, needs_layout_passes=None, fuse_transposed_lhs_in_matmul=False)[source]#
Mosaic TPU compiler parameters.
- Parameters:
dimension_semantics (tuple[DimensionSemantics, ...] | None)
vmem_limit_bytes (int | None)
collective_id (int | None)
has_side_effects (bool | SideEffectType)
internal_scratch_in_bytes (int | None)
serialization_format (int)
kernel_type (CoreType)
disable_bounds_checks (bool)
disable_semaphore_checks (bool)
skip_device_barrier (bool)
allow_collective_id_without_custom_barrier (bool)
shape_invariant_numerics (bool)
use_tc_tiling_on_sc (bool | None)
needs_layout_passes (bool)
fuse_transposed_lhs_in_matmul (bool)
- dimension_semantics#
A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “arbitrary” for dimensions that must be executed sequentially.
- Type:
tuple[DimensionSemantics, …] | None
- allow_input_fusion#
A list of booleans indicating whether input fusion is allowed for each argument.
- vmem_limit_bytes#
Overrides the default VMEM limit for a kernel. Note that this must be used in conjunction with the –xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.
- Type:
int | None
- collective_id#
Indicates which barrier semaphore to use for the kernel. Note that using the same collective_id does not guarantee that the same barrier semaphore will be allocated between kernels.
- Type:
int | None
- kernel_type#
Specify if the kernel is meant to run on TensorCore or one of the SparseCores
- Type:
CoreType
- allow_collective_id_without_custom_barrier#
Allow the use of collective_id without a custom barrier.
- Type:
- use_tc_tiling_on_sc#
Use TensorCore tiling for SparseCore. This flag is only used for
SC_*_SUBCOREkernels.- Type:
bool | None
- needs_layout_passes#
Whether to use vector layout inference passes. This flag is temporary and will eventually be removed.
- Type:
- fuse_transposed_lhs_in_matmul#
Hint to compilers to attempt to fuse transposed LHS in MXU if users specify the transposed layout of LHS in matmul operations, e.g., jnp.einsum(‘km,kn->mn’, lhs, rhs); on the other hand, When transposition is performed separately from multiplication (e.g. jnp.matmul(lhs.T, rhs)), this flag does not affect the compiler’s decision (it might still decide to do it if obviously profitable). Note that this flag is at the best-effort basis, and the fusion will only be performed when compilers determine it is feasible. Also, the fusion is not always profitable and therefore should be used sparingly.
- Type:
- __init__(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=TC, disable_bounds_checks=False, disable_semaphore_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None, needs_layout_passes=None, fuse_transposed_lhs_in_matmul=False)[source]#
- Parameters:
dimension_semantics (Sequence[DimensionSemantics] | None)
allow_input_fusion (Sequence[bool] | None)
vmem_limit_bytes (int | None)
collective_id (int | None)
has_side_effects (bool | SideEffectType)
flags (Mapping[str, Any] | None)
internal_scratch_in_bytes (int | None)
serialization_format (int)
kernel_type (CoreType)
disable_bounds_checks (bool)
disable_semaphore_checks (bool)
skip_device_barrier (bool)
allow_collective_id_without_custom_barrier (bool)
shape_invariant_numerics (bool)
use_tc_tiling_on_sc (bool | None)
needs_layout_passes (bool | None)
fuse_transposed_lhs_in_matmul (bool)
Methods
__init__([dimension_semantics, ...])replace(**changes)Return a new object replacing specified fields with new values.
Attributes
shape_invariant_numerics