jax.experimental.pallas.tpu.BufferedRefBase#
- class jax.experimental.pallas.tpu.BufferedRefBase[source]#
Abstract interface for BufferedRefs.
- __init__()#
Methods
__init__()advance_copy_in_slot([predicate])Advance the copy in slot.
advance_copy_out_slot([predicate])Advance the copy out slot.
advance_wait_in_slot([predicate])Advance the wait in slot.
advance_wait_out_slot([predicate])Advance the wait out slot.
bind_existing_ref(window_ref, indices)For handling VMEM references, the pipeline aliases the existing ref.
get_dma_slice(src_ty, grid_indices)initialize_slots()Initializes slots to 0.
unbind_refs()with_spec(spec)Returns a new BufferedRefBase with the given block spec.
Attributes
block_shapebuffer_typecompute_indexhas_allocated_bufferReturns True if the reference has an allocated buffer outside loop.
has_indirectWhether any block dimension uses indirect indexing.
is_bufferedis_inputis_input_outputis_manualis_outputis_trivial_windowingWhether the reference uses trivial windowing.
spec