jax.experimental.pallas.tpu.emit_pipeline#
- jax.experimental.pallas.tpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), tiling=None, core_axis=None, core_axis_name=None, dimension_semantics=None, trace_scopes=True, no_pipelining=False, _explicit_indices=False)[source]#
Creates a function to emit a manual pallas pipeline.
This has the same semantics as pallas_call but is meant to be called inside pallas_call for nesting grids. This is useful when you need to have separate windowing strategies for communication and computation.
- Parameters:
body β pallas kernel to set up pipeline for.
grid (tuple[int | jax.Array, ...]) β a pallas grid definition.
in_specs β input pallas block specs
out_specs β output pallas block specs
tiling (Tiling | None) β optional tiling to assume for the refs.
core_axis (tuple[int, ...] | int | None) β optional int or tuple of int, indicates whether or not to partition the grid along the core axis.
core_axis_name (tuple[str, ...] | str | None) β optional str or tuple of str, indicates whether or not to partition the grid along the core axis.
dimension_semantics (tuple[GridDimensionSemantics, ...] | None) β optional tuple of GridDimensionSemantics (e.g. PARALLEL or ARBITRARY).
trace_scopes (bool) β optional bool, indicates whether to annotate each region in the pipeline using named_scope.
no_pipelining (bool) β If True, turns off pipelining and all copies will be made synchronous. This is useful for debugging multiple-buffering related bugs.
_explicit_indices (bool) β If True, the body will receive the iteration indices as its first argument. This parameter is meant for internal use only.