jax.experimental.pallas.mosaic_gpu.CompilerParams#

class jax.experimental.pallas.mosaic_gpu.CompilerParams(*, approx_math=False, dimension_semantics=None, max_concurrent_steps=1, unsafe_no_auto_barriers=False, reduction_scratch_bytes=2048, profile_space=0, profile_dir='', profile_trace_scope=TraceScope.WARPGROUP, lowering_semantics=LoweringSemantics.Lane)[source]#

Mosaic GPU compiler parameters.

Parameters:
  • approx_math (bool)

  • dimension_semantics (Sequence[DimensionSemantics] | None)

  • max_concurrent_steps (int)

  • unsafe_no_auto_barriers (bool)

  • reduction_scratch_bytes (int)

  • profile_space (int)

  • profile_dir (str)

  • profile_trace_scope (TraceScope)

  • lowering_semantics (mgpu.core.LoweringSemantics)

approx_math#

If True, the compiler is allowed to use approximate implementations of some math operations, e.g. exp. Defaults to False.

Type:

bool

dimension_semantics#

A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “sequential” for dimensions that must be executed sequentially.

Type:

Sequence[DimensionSemantics] | None

max_concurrent_steps#

The maximum number of sequential stages that are active concurrently. Defaults to 1.

Type:

int

delay_release#

The number of steps to wait before reusing the input/output references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you’ll want to set it to 1 if you don’t await the WGMMA in the body.

unsafe_no_auto_barriers#

If True, Pallas will never automatically insert barrier instructions that ensure synchronous semantics of loads and stores. At the moment, the insertion is done conservatively and might regress performance. There are (at least) two conditions that must be satisfied for the use of this flag to be safe. First, no memory region is ever read and written to by the same thread (async copies are performed by background threads and do not count towards this rule). Secondly, no thread ever calls commit_smem(), reads from the committed SMEM and then issues an async copy overwriting that region (this is a very artificial and highly unlikely scenario).

Type:

bool

reduction_scratch_bytes#

The number of shared memory bytes to reserve as scratch space for cross-warp reductions. The higher this value, the more registers can be reduced in parallel. 2 * 128 * 6 * 4 = 6144 bytes is typically a good value in order to extract most of the potential gains on H100 and B200.

Type:

int

profile_space#

The number of profiler events that can be collected in a single invocation. It is undefined behavior if a thread collects more events than this.

Type:

int

profile_dir#

The directory to which profiling traces will be written to.

Type:

str

profile_trace_scope#

The scope at which traces are collected (WARP or WARPGROUP).

Type:

TraceScope

__init__(*, approx_math=False, dimension_semantics=None, max_concurrent_steps=1, unsafe_no_auto_barriers=False, reduction_scratch_bytes=2048, profile_space=0, profile_dir='', profile_trace_scope=TraceScope.WARPGROUP, lowering_semantics=LoweringSemantics.Lane)#
Parameters:
  • approx_math (bool)

  • dimension_semantics (Sequence[DimensionSemantics] | None)

  • max_concurrent_steps (int)

  • unsafe_no_auto_barriers (bool)

  • reduction_scratch_bytes (int)

  • profile_space (int)

  • profile_dir (str)

  • profile_trace_scope (TraceScope)

  • lowering_semantics (mgpu.core.LoweringSemantics)

Return type:

None

Methods

__init__(*[, approx_math, ...])

Attributes