jax.experimental.pallas.tpu.BufferedRefBase

jax.experimental.pallas.tpu.BufferedRefBase#

class jax.experimental.pallas.tpu.BufferedRefBase[source]#

Abstract interface for BufferedRefs.

__init__()#

Methods

__init__()

advance_copy_in_slot([predicate])

Advance the copy in slot.

advance_copy_out_slot([predicate])

Advance the copy out slot.

advance_wait_in_slot([predicate])

Advance the wait in slot.

advance_wait_out_slot([predicate])

Advance the wait out slot.

bind_existing_ref(window_ref, indices)

For handling VMEM references, the pipeline aliases the existing ref.

get_dma_slice(src_ty, grid_indices)

initialize_slots()

Initializes slots to 0.

unbind_refs()

with_spec(spec)

Returns a new BufferedRefBase with the given block spec.

Attributes

block_shape

buffer_type

compute_index

has_allocated_buffer

Returns True if the reference has an allocated buffer outside loop.

has_indirect

Whether any block dimension uses indirect indexing.

is_buffered

is_input

is_input_output

is_manual

is_output

is_trivial_windowing

Whether the reference uses trivial windowing.

spec