jax.experimental.pallas.tpu.BufferedRef#
- class jax.experimental.pallas.tpu.BufferedRef(_spec, _buffer_type, window_ref, accum_ref, copy_in_slot, wait_in_slot, copy_out_slot, wait_out_slot, _copy_in_slot_reg, _wait_in_slot_reg, _copy_out_slot_reg, _wait_out_slot_reg, next_fetch_smem, next_fetch_sreg, sem_recvs, sem_sends, swap)[source]#
A helper class to automate VMEM double buffering in pallas pipelines.
- Parameters:
_spec (pl.BlockSpec)
_buffer_type (BufferType)
window_ref (ArrayRef | None)
accum_ref (ArrayRef | None)
copy_in_slot (ArrayRef | None)
wait_in_slot (ArrayRef | None)
copy_out_slot (ArrayRef | None)
wait_out_slot (ArrayRef | None)
next_fetch_smem (Sequence[jax.Array] | None)
next_fetch_sreg (Sequence[jax.Array] | None)
sem_recvs (SemaphoreTuple | None)
sem_sends (SemaphoreTuple | None)
swap (ArrayRef | None)
- buffer_type[source]#
enum indicating whether this is an input, output, or in/out accumulator buffered reference.
- window_ref#
a multiple-buffer to hold the working and dirty buffers used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref.
- Type:
ArrayRef | None
- accum_ref#
accumulating buffer used by accumulator BufferedRefs.
- Type:
ArrayRef | None
- copy_in_slot#
current slot to copy in for the working buffer.
- Type:
ArrayRef | None
- copy_out_slot#
current slot to copy out for the working buffer.
- Type:
ArrayRef | None
- wait_in_slot#
current slot to wait in for the working buffer.
- Type:
ArrayRef | None
- wait_out_slot#
current slot to wait out for the working buffer.
- Type:
ArrayRef | None
- next_fetch_smem#
Holds the next grid indices to fetch for lookahead. This is the SMEM backing buffer used to persist state between pipeline invocations.
- Type:
Sequence[jax.Array] | None
- next_fetch_sreg#
Holds the next grid indices to fetch for lookahead. This is the register state used to track the indices within the pipeline loop.
- Type:
Sequence[jax.Array] | None
- sem_recvs#
Multiple buffered semaphores for input DMAs.
- Type:
SemaphoreTuple | None
- sem_sends#
Multiple buffered semaphores for output DMAs.
- Type:
SemaphoreTuple | None
- memory_space#
passthrough property for the BlockSpec’s memory_space.
- is_input_output[source]#
whether this BufferedRef is an input/output without automatic accumulation.
- swap#
Tracks whether the BufferedRef slots need to be swapped before next copy.
- Type:
ArrayRef | None
- __init__(_spec, _buffer_type, window_ref, accum_ref, copy_in_slot, wait_in_slot, copy_out_slot, wait_out_slot, _copy_in_slot_reg, _wait_in_slot_reg, _copy_out_slot_reg, _wait_out_slot_reg, next_fetch_smem, next_fetch_sreg, sem_recvs, sem_sends, swap)#
- Parameters:
_spec (pl.BlockSpec)
_buffer_type (BufferType)
window_ref (ArrayRef | None)
accum_ref (ArrayRef | None)
copy_in_slot (ArrayRef | None)
wait_in_slot (ArrayRef | None)
copy_out_slot (ArrayRef | None)
wait_out_slot (ArrayRef | None)
next_fetch_smem (Sequence[jax.Array] | None)
next_fetch_sreg (Sequence[jax.Array] | None)
sem_recvs (SemaphoreTuple | None)
sem_sends (SemaphoreTuple | None)
swap (ArrayRef | None)
- Return type:
None
Methods
__init__(_spec, _buffer_type, window_ref, ...)accumulate()Add into the current slot.
accumulator(spec, dtype_or_type[, buffer_count])advance_copy_in_slot([predicate])Switch to the next copy slot.
advance_copy_out_slot([predicate])Switch to the next copy slot.
advance_wait_in_slot([predicate])Switch to the next wait slot.
advance_wait_out_slot([predicate])Switch to the next wait slot.
bind_existing_ref(window_ref, indices)For handling VMEM references, the pipeline aliases the existing ref.
buffer_types()compute_slice(grid_indices)Compute DMA slice from grid indices.
copy_in(src_ref, grid_indices)Starts copy of HBM dma slice into the current slot.
copy_out(dst_ref, grid_indices)Starts copy of HBM dma slice from the current slot.
create(spec, dtype_or_type, buffer_type, ...)Create a BufferedRef.
get_dma_slice(src_ty, grid_indices)init_slots()Initialize slot indices.
input(spec, dtype_or_type[, buffer_count])input_output(spec, dtype_or_type[, buffer_count])load_slots([predicate])Load slot information into registers.
output(spec, dtype_or_type[, buffer_count])save_slots([predicate])Save slot information from registers.
set_accumulator([init])Set accumulator or zero it out to initialize.
unbind_refs()wait_in(src_ref, grid_indices)Waits for input copy to finish.
wait_out(dst_ref, grid_indices)Waits for output copy to finish.
with_next_fetch([next_fetch])with_slot_index([copy_in_slot, ...])Returns a new BufferedRef with the given slot index.
with_spec(spec)Returns a new BufferedRef with the given block spec.
Attributes
buffer_countReturns the number of buffers used for multiple buffering.
cumulative_copy_inThe cumulative number of copy_ins issued on this buffer.
cumulative_copy_outThe cumulative number of copy_outs issued on this buffer.
cumulative_wait_inThe cumulative number of wait_ins issued on this buffer.
cumulative_wait_outThe cumulative number of wait_outs issued on this buffer.
current_copy_in_slotIndex in multiple buffer corresponding to the current slot.
current_copy_out_slotIndex in multiple buffer corresponding to the current copy slot.
current_wait_in_slotIndex in multiple buffer corresponding to the current wait slot.
current_wait_out_slotIndex in multiple buffer corresponding to the current wait slot.
is_bufferedWhether this buffer is multiple-buffered.
is_manualnext_fetch_indicesReturns the next grid indices to fetch from if using lookahead.
use_lookaheadWhether this buffer allows lookahead for fetching blocks.