jax.experimental.pallas.tpu.CompilerParams#
- class jax.experimental.pallas.tpu.CompilerParams(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=KernelType.TC, disable_bounds_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None)[source]#
Mosaic TPU compiler parameters.
- Parameters:
dimension_semantics (tuple[DimensionSemantics, ...] | None)
vmem_limit_bytes (int | None)
collective_id (int | None)
has_side_effects (bool | SideEffectType)
internal_scratch_in_bytes (int | None)
serialization_format (int)
kernel_type (KernelType)
disable_bounds_checks (bool)
skip_device_barrier (bool)
allow_collective_id_without_custom_barrier (bool)
shape_invariant_numerics (bool)
use_tc_tiling_on_sc (bool | None)
- dimension_semantics#
A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “arbitrary” for dimensions that must be executed sequentially.
- Type:
tuple[DimensionSemantics, …] | None
- allow_input_fusion#
A list of booleans indicating whether input fusion is allowed for each argument.
- vmem_limit_bytes#
Overrides the default VMEM limit for a kernel. Note that this must be used in conjunction with the –xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.
- Type:
int | None
- collective_id#
Indicates which barrier semaphore to use for the kernel. Note that using the same collective_id does not guarantee that the same barrier semaphore will be allocated between kernels.
- Type:
int | None
- kernel_type#
Specify if the kernel is meant to run on TensorCore or one of the SparseCores
- Type:
KernelType
- allow_collective_id_without_custom_barrier#
Allow the use of collective_id without a custom barrier.
- Type:
- use_tc_tiling_on_sc#
Use TensorCore tiling for SparseCore. This flag is only used for
SC_*_SUBCOREkernels.- Type:
bool | None
- __init__(dimension_semantics=None, allow_input_fusion=None, vmem_limit_bytes=None, collective_id=None, has_side_effects=False, flags=None, internal_scratch_in_bytes=None, serialization_format=1, kernel_type=KernelType.TC, disable_bounds_checks=False, skip_device_barrier=False, allow_collective_id_without_custom_barrier=False, shape_invariant_numerics=True, use_tc_tiling_on_sc=None)[source]#
- Parameters:
dimension_semantics (Sequence[DimensionSemantics] | None)
allow_input_fusion (Sequence[bool] | None)
vmem_limit_bytes (int | None)
collective_id (int | None)
has_side_effects (bool | SideEffectType)
flags (Mapping[str, Any] | None)
internal_scratch_in_bytes (int | None)
serialization_format (int)
kernel_type (KernelType)
disable_bounds_checks (bool)
skip_device_barrier (bool)
allow_collective_id_without_custom_barrier (bool)
shape_invariant_numerics (bool)
use_tc_tiling_on_sc (bool | None)
Methods
__init__([dimension_semantics, ...])replace(**changes)Return a new object replacing specified fields with new values.
Attributes
BACKENDshape_invariant_numerics