jax.experimental.pallas.dot#
- jax.experimental.pallas.dot(a, b, trans_a=False, trans_b=False, allow_tf32=None, precision=None)[source]#
Computes the dot product of two arrays.
The inputs can optionally be transposed before computing the product. Depending on the hardware, this can be cheaper than computing the transpose beforehand.
- Parameters:
a – The left-hand size of the dot product, of shape
(..., N).b – The right-hand size of the dot product, of shape
(...N, M).trans_a (bool) – Whether to transpose
abefore the product.trans_b (bool) – Whether to transpose
bbefore the product.allow_tf32 (bool | None) – Whether to use tf32 precision. Mutually exclusive with
precision.precision – Specifies the precision of the dot product.
See also