jax.experimental.pallas.tpu.BufferedRefBase#
- class jax.experimental.pallas.tpu.BufferedRefBase[source]#
Abstract interface for BufferedRefs.
- __init__()#
Methods
__init__()advance_copy_in_slot([predicate])Advance the copy in slot.
advance_copy_out_slot([predicate])Advance the copy out slot.
advance_wait_in_slot([predicate])Advance the wait in slot.
advance_wait_out_slot([predicate])Advance the wait out slot.
bind_existing_ref(window_ref, indices)For handling VMEM references, the pipeline aliases the existing ref.
get_dma_slice(src_ty, grid_indices)init_slots()Initialize slot indices.
load_slots([predicate])Load slot information into registers.
save_slots([predicate])Save slot information from registers.
unbind_refs()with_spec(spec)Returns a new BufferedRefBase with the given block spec.
Attributes
block_shapebuffer_typecompute_indexis_accumulatoris_bufferedis_inputis_input_outputis_manualis_outputspec