jax.numpy.std#
- jax.numpy.std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)[source]#
Compute the standard deviation along a given axis.
JAX implementation of
numpy.std().- Parameters:
a (ArrayLike) – input array.
axis (Axis) – optional, int or sequence of ints, default=None. Axis along which the standard deviation is computed. If None, standard deviaiton is computed along all the axes.
dtype (DTypeLike | None) – The type of the output array. Default=None.
ddof (int) – int, default=0. Degrees of freedom. The divisor in the standard deviation computation is
N-ddof,Nis number of elements along given axis.keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
where (ArrayLike | None) – optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input.
mean (ArrayLike | None) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the standard deviation instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean with
keepdims=Trueandaxismatching this function’saxisargument.correction (int | float | None) – int or float, default=None. Alternative name for
ddof. Both ddof and correction can’t be provided simultaneously.out (None) – Unused by JAX.
- Returns:
An array of the standard deviation along the given axis.
- Return type:
See also
jax.numpy.var(): Compute the variance of array elements over given axis.jax.numpy.mean(): Compute the mean of array elements over a given axis.jax.numpy.nanvar(): Compute the variance along a given axis, ignoring NaNs values.jax.numpy.nanstd(): Computed the standard deviation of a given axis, ignoring NaN values.
Examples
By default,
jnp.stdcomputes the standard deviation along all axes.>>> x = jnp.array([[1, 3, 4, 2], ... [4, 2, 5, 3], ... [5, 4, 2, 3]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jnp.std(x) Array(1.21, dtype=float32)
If
axis=0, computes along axis 0.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0)) [1.7 0.82 1.25 0.47]
To preserve the dimensions of input, you can set
keepdims=True.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0, keepdims=True)) [[1.7 0.82 1.25 0.47]]
If
ddof=1:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0, keepdims=True, ddof=1)) [[2.08 1. 1.53 0.58]]
To include specific elements of the array to compute standard deviation, you can use
where.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.std(x, axis=0, keepdims=True, where=where) Array([[2., 1., 1., 0.]], dtype=float32)