jax.experimental.pallas.mosaic_gpu.tcgen05_mma#
- jax.experimental.pallas.mosaic_gpu.tcgen05_mma(acc, a, b, barrier=None, *, a_scale=None, b_scale=None, a_sparse_metadata=None, accumulate=True, collective_axis=None)[source]#
Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell).
If run in collective mode,
acc,a(LHS), andb(RHS) should correspond to half of the total inputs to the MMA, whereaccanda(LHS) are split in half along the rows andb(RHS) is split along the columns like so:----------- ----------- ----------- | ACC1 | | LHS1 | | | | ----------- += ----------- @ |RHS1|RHS2| | ACC2 | | LHS2 | | | | ----------- ----------- -----------
To use the block-scaled matrix-multiply, provide
a_scaleandb_scaleoperands (they must be both present or both unspecified).- Parameters:
acc (_Ref) – The accumulator. Must be a TMEM Ref.
a (_Ref) – The left-hand side. Must be a TMEM/SMEM Ref.
b (_Ref) – The right-hand side. Must be an SMEM Ref.
barrier (_Ref | None) – Optional barrier Ref for synchronizing with the tensor core. Must have orders_tensor_core set to True. If not specified, the MMA completion should be explicitly observed by calling
jax.experimental.pallas.mosaic_gpu.tcgen05_commit_arrive()a_scale (_Ref | None) – An optional scale for the
aoperand. Must be a TMEM Ref if present.b_scale (_Ref | None) – An optional scale for the
boperand. Must be a TMEM Ref if present.a_sparse_metadata (_Ref | None) – An optional sparse metadata for the
aoperand. Must be a TMEM Ref if present.accumulate (bool | jax.Array) – Whether to accumulate into acc or overwrite it.
collective_axis (str | None) – The name of the cluster axis along which to perform a collective MMA. The cluster axis should have a size of exactly 2, and must be on the minormost cluster axis.