jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, *, out_sharding=None)[source]#

Alias of jax.lax.dot().

Prefer use of jax.lax.dot() directly, but note that it requires all arguments after lhs and rhs to be specified by keyword rather than position.

Parameters:
  • lhs (ArrayLike)

  • rhs (ArrayLike)

  • dimension_numbers (DotDimensionNumbers)

  • precision (PrecisionLike)

  • preferred_element_type (DTypeLike | None)

Return type:

Array