jax.experimental.pallas.tpu.BufferedRefBase#

class jax.experimental.pallas.tpu.BufferedRefBase[source]#

Abstract interface for BufferedRefs.

__init__()#

Methods

__init__()

advance_copy_in_slot([predicate])

Advance the copy in slot.

advance_copy_out_slot([predicate])

Advance the copy out slot.

advance_wait_in_slot([predicate])

Advance the wait in slot.

advance_wait_out_slot([predicate])

Advance the wait out slot.

bind_existing_ref(window_ref, indices)

For handling VMEM references, the pipeline aliases the existing ref.

get_dma_slice(src_ty, grid_indices)

init_slots()

Initialize slot indices.

load_slots([predicate])

Load slot information into registers.

save_slots([predicate])

Save slot information from registers.

unbind_refs()

with_spec(spec)

Returns a new BufferedRefBase with the given block spec.

Attributes

block_shape

buffer_type

compute_index

is_accumulator

is_buffered

is_input

is_input_output

is_manual

is_output

spec