jax.experimental.pallas.tpu.BufferedRef#
- class jax.experimental.pallas.tpu.BufferedRef(_spec, _buffer_type, _buffer_count, _grid_rank, window_ref, copy_in_slot, wait_in_slot, copy_out_slot, wait_out_slot, next_fetch, sem_recvs, sem_sends, tiling, is_trivial_windowing=False, has_allocated_buffer=False)[source]#
A helper class to automate VMEM double buffering in pallas pipelines.
- Parameters:
_spec (pl.BlockSpec)
_buffer_type (BufferType)
_buffer_count (int)
_grid_rank (int | None)
window_ref (ArrayRef | None)
next_fetch (Sequence[jax.Array] | None)
sem_recvs (SemaphoreTuple | None)
sem_sends (SemaphoreTuple | None)
tiling (Tiling | None)
is_trivial_windowing (bool)
has_allocated_buffer (bool)
- buffer_type[source]#
enum indicating whether this is an input, output, or in/out buffered reference.
- window_ref#
a multiple-buffer to hold the working and dirty buffers used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref.
- Type:
ArrayRef | None
- next_fetch#
Holds the next grid indices to fetch for lookahead. This is the register state used to track the indices within the pipeline loop.
- Type:
Sequence[jax.Array] | None
- sem_recvs#
Multiple buffered semaphores for input DMAs.
- Type:
SemaphoreTuple | None
- sem_sends#
Multiple buffered semaphores for output DMAs.
- Type:
SemaphoreTuple | None
- tiling#
The tiling to assume for the buffers.
- Type:
Tiling | None
- has_allocated_buffer#
Whether the reference has an allocated buffer due to being in a different memory space than the source ref.
- Type:
- __init__(_spec, _buffer_type, _buffer_count, _grid_rank, window_ref, copy_in_slot, wait_in_slot, copy_out_slot, wait_out_slot, next_fetch, sem_recvs, sem_sends, tiling, is_trivial_windowing=False, has_allocated_buffer=False)#
- Parameters:
_spec (pl.BlockSpec)
_buffer_type (BufferType)
_buffer_count (int)
_grid_rank (int | None)
window_ref (ArrayRef | None)
next_fetch (Sequence[jax.Array] | None)
sem_recvs (SemaphoreTuple | None)
sem_sends (SemaphoreTuple | None)
tiling (Tiling | None)
is_trivial_windowing (bool)
has_allocated_buffer (bool)
- Return type:
None
Methods
__init__(_spec, _buffer_type, _buffer_count, ...)advance_copy_in_slot([predicate])Switch to the next copy slot.
advance_copy_out_slot([predicate])Switch to the next copy slot.
advance_wait_in_slot([predicate])Switch to the next wait slot.
advance_wait_out_slot([predicate])Switch to the next wait slot.
bind_existing_ref(window_ref, indices)For handling VMEM references, the pipeline aliases the existing ref.
compute_slice(grid_indices)Compute DMA slice from grid indices.
copy_in(src_ref, grid_indices)Starts copy of HBM dma slice into the current slot.
copy_out(dst_ref, grid_indices)Starts copy of HBM dma slice from the current slot.
create(spec, dtype_or_type, buffer_type, ...)Create a BufferedRef.
get_dma_slice(src_ty, grid_indices)initialize_slots()Initializes slots to 0.
input(spec, dtype_or_type[, buffer_count])input_output(spec, dtype_or_type[, buffer_count])output(spec, dtype_or_type[, buffer_count])unbind_refs()wait_in(src_ref, grid_indices)Waits for input copy to finish.
wait_out(dst_ref, grid_indices)Waits for output copy to finish.
with_next_fetch([next_fetch])with_slot_index([copy_in_slot, ...])Returns a new BufferedRef with the given slot index.
with_spec(spec)Returns a new BufferedRef with the given block spec.
Attributes
block_shapebuffer_countReturns the number of buffers used for multiple buffering.
compute_indexcumulative_copy_inThe cumulative number of copy_ins issued on this buffer.
cumulative_copy_outThe cumulative number of copy_outs issued on this buffer.
cumulative_wait_inThe cumulative number of wait_ins issued on this buffer.
cumulative_wait_outThe cumulative number of wait_outs issued on this buffer.
current_copy_in_slotIndex in multiple buffer corresponding to the current slot.
current_copy_out_slotIndex in multiple buffer corresponding to the current copy slot.
current_refReturns the current working slice of the double-buffer.
current_wait_in_slotIndex in multiple buffer corresponding to the current wait slot.
current_wait_out_slotIndex in multiple buffer corresponding to the current wait slot.
has_indirectWhether any block dimension uses indirect indexing.
is_bufferedWhether this buffer is multiple-buffered.
is_inputis_input_outputis_manualis_outputnext_fetch_indicesReturns the next grid indices to fetch from if using lookahead.
use_lookaheadWhether this buffer allows lookahead for fetching blocks.