jax.experimental.pallas.mosaic_gpu.BlockSpec#
- class jax.experimental.pallas.mosaic_gpu.BlockSpec(block_shape=None, index_map=None, pipeline_mode=None, transforms=(), delay_release=0, collective_axes=(), *, memory_space=None)[source]#
A GPU-specific
BlockSpec.- Parameters:
- transforms#
A sequence of transforms that will be applied to the reference.
- Type:
Sequence[MemoryRefTransform]
- delay_release#
used during pipelining to delay the release of resources of a slot after it is used in the computation.
- Type:
- collective_axes#
When set, all blocks along the specified axes must execute the same sequence of pipeline operations (with the only exception being the index_map in non-collective
BlockSpecs), and all of them must return the same block from the index_map for this operand. This enables the pipelining helpers to use collective async copies, which can improve performance.- Type:
tuple[Hashable, …]
- __init__(block_shape=None, index_map=None, pipeline_mode=None, transforms=(), delay_release=0, collective_axes=(), *, memory_space=None)#
Methods
__init__([block_shape, index_map, ...])replace(**changes)Return a new object replacing specified fields with new values.
to_block_mapping(origin, array_aval, *, ...)Attributes
block_shapeindex_mapmemory_spacepipeline_mode