jax.lax.top_k#
- jax.lax.top_k(operand, k, *, axis=-1)[source]#
Returns top
kvalues and their indices along the specified axis ofoperand.- Parameters:
- Returns:
A tuple
(values, indices)wherevaluesis an array containing the top k values along the last axis.indicesis an array containing the indices corresponding to values.
- Return type:
values[..., i, ...]is thei-th largest entry inoperandalong the specified axis, and its index isindices[..., i, ...].If two elements are equal, the lower-index element appears first.
Examples
Find the largest three values, and their indices, within an array:
>>> x = jnp.array([9., 3., 6., 4., 10.]) >>> values, indices = jax.lax.top_k(x, 3) >>> values Array([10., 9., 6.], dtype=float32) >>> indices Array([4, 0, 2], dtype=int32)