jax.experimental.pallas.tpu.make_pipeline_allocations#
- jax.experimental.pallas.tpu.make_pipeline_allocations(*refs, in_specs=(), out_specs=(), should_accumulate_out=False, needs_swap_ref=True, grid=None)[source]#
Create BufferedRefs for the pipeline.
This function creates buffered refs for an inner pipeline that can be created at the top-level of a pallas call such that they may be reused across multiple invocations of the inner pipeline.
- Parameters:
in_specs – input pallas block specs
out_specs – output pallas block specs
should_accumulate_out – booleans to indicate which outputs should be treated as accumulators.
needs_swap_ref – whether a swap slots tracker needs to be allocated.
grid – grid to use for the pipeline.
- Returns:
A list of BufferedRefs, one corresponding to each ref specified in the in_specs and out_specs.