jax.experimental.pallas.tpu.emit_pipeline

Contents

jax.experimental.pallas.tpu.emit_pipeline#

jax.experimental.pallas.tpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), tiling=None, core_axis=None, core_axis_name=None, dimension_semantics=None, trace_scopes=True, no_pipelining=False, _explicit_indices=False)[source]#

Creates a function to emit a manual pallas pipeline.

This has the same semantics as pallas_call but is meant to be called inside pallas_call for nesting grids. This is useful when you need to have separate windowing strategies for communication and computation.

Parameters:
  • body – pallas kernel to set up pipeline for.

  • grid (tuple[int | jax.Array, ...]) – a pallas grid definition.

  • in_specs – input pallas block specs

  • out_specs – output pallas block specs

  • tiling (Tiling | None) – optional tiling to assume for the refs.

  • core_axis (tuple[int, ...] | int | None) – optional int or tuple of int, indicates whether or not to partition the grid along the core axis.

  • core_axis_name (tuple[str, ...] | str | None) – optional str or tuple of str, indicates whether or not to partition the grid along the core axis.

  • dimension_semantics (tuple[GridDimensionSemantics, ...] | None) – optional tuple of GridDimensionSemantics (e.g. PARALLEL or ARBITRARY).

  • trace_scopes (bool) – optional bool, indicates whether to annotate each region in the pipeline using named_scope.

  • no_pipelining (bool) – If True, turns off pipelining and all copies will be made synchronous. This is useful for debugging multiple-buffering related bugs.

  • _explicit_indices (bool) – If True, the body will receive the iteration indices as its first argument. This parameter is meant for internal use only.