jax.experimental.pallas.tpu.emit_pipeline_with_allocations#

jax.experimental.pallas.tpu.emit_pipeline_with_allocations(body, *, grid, in_specs=(), out_specs=(), should_accumulate_out=False)[source]#

Creates pallas pipeline and top-level allocation preparation functions.

Parameters:
  • body – pallas kernel to set up pipeline for.

  • grid – a pallas grid definition.

  • in_specs – input pallas block specs

  • out_specs – output pallas block specs

  • should_accumulate_out – booleans to indicate which outputs should be treated as accumulators.

Returns:

(emit_pipeline, make_allocations) function pair, where
  • emit_pipeline is the pallas pipeline function.

  • make_allocations is a function to create buffered refs for the inner pipeline that can be created at the top-level of a pallas call to be reused across multiple invocations of the inner pipeline.