jax.experimental.pallas.mosaic_gpu.emit_pipeline_warp_specialized#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline_warp_specialized(body, *, grid, memory_registers, in_specs=(), out_specs=(), max_concurrent_steps=2, wg_axis, num_compute_wgs, pipeline_state=None, manual_consumed_barriers=False, compute_context=None, memory_thread_idx=None)[source]#
Creates a function to emit a warp-specialized pipeline.
The
bodyfunction should have the following signature (without carry).consumed_barriersis an optional argument that is only passed if themanual_consumed_barriersargument is True:def body(indices, *input_refs, *output_refs, *consumed_barriers) -> None:
or with a carries enabled (enabled via the
compute_contextargument), where the body returns the next carry:def body( indices, *input_refs, *output_refs, *consumed_barriers, carry ) -> Carry:
When
manual_consumed_barriersis True, the user must arrive on all the consumed barriers from all compute warpgroups at each pipeline step.- Parameters:
body (Callable[..., None]) – The pipeline body.
grid (pallas_core.TupleGrid) – The grid to use for the pipeline.
memory_registers (int) – The number of registers to reserve for the memory thread. For H100 GPUs, 40 is a reasonable value.
in_specs (BlockSpecPytree) – The block specs for the inputs.
out_specs (BlockSpecPytree) – The block specs for the outputs.
max_concurrent_steps (int) – The maximum number of sequential stages that are active concurrently. Defaults to 2.
wg_axis (str) – The axis name for the warp group axis.
num_compute_wgs (int) – The number of compute warpgroups
manual_consumed_barriers (bool) – If True, consumed barriers will be passed into the body function after the output refs. There will be one barrier per input and will be passed in the same order.
compute_context (ComputeContext | None) – If specified, enables carries in the pipeline and allows a user-specified prologue/epilogue that is only executed in the compute thread. The signature of the pipeline body function will be modified such that the last argument will be the current carry and it must return the next carry. The compute_context itself should follow the signature of ComputeContext and take a pipeline function as its sole argument. Calling the pipeline with the initial carry will run the pipeline and return the final carry.
memory_thread_idx (int | None) – The index of the memory thread. If not specified, defaults to the last thread.
pipeline_state (jax.Array | PipelinePipeline | None) –
If multiple pipelines that have almost the same parameters (only in/out_specs and body can differ) are going to be evaluated in sequence, this argument can be used to avoid pipeline bubbles between their invocations. The first pipeline in the sequence should use the
STARTstate, followed by an arbitrary number ofSTEADYstates, followed by a singleSTOPstate. Note that until the pipeline withSTOPis done, the memory thread will not wait for the compute threads to complete and fully consume their work. Any modification of their operands other than invoking another pipeline is disallowed.Important: To achieve bubble-free execution, it is important to also use the manual allocation mode by calling
get_allocationson the returned function, passing the result topl.run_scopedand the provided results to the returned function as anallocationskeyword argument. Otherwise, the pipeline function will perform the scoped allocation itself which can lead to synchronization that can still cause pipeline bubbles.
- Return type:
WarpSpecializedPipeline