jax.lax.ragged_dot_general#
- jax.lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dot_dimension_numbers, precision=None, preferred_element_type=None, group_offset=None)[source]#
Ragged matrix multiplication.
Ragged dot takes three arrays—
lhs,rhs, andgroup_sizes—and aragged_dot_dimension_numbersargument. Like dot_general,lhsandrhsare allowed arbitrary batch and contracting dimensions. Additionally,lhsis required to have one ragged dimension, andrhsmay have at most one group dimension.Let g be the number of groups in the lhs ragged dimension. Ragged dot has three modes, depending on the kind of the lhs ragged dimension:
[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]. Here the ragged dimension is a non-contracting dimension (m) oflhs, andx...are the lhs non-contracting dims outer to the ragged dim.[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]. Here the ragged dimension is a contracting dimension (k) oflhsandrhs, and x… are the lhs contracting dims outer to the ragged dim.[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]. Here the ragged dimension is a batch dimension (b) oflhsandrhs, andx...are the lhs batch dims outer to the ragged dim.
If
group_sizesis passed-in with shape[g], it is broadcasted according to the rules above.- Parameters:
lhs (Array) – an array
rhs (Array) – an array
group_sizes (Array) – an array with integer element type
ragged_dot_dimension_numbers (RaggedDotDimensionNumbers) – a
RaggedDotDimensionNumbersobject to specify the dot dimension numbers, lhs ragged dimension, and rhs group dimension.precision (PrecisionLike) – Optional. Consistent with precision argument for
jax.lax.dot().preferred_element_type (DTypeLike | None) – Optional. Consistent with precision argument for
jax.lax.dot().group_offset (Array | None) – Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
- Return type:
- Results:
An array whose shape is the same as that produced by dot_general, with an extra leading dimension of size g in the case where the lhs ragged dimension is a contracting dimension.