jax.experimental.pallas.tpu.BufferedRef#

class jax.experimental.pallas.tpu.BufferedRef(_spec, _buffer_type, window_ref, accum_ref, copy_in_slot, wait_in_slot, copy_out_slot, wait_out_slot, _copy_in_slot_reg, _wait_in_slot_reg, _copy_out_slot_reg, _wait_out_slot_reg, next_fetch_smem, next_fetch_sreg, sem_recvs, sem_sends, swap)[source]#

A helper class to automate VMEM double buffering in pallas pipelines.

Parameters:
  • _spec (pl.BlockSpec)

  • _buffer_type (BufferType)

  • window_ref (ArrayRef | None)

  • accum_ref (ArrayRef | None)

  • copy_in_slot (ArrayRef | None)

  • wait_in_slot (ArrayRef | None)

  • copy_out_slot (ArrayRef | None)

  • wait_out_slot (ArrayRef | None)

  • _copy_in_slot_reg (int | jax.Array | None)

  • _wait_in_slot_reg (int | jax.Array | None)

  • _copy_out_slot_reg (int | jax.Array | None)

  • _wait_out_slot_reg (int | jax.Array | None)

  • next_fetch_smem (Sequence[jax.Array] | None)

  • next_fetch_sreg (Sequence[jax.Array] | None)

  • sem_recvs (SemaphoreTuple | None)

  • sem_sends (SemaphoreTuple | None)

  • swap (ArrayRef | None)

spec[source]#

pallas blockspec.

buffer_type[source]#

enum indicating whether this is an input, output, or in/out accumulator buffered reference.

window_ref#

a multiple-buffer to hold the working and dirty buffers used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref.

Type:

ArrayRef | None

accum_ref#

accumulating buffer used by accumulator BufferedRefs.

Type:

ArrayRef | None

copy_in_slot#

current slot to copy in for the working buffer.

Type:

ArrayRef | None

copy_out_slot#

current slot to copy out for the working buffer.

Type:

ArrayRef | None

wait_in_slot#

current slot to wait in for the working buffer.

Type:

ArrayRef | None

wait_out_slot#

current slot to wait out for the working buffer.

Type:

ArrayRef | None

next_fetch_smem#

Holds the next grid indices to fetch for lookahead. This is the SMEM backing buffer used to persist state between pipeline invocations.

Type:

Sequence[jax.Array] | None

next_fetch_sreg#

Holds the next grid indices to fetch for lookahead. This is the register state used to track the indices within the pipeline loop.

Type:

Sequence[jax.Array] | None

sem_recvs#

Multiple buffered semaphores for input DMAs.

Type:

SemaphoreTuple | None

sem_sends#

Multiple buffered semaphores for output DMAs.

Type:

SemaphoreTuple | None

block_shape[source]#

passthrough property for the BlockSpec’s block_shape.

compute_index[source]#

passthrough property for the BlockSpec’s compute_index.

memory_space#

passthrough property for the BlockSpec’s memory_space.

current_ref[source]#

points to the current working slice of the double-buffer.

is_input[source]#

whether this BufferedRef acts as a pipeline input.

is_output[source]#

whether this BufferedRef acts as a pipeline output.

is_accumulator[source]#

whether this BufferedRef is an accumulator.

is_input_output[source]#

whether this BufferedRef is an input/output without automatic accumulation.

swap#

Tracks whether the BufferedRef slots need to be swapped before next copy.

Type:

ArrayRef | None

__init__(_spec, _buffer_type, window_ref, accum_ref, copy_in_slot, wait_in_slot, copy_out_slot, wait_out_slot, _copy_in_slot_reg, _wait_in_slot_reg, _copy_out_slot_reg, _wait_out_slot_reg, next_fetch_smem, next_fetch_sreg, sem_recvs, sem_sends, swap)#
Parameters:
  • _spec (pl.BlockSpec)

  • _buffer_type (BufferType)

  • window_ref (ArrayRef | None)

  • accum_ref (ArrayRef | None)

  • copy_in_slot (ArrayRef | None)

  • wait_in_slot (ArrayRef | None)

  • copy_out_slot (ArrayRef | None)

  • wait_out_slot (ArrayRef | None)

  • _copy_in_slot_reg (int | jax.Array | None)

  • _wait_in_slot_reg (int | jax.Array | None)

  • _copy_out_slot_reg (int | jax.Array | None)

  • _wait_out_slot_reg (int | jax.Array | None)

  • next_fetch_smem (Sequence[jax.Array] | None)

  • next_fetch_sreg (Sequence[jax.Array] | None)

  • sem_recvs (SemaphoreTuple | None)

  • sem_sends (SemaphoreTuple | None)

  • swap (ArrayRef | None)

Return type:

None

Methods

__init__(_spec, _buffer_type, window_ref, ...)

accumulate()

Add into the current slot.

accumulator(spec, dtype_or_type[, buffer_count])

advance_copy_in_slot([predicate])

Switch to the next copy slot.

advance_copy_out_slot([predicate])

Switch to the next copy slot.

advance_wait_in_slot([predicate])

Switch to the next wait slot.

advance_wait_out_slot([predicate])

Switch to the next wait slot.

bind_existing_ref(window_ref, indices)

For handling VMEM references, the pipeline aliases the existing ref.

buffer_types()

compute_slice(grid_indices)

Compute DMA slice from grid indices.

copy_in(src_ref, grid_indices)

Starts copy of HBM dma slice into the current slot.

copy_out(dst_ref, grid_indices)

Starts copy of HBM dma slice from the current slot.

create(spec, dtype_or_type, buffer_type, ...)

Create a BufferedRef.

get_dma_slice(src_ty, grid_indices)

init_slots()

Initialize slot indices.

input(spec, dtype_or_type[, buffer_count])

input_output(spec, dtype_or_type[, buffer_count])

load_slots([predicate])

Load slot information into registers.

output(spec, dtype_or_type[, buffer_count])

save_slots([predicate])

Save slot information from registers.

set_accumulator([init])

Set accumulator or zero it out to initialize.

unbind_refs()

wait_in(src_ref, grid_indices)

Waits for input copy to finish.

wait_out(dst_ref, grid_indices)

Waits for output copy to finish.

with_next_fetch([next_fetch])

with_slot_index([copy_in_slot, ...])

Returns a new BufferedRef with the given slot index.

with_spec(spec)

Returns a new BufferedRef with the given block spec.

Attributes

block_shape

buffer_count

Returns the number of buffers used for multiple buffering.

buffer_type

compute_index

cumulative_copy_in

The cumulative number of copy_ins issued on this buffer.

cumulative_copy_out

The cumulative number of copy_outs issued on this buffer.

cumulative_wait_in

The cumulative number of wait_ins issued on this buffer.

cumulative_wait_out

The cumulative number of wait_outs issued on this buffer.

current_copy_in_slot

Index in multiple buffer corresponding to the current slot.

current_copy_out_slot

Index in multiple buffer corresponding to the current copy slot.

current_ref

current_wait_in_slot

Index in multiple buffer corresponding to the current wait slot.

current_wait_out_slot

Index in multiple buffer corresponding to the current wait slot.

is_accumulator

is_buffered

Whether this buffer is multiple-buffered.

is_input

is_input_output

is_manual

is_output

next_fetch_indices

Returns the next grid indices to fetch from if using lookahead.

spec

use_lookahead

Whether this buffer allows lookahead for fetching blocks.

window_ref

accum_ref

copy_in_slot

wait_in_slot

copy_out_slot

wait_out_slot

next_fetch_smem

next_fetch_sreg

sem_recvs

sem_sends

swap