jax.lax.ragged_dot_general#

jax.lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dot_dimension_numbers, precision=None, preferred_element_type=None, group_offset=None)[source]#

Ragged matrix multiplication.

Ragged dot takes three arrays—lhs, rhs, and group_sizes—and a ragged_dot_dimension_numbers argument. Like dot_general, lhs and rhs are allowed arbitrary batch and contracting dimensions. Additionally, lhs is required to have one ragged dimension, and rhs may have at most one group dimension.

Let g be the number of groups in the lhs ragged dimension. Ragged dot has three modes, depending on the kind of the lhs ragged dimension:

  1. [b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]. Here the ragged dimension is a non-contracting dimension (m) of lhs, and x... are the lhs non-contracting dims outer to the ragged dim.

  2. [b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]. Here the ragged dimension is a contracting dimension (k) of lhs and rhs, and x… are the lhs contracting dims outer to the ragged dim.

  3. [b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]. Here the ragged dimension is a batch dimension (b) of lhs and rhs, and x... are the lhs batch dims outer to the ragged dim.

If group_sizes is passed-in with shape [g], it is broadcasted according to the rules above.

Parameters:
  • lhs (Array) – an array

  • rhs (Array) – an array

  • group_sizes (Array) – an array with integer element type

  • ragged_dot_dimension_numbers (RaggedDotDimensionNumbers) – a RaggedDotDimensionNumbers object to specify the dot dimension numbers, lhs ragged dimension, and rhs group dimension.

  • precision (PrecisionLike) – Optional. Consistent with precision argument for jax.lax.dot().

  • preferred_element_type (DTypeLike | None) – Optional. Consistent with precision argument for jax.lax.dot().

  • group_offset (Array | None) – Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].

Return type:

Array

Results:

An array whose shape is the same as that produced by dot_general, with an extra leading dimension of size g in the case where the lhs ragged dimension is a contracting dimension.