jax.experimental.pallas.dot#

jax.experimental.pallas.dot(a, b, trans_a=False, trans_b=False, allow_tf32=None, precision=None)[source]#

Computes the dot product of two arrays.

The inputs can optionally be transposed before computing the product. Depending on the hardware, this can be cheaper than computing the transpose beforehand.

Parameters:
  • a – The left-hand size of the dot product, of shape (..., N).

  • b – The right-hand size of the dot product, of shape (...N, M).

  • trans_a (bool) – Whether to transpose a before the product.

  • trans_b (bool) – Whether to transpose b before the product.

  • allow_tf32 (bool | None) – Whether to use tf32 precision. Mutually exclusive with precision.

  • precision – Specifies the precision of the dot product.

See also

jax.numpy.dot()