jax.experimental.pallas.tpu.emit_pipeline#
- jax.experimental.pallas.tpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), should_accumulate_out=False, core_axis=None, core_axis_name=None, dimension_semantics=None, trace_scopes=True, no_pipelining=False)[source]#
Creates a function to emit a manual pallas pipeline.
This has the same semantics as pallas_call but is meant to be called inside pallas_call for nesting grids. This is useful when you need to have separate windowing strategies for communication and computation.
The new argument should_accumulate_out can be used to specify which outputs we should accumulate into automatically within and across pipeline invocations.
- Parameters:
body – pallas kernel to set up pipeline for.
grid (tuple[int | jax.Array, ...]) – a pallas grid definition.
in_specs – input pallas block specs
out_specs – output pallas block specs
should_accumulate_out (bool) – booleans to indicate which outputs should be treated as accumulators.
core_axis (int | None) – optional int, indicates whether or not to partition the grid along the core axis.
core_axis_name (str | None) – optional str, indicates whether or not to partition the grid along the core axis.
dimension_semantics (tuple[GridDimensionSemantics, ...] | None) – optional tuple of GridDimensionSemantics (e.g. PARALLEL or ARBITRARY).
trace_scopes (bool) – optional bool, indicates whether to annotate each region in the pipeline using named_scope.
no_pipelining (bool) – If True, turns off pipelining and all copies will be made synchronous. This is useful for debugging multiple-buffering related bugs.