jax.numpy.nanquantile

Contents

jax.numpy.nanquantile#

jax.numpy.nanquantile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, weights=None)[source]#

Compute the quantile of the data along the specified axis, ignoring NaNs.

JAX implementation of numpy.nanquantile().

Parameters:
  • a (ArrayLike) – N-dimensional array input.

  • q (ArrayLike) – scalar or 1-dimensional array specifying the desired quantiles. q should contain floating-point values between 0.0 and 1.0.

  • axis (int | tuple[int, ...] | None) – optional axis or tuple of axes along which to compute the quantile

  • out (None) – not implemented by JAX; will error if not None

  • overwrite_input (bool) – not implemented by JAX; will error if not False

  • method (str) – specify the interpolation method to use. Options are one of ["linear", "lower", "higher", "midpoint", "nearest"]. default is linear.

  • keepdims (bool) – if True, then the returned array will have the same number of dimensions as the input. Default is False.

  • weights (ArrayLike | None) – keyword-only. Optional array of weights for each element in a. Values with higher weights contribute more to the quantile calculation. The weights array must be broadcastable to the shape of a along the specified axis. NaN values in a are ignored when computing the quantiles. Weighted quantiles are currently only supported when method=”inverted_cdf”.

Returns:

An array containing the specified quantiles along the specified axes.

Return type:

Array

See also

Examples

Computing the median and quartiles of a 1D array:

>>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6])
>>> q = jnp.array([0.25, 0.5, 0.75])

Because of the NaN value, jax.numpy.quantile() returns all NaNs, while nanquantile() ignores them:

>>> jnp.quantile(x, q)
Array([nan, nan, nan], dtype=float32)
>>> jnp.nanquantile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)

Computing weighted quantiles while ignoring NaNs:

>>> x = jnp.array([1, 2, jnp.nan, 4, 5])
>>> weights = jnp.array([1, 1, 1, 2, 1])
>>> jnp.nanquantile(x, 0.5, weights=weights, method='inverted_cdf')
Array(4.0, dtype=float32)