jax.lax.split#
- jax.lax.split(operand, sizes, axis=0)[source]#
Splits an array along
axis.- Parameters:
operand (ArrayLike) – an array to split
sizes (Sequence[DimSize]) – the sizes of the split arrays. The sum of the sizes must be equal to the size of the
axisdimension ofoperand.axis (int) – the axis along which to split the array.
- Returns:
A sequence of
len(sizes)arrays. Ifsizesis[s1, s2, ...], this function returns chunks of sizess1,s2, taken alongaxis.- Return type:
Sequence[Array]