jax.experimental.pallas.tpu.make_pipeline_allocations#

jax.experimental.pallas.tpu.make_pipeline_allocations(*refs, in_specs=(), out_specs=(), should_accumulate_out=False, needs_swap_ref=True, grid=None)[source]#

Create BufferedRefs for the pipeline.

This function creates buffered refs for an inner pipeline that can be created at the top-level of a pallas call such that they may be reused across multiple invocations of the inner pipeline.

Parameters:
  • in_specs – input pallas block specs

  • out_specs – output pallas block specs

  • should_accumulate_out – booleans to indicate which outputs should be treated as accumulators.

  • needs_swap_ref – whether a swap slots tracker needs to be allocated.

  • grid – grid to use for the pipeline.

Returns:

A list of BufferedRefs, one corresponding to each ref specified in the in_specs and out_specs.