jax.experimental.pallas.mosaic_gpu.BlockSpec#

class jax.experimental.pallas.mosaic_gpu.BlockSpec(block_shape=None, index_map=None, pipeline_mode=None, transforms=(), delay_release=0, collective_axes=(), *, memory_space=None)[source]#

A GPU-specific BlockSpec.

Parameters:
  • block_shape (Sequence[BlockDim | int | None] | None)

  • index_map (Callable[..., Any] | None)

  • pipeline_mode (Buffered | None)

  • transforms (Sequence[MemoryRefTransform])

  • delay_release (int)

  • collective_axes (tuple[Hashable, ...])

  • memory_space (Any | None)

transforms#

A sequence of transforms that will be applied to the reference.

Type:

Sequence[MemoryRefTransform]

delay_release#

used during pipelining to delay the release of resources of a slot after it is used in the computation.

Type:

int

collective_axes#

When set, all blocks along the specified axes must execute the same sequence of pipeline operations (with the only exception being the index_map in non-collective BlockSpecs), and all of them must return the same block from the index_map for this operand. This enables the pipelining helpers to use collective async copies, which can improve performance.

Type:

tuple[Hashable, …]

__init__(block_shape=None, index_map=None, pipeline_mode=None, transforms=(), delay_release=0, collective_axes=(), *, memory_space=None)#
Parameters:
  • block_shape (Sequence[BlockDim | int | None] | None)

  • index_map (Callable[..., Any] | None)

  • pipeline_mode (Buffered | None)

  • transforms (Sequence[MemoryRefTransform])

  • delay_release (int)

  • collective_axes (tuple[Hashable, ...])

  • memory_space (Any | None)

Return type:

None

Methods

__init__([block_shape, index_map, ...])

replace(**changes)

Return a new object replacing specified fields with new values.

to_block_mapping(origin, array_aval, *, ...)

Attributes

block_shape

collective_axes

delay_release

index_map

memory_space

pipeline_mode

transforms