jax.lax.linalg.ormqr#
- jax.lax.linalg.ormqr(a, taus, c, *, left=True, transpose=False)[source]#
Multiplies a matrix by Q from a QR factorization without materializing Q.
Computes
Q @ C(left=True,transpose=False),Q^T @ C(left=True,transpose=True),C @ Q(left=False,transpose=False), orC @ Q^T(left=False,transpose=True).For complex types,
transpose=Truecomputes the conjugate transpose (Q^H).- Parameters:
a (ArrayLike) – The Householder reflectors with shape
[..., m, n], as returned by the internalgeqrf/geqp3primitives. Alternatively, one can usejax.numpy.linalg.qr()withmode="raw", but in this case the returnedamust be transposed with.mT(see example below).taus (ArrayLike) – The Householder scalar factors with shape
[..., k], as returned bygeqrf/geqp3or the second element of the tuple fromjax.numpy.linalg.qr()withmode="raw".c (ArrayLike) – The matrix to multiply by Q, with shape
[..., c_rows, c_cols].left (bool) – If
True, computeQ @ C. IfFalse, computeC @ Q.transpose (bool) – If
True, useQ^T(orQ^Hfor complex types).
- Returns:
The result of multiplying
cby Q (orQ^T/Q^H), with the same shape asc.- Return type:
Examples
Multiply a vector by Q without forming Q explicitly:
>>> import jax.numpy as jnp >>> from jax.lax.linalg import ormqr >>> a = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) >>> h, taus = jnp.linalg.qr(a, mode="raw") >>> c = jnp.eye(3) >>> Q_times_c = ormqr(h.mT, taus, c) >>> Q_direct, _ = jnp.linalg.qr(a, mode="complete") >>> jnp.allclose(Q_times_c, Q_direct, atol=1e-5) Array(True, dtype=bool)
See also
jax.scipy.linalg.qr_multiply(): Higher-level API for computing Q @ C or C @ Q from a matrixadirectly.