jax.lax.ragged_dot#
- jax.lax.ragged_dot(lhs, rhs, group_sizes, precision=None, preferred_element_type=None, group_offset=None)[source]#
Ragged matrix multiplication.
- Parameters:
lhs (Array) – (m, k) shaped array.
rhs (Array) – (g, k, n) shaped array.
group_sizes (Array) – (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
precision (PrecisionLike) – Optional. Consistent with precision argument for
jax.lax.dot().preferred_element_type (DTypeLike | None) – Optional. Consistent with precision argument for
jax.lax.dot().group_offset (Array | None) – Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
- Return type:
- Results:
(m, n) shaped array with preferred_element_type element type.