jax.scipy.fft.idctn#
- jax.scipy.fft.idctn(x, type=2, s=None, axes=None, norm=None)[source]#
Computes the multidimensional inverse discrete cosine transform of the input
JAX implementation of
scipy.fft.idctn().- Parameters:
x (Array) – array
type (int) – integer, default = 2. Currently only type 2 is supported.
s (Sequence[int] | None) – integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of
xalong the specifiedaxes.axes (Sequence[int] | None) – integer or sequence of integers. Specifies the axes along which the transform will be computed. If not given, the last
len(s)axes are used, or all axes ifsis also not specified.norm (str | None) – string. The normalization mode: one of
[None, "backward", "ortho"]. The default isNone, which is equivalent to"backward".
- Returns:
array containing the inverse discrete cosine transform of x
- Return type:
See also
jax.scipy.fft.dct(): one-dimensional DCTjax.scipy.fft.dctn(): multidimensional DCTjax.scipy.fft.idct(): one-dimensional inverse DCT
Examples
jax.scipy.fft.idctncomputes the transform along both the axes by default whenaxesargument isNoneandsis alsoNone.>>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x)) [[ 0.12 0.11 -0.15] [ 0.07 0.17 -0.03] [ 0.19 -0.07 -0.02]]
When
s=[2], the transform will be computed only along the last axis, with its dimension padded or truncated to size2:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x, s=[2])) [[ 1.12 -0.31] [ 0.04 -0.08] [ 0.05 -0.3 ]]
When
s=[2]andaxes=[0], the transform will be computed only along the specified axis, with its dimension padded or truncated to size2:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x, s=[2], axes=[0])) [[ 0.38 0.57 -0.45] [ 0.43 0.44 0.24]]
When
s=[2, 4], shape of the transform will be(2, 4)>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x, s=[2, 4])) [[ 0.1 0.18 0.07 -0.16] [ 0.2 0.06 -0.03 -0.01]]
jax.scipy.fft.idctncan be used to reconstructxfrom the result ofjax.scipy.fft.dctn>>> x_dctn = jax.scipy.fft.dctn(x) >>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn)) Array(True, dtype=bool)