jax.lax.ragged_dot#

jax.lax.ragged_dot(lhs, rhs, group_sizes, precision=None, preferred_element_type=None, group_offset=None)[source]#

Ragged matrix multiplication.

Parameters:
  • lhs (Array) – (m, k) shaped array.

  • rhs (Array) – (g, k, n) shaped array.

  • group_sizes (Array) – (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.

  • precision (PrecisionLike) – Optional. Consistent with precision argument for jax.lax.dot().

  • preferred_element_type (DTypeLike | None) – Optional. Consistent with precision argument for jax.lax.dot().

  • group_offset (Array | None) – Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].

Return type:

Array

Results:

(m, n) shaped array with preferred_element_type element type.