jax.experimental.pallas.tpu.emit_pipeline_with_allocations#
- jax.experimental.pallas.tpu.emit_pipeline_with_allocations(body, *, grid, in_specs=(), out_specs=(), should_accumulate_out=False)[source]#
Creates pallas pipeline and top-level allocation preparation functions.
- Parameters:
body – pallas kernel to set up pipeline for.
grid – a pallas grid definition.
in_specs – input pallas block specs
out_specs – output pallas block specs
should_accumulate_out – booleans to indicate which outputs should be treated as accumulators.
- Returns:
- (emit_pipeline, make_allocations) function pair, where
emit_pipeline is the pallas pipeline function.
make_allocations is a function to create buffered refs for the inner pipeline that can be created at the top-level of a pallas call to be reused across multiple invocations of the inner pipeline.