jax.experimental.pallas.mosaic_gpu.CompilerParams#
- class jax.experimental.pallas.mosaic_gpu.CompilerParams(*, approx_math=False, dimension_semantics=None, max_concurrent_steps=1, unsafe_no_auto_barriers=False, reduction_scratch_bytes=2048, profile_space=0, profile_dir='', profile_trace_scope=TraceScope.WARPGROUP, lowering_semantics=LoweringSemantics.Lane)[source]#
Mosaic GPU compiler parameters.
- Parameters:
- approx_math#
If True, the compiler is allowed to use approximate implementations of some math operations, e.g.
exp. Defaults to False.- Type:
- dimension_semantics#
A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “sequential” for dimensions that must be executed sequentially.
- Type:
Sequence[DimensionSemantics] | None
- max_concurrent_steps#
The maximum number of sequential stages that are active concurrently. Defaults to 1.
- Type:
- delay_release#
The number of steps to wait before reusing the input/output references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you’ll want to set it to 1 if you don’t await the WGMMA in the body.
- unsafe_no_auto_barriers#
If True, Pallas will never automatically insert barrier instructions that ensure synchronous semantics of loads and stores. At the moment, the insertion is done conservatively and might regress performance. There are (at least) two conditions that must be satisfied for the use of this flag to be safe. First, no memory region is ever read and written to by the same thread (async copies are performed by background threads and do not count towards this rule). Secondly, no thread ever calls commit_smem(), reads from the committed SMEM and then issues an async copy overwriting that region (this is a very artificial and highly unlikely scenario).
- Type:
- reduction_scratch_bytes#
The number of shared memory bytes to reserve as scratch space for cross-warp reductions. The higher this value, the more registers can be reduced in parallel. 2 * 128 * 6 * 4 = 6144 bytes is typically a good value in order to extract most of the potential gains on H100 and B200.
- Type:
- profile_space#
The number of profiler events that can be collected in a single invocation. It is undefined behavior if a thread collects more events than this.
- Type:
- profile_trace_scope#
The scope at which traces are collected (WARP or WARPGROUP).
- Type:
TraceScope
- __init__(*, approx_math=False, dimension_semantics=None, max_concurrent_steps=1, unsafe_no_auto_barriers=False, reduction_scratch_bytes=2048, profile_space=0, profile_dir='', profile_trace_scope=TraceScope.WARPGROUP, lowering_semantics=LoweringSemantics.Lane)#
- Parameters:
- Return type:
None
Methods
__init__(*[, approx_math, ...])Attributes