jax.scipy.linalg.dft

Contents

jax.scipy.linalg.dft#

jax.scipy.linalg.dft(n, scale=None, *, dtype=None)[source]#

Construct an n-by-n discrete Fourier transform matrix.

JAX implementation of scipy.linalg.dft().

The DFT matrix \(W_n\) has entries \(W_{ij} = \omega^{ij}\), where \(\omega = e^{-2\pi i / n}\) is the primitive n-th root of unity, for \(0 \le i, j < n\).

Parameters:
  • n (int) – size of the matrix.

  • scale (str | None) – (optional) None (default, unscaled), 'sqrtn' (scale by \(1/\sqrt{n}\), making the matrix unitary), or 'n' (scale by \(1/n\)).

  • dtype (DTypeLike | None) – (optional) complex floating-point dtype for the output. Defaults to JAX’s default complex dtype.

Returns:

A DFT matrix of shape (n, n).

Return type:

Array

Examples

>>> jax.scipy.linalg.dft(4).round(3)
Array([[ 1.+0.j,  1.+0.j,  1.+0.j,  1.+0.j],
       [ 1.+0.j, -0.-1.j, -1.+0.j,  0.+1.j],
       [ 1.+0.j, -1.+0.j,  1.-0.j, -1.+0.j],
       [ 1.+0.j,  0.+1.j, -1.+0.j, -0.-1.j]], dtype=complex64)