jax.scipy.linalg.dft#
- jax.scipy.linalg.dft(n, scale=None, *, dtype=None)[source]#
Construct an n-by-n discrete Fourier transform matrix.
JAX implementation of
scipy.linalg.dft().The DFT matrix \(W_n\) has entries \(W_{ij} = \omega^{ij}\), where \(\omega = e^{-2\pi i / n}\) is the primitive n-th root of unity, for \(0 \le i, j < n\).
- Parameters:
n (int) – size of the matrix.
scale (str | None) – (optional)
None(default, unscaled),'sqrtn'(scale by \(1/\sqrt{n}\), making the matrix unitary), or'n'(scale by \(1/n\)).dtype (DTypeLike | None) – (optional) complex floating-point dtype for the output. Defaults to JAX’s default complex dtype.
- Returns:
A DFT matrix of shape
(n, n).- Return type:
Examples
>>> jax.scipy.linalg.dft(4).round(3) Array([[ 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j], [ 1.+0.j, -0.-1.j, -1.+0.j, 0.+1.j], [ 1.+0.j, -1.+0.j, 1.-0.j, -1.+0.j], [ 1.+0.j, 0.+1.j, -1.+0.j, -0.-1.j]], dtype=complex64)