jax.experimental.pallas.mosaic_gpu.tcgen05_mma#

jax.experimental.pallas.mosaic_gpu.tcgen05_mma(acc, a, b, barrier=None, *, a_scale=None, b_scale=None, a_sparse_metadata=None, accumulate=True, collective_axis=None)[source]#

Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell).

If run in collective mode, acc, a (LHS), and b (RHS) should correspond to half of the total inputs to the MMA, where acc and a (LHS) are split in half along the rows and b (RHS) is split along the columns like so:

-----------    -----------   -----------
|  ACC1   |    |  LHS1   |   |    |    |
----------- += ----------- @ |RHS1|RHS2|
|  ACC2   |    |  LHS2   |   |    |    |
-----------    -----------   -----------

To use the block-scaled matrix-multiply, provide a_scale and b_scale operands (they must be both present or both unspecified).

Parameters:
  • acc (_Ref) – The accumulator. Must be a TMEM Ref.

  • a (_Ref) – The left-hand side. Must be a TMEM/SMEM Ref.

  • b (_Ref) – The right-hand side. Must be an SMEM Ref.

  • barrier (_Ref | None) – Optional barrier Ref for synchronizing with the tensor core. Must have orders_tensor_core set to True. If not specified, the MMA completion should be explicitly observed by calling jax.experimental.pallas.mosaic_gpu.tcgen05_commit_arrive()

  • a_scale (_Ref | None) – An optional scale for the a operand. Must be a TMEM Ref if present.

  • b_scale (_Ref | None) – An optional scale for the b operand. Must be a TMEM Ref if present.

  • a_sparse_metadata (_Ref | None) – An optional sparse metadata for the a operand. Must be a TMEM Ref if present.

  • accumulate (bool | jax.Array) – Whether to accumulate into acc or overwrite it.

  • collective_axis (str | None) – The name of the cluster axis along which to perform a collective MMA. The cluster axis should have a size of exactly 2, and must be on the minormost cluster axis.