jax.lax.linalg.tridiagonal_solve

jax.lax.linalg.tridiagonal_solve#

jax.lax.linalg.tridiagonal_solve(dl, d, du, b, *, perturb_singular=False)[source]#

Computes the solution of a tridiagonal linear system.

This function computes the solution of a tridiagonal linear system:

\[A \, X = B\]
Parameters:
  • dl (Array) – A batch of vectors with shape [..., m]. The lower diagonal of A: dl[i] := A[i, i-1] for i in [0,m). Note that dl[0] = 0.

  • d (Array) – A batch of vectors with shape [..., m]. The middle diagonal of A: d[i]  := A[i, i] for i in [0,m).

  • du (Array) – A batch of vectors with shape [..., m]. The upper diagonal of A: du[i] := A[i, i+1] for i in [0,m). Note that dl[m - 1] = 0.

  • b (Array) – Right hand side matrix.

  • perturb_singular (bool) – Whether to perturb singular matrices to return a finite result. False by default. If True, solutions to systems involving a singular matrix will be computed by perturbing near-zero pivots in the partially pivoted LU decomposition. Specifically, tiny pivots are perturbed by an amount of order eps * max_{ij} |U(i,j)| to avoid overflow. Here U is the upper triangular part of the LU decomposition, and eps is the machine precision. This is useful for solving numerically singular systems when computing eigenvectors by inverse iteration. Only implemented on CPU and GPU at the moment.

Returns:

Solution X of tridiagonal system.

Return type:

Array