jax.numpy.quantile#
- jax.numpy.quantile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, weights=None)[source]#
Compute the quantile of the data along the specified axis.
JAX implementation of
numpy.quantile().- Parameters:
a (ArrayLike) β N-dimensional array input.
q (ArrayLike) β scalar or 1-dimensional array specifying the desired quantiles.
qshould contain floating-point values between0.0and1.0.axis (int | tuple[int, ...] | None) β optional axis or tuple of axes along which to compute the quantile
out (None) β not implemented by JAX; will error if not None
overwrite_input (bool) β not implemented by JAX; will error if not False
method (str) β specify the interpolation method to use. Options are one of
["linear", "lower", "higher", "midpoint", "nearest"]. default islinear.keepdims (bool) β if True, then the returned array will have the same number of dimensions as the input. Default is False.
weights (ArrayLike | None) β keyword-only. optional array of weights associated with the values in
a. Each value inacontributes to the quantile according to its associated weight. The weights array must be broadcastable to the same shape asa. Only works withmethod="inverted_cdf".
- Returns:
An array containing the specified quantiles along the specified axes.
- Return type:
See also
jax.numpy.nanquantile(): compute the quantile while ignoring NaNsjax.numpy.percentile(): compute the percentile (0-100)
Examples
Computing the median and quartiles of an array, with linear interpolation:
>>> x = jnp.arange(10) >>> q = jnp.array([0.25, 0.5, 0.75]) >>> jnp.quantile(x, q) Array([2.25, 4.5 , 6.75], dtype=float32)
Computing the quartiles using nearest-value interpolation:
>>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32)
Computing weighted quantiles:
>>> x = jnp.array([1, 2, 3, 4, 5]) >>> weights = jnp.array([1, 1, 2, 1, 1]) >>> jnp.quantile(x, 0.5, weights=weights, method='inverted_cdf') Array(3., dtype=float32)