jax.scipy.linalg.qr_multiply#
- jax.scipy.linalg.qr_multiply(a: ArrayLike, c: ArrayLike, mode: str = 'right', pivoting: Literal[False] = False, conjugate: bool = False, overwrite_a: bool = False, overwrite_c: bool = False) tuple[Array, Array][source]#
- jax.scipy.linalg.qr_multiply(a: ArrayLike, c: ArrayLike, mode: str = 'right', pivoting: Literal[True] = True, conjugate: bool = False, overwrite_a: bool = False, overwrite_c: bool = False) tuple[Array, Array, Array]
- jax.scipy.linalg.qr_multiply(a: ArrayLike, c: ArrayLike, mode: str = 'right', pivoting: bool = False, conjugate: bool = False, overwrite_a: bool = False, overwrite_c: bool = False) tuple[Array, Array] | tuple[Array, Array, Array]
Calculate the QR decomposition and multiply Q with a matrix.
JAX implementation of
scipy.linalg.qr_multiply().- Parameters:
a β array of shape
(..., M, N). Matrix to be decomposed.c β array to be multiplied by Q. For
mode='left',chas shape(..., K, P)whereK = min(M, N). Formode='right',chas shape(..., P, M). 1-D arrays are supported: formode='left', treated as a length-Kcolumn vector; formode='right', treated as a length-Mrow vector. The result is raveled back to 1-D in either case.mode β
'right'(default) or'left'.'left': computeQ @ c(orconj(Q) @ cifconjugate=True) and return(Q @ c, R)with result shape(..., M, P)'right': computec @ Q(orc @ conj(Q)ifconjugate=True) and return(c @ Q, R)with result shape(..., P, K)whereK = min(M, N)
pivoting β Allows the QR decomposition to be rank-revealing. If
True, compute the column-pivoted QR decomposition and return permutation indices as a third element.conjugate β If
True, useconj(Q)(element-wise complex conjugate) instead ofQ. For real arrays this has no effect.overwrite_a β unused in JAX
overwrite_c β unused in JAX
- Returns:
(result, R)If
pivotingisTrue:(result, R, P)- Return type:
If
pivotingisFalse
See also
jax.scipy.linalg.qr(): SciPy-style QR decomposition APIjax.lax.linalg.ormqr(): XLA-style Q-multiply primitive
Examples
Use
qr_multiply()to efficiently solve a least-squares problem. For an overdetermined systemA @ x β b, passbas a 1-D row viamode='right'to obtainQ^T @ bandRin one step:>>> import jax >>> import jax.numpy as jnp >>> A = jnp.array([[1., 1.], [1., 2.], [1., 3.], [1., 4.]]) >>> b = jnp.array([2., 4., 5., 4.]) >>> Qtb, R = jax.scipy.linalg.qr_multiply(A, b, mode='right') >>> x = jax.scipy.linalg.solve_triangular(R, Qtb) >>> jnp.allclose(A.T @ A @ x, A.T @ b) Array(True, dtype=bool)