jax.scipy.linalg.qr_multiply

Contents

jax.scipy.linalg.qr_multiply#

jax.scipy.linalg.qr_multiply(a: ArrayLike, c: ArrayLike, mode: str = 'right', pivoting: Literal[False] = False, conjugate: bool = False, overwrite_a: bool = False, overwrite_c: bool = False) tuple[Array, Array][source]#
jax.scipy.linalg.qr_multiply(a: ArrayLike, c: ArrayLike, mode: str = 'right', pivoting: Literal[True] = True, conjugate: bool = False, overwrite_a: bool = False, overwrite_c: bool = False) tuple[Array, Array, Array]
jax.scipy.linalg.qr_multiply(a: ArrayLike, c: ArrayLike, mode: str = 'right', pivoting: bool = False, conjugate: bool = False, overwrite_a: bool = False, overwrite_c: bool = False) tuple[Array, Array] | tuple[Array, Array, Array]

Calculate the QR decomposition and multiply Q with a matrix.

JAX implementation of scipy.linalg.qr_multiply().

Parameters:
  • a – array of shape (..., M, N). Matrix to be decomposed.

  • c – array to be multiplied by Q. For mode='left', c has shape (..., K, P) where K = min(M, N). For mode='right', c has shape (..., P, M). 1-D arrays are supported: for mode='left', treated as a length-K column vector; for mode='right', treated as a length-M row vector. The result is raveled back to 1-D in either case.

  • mode –

    'right' (default) or 'left'.

    • 'left': compute Q @ c (or conj(Q) @ c if conjugate=True) and return (Q @ c, R) with result shape (..., M, P)

    • 'right': compute c @ Q (or c @ conj(Q) if conjugate=True) and return (c @ Q, R) with result shape (..., P, K) where K = min(M, N)

  • pivoting – Allows the QR decomposition to be rank-revealing. If True, compute the column-pivoted QR decomposition and return permutation indices as a third element.

  • conjugate – If True, use conj(Q) (element-wise complex conjugate) instead of Q. For real arrays this has no effect.

  • overwrite_a – unused in JAX

  • overwrite_c – unused in JAX

Returns:

(result, R)

If pivoting is True: (result, R, P)

Return type:

If pivoting is False

See also

Examples

Use qr_multiply() to efficiently solve a least-squares problem. For an overdetermined system A @ x β‰ˆ b, pass b as a 1-D row via mode='right' to obtain Q^T @ b and R in one step:

>>> import jax
>>> import jax.numpy as jnp
>>> A = jnp.array([[1., 1.], [1., 2.], [1., 3.], [1., 4.]])
>>> b = jnp.array([2., 4., 5., 4.])
>>> Qtb, R = jax.scipy.linalg.qr_multiply(A, b, mode='right')
>>> x = jax.scipy.linalg.solve_triangular(R, Qtb)
>>> jnp.allclose(A.T @ A @ x, A.T @ b)
Array(True, dtype=bool)