jax.random.orthogonal

Contents

jax.random.orthogonal#

jax.random.orthogonal(key, n, shape=(), dtype=None, m=None, *, out_sharding=None)[source]#

Sample uniformly from the orthogonal group O(n).

If the dtype is complex, sample uniformly from the unitary group U(n).

For unequal rows and columns, this samples a semi-orthogonal matrix instead. That is, if \(A\) is the resulting matrix and \(A^*\) is its conjugate transpose, then:

  • If \(n \leq m\), the rows are mutually orthonormal: \(A A^* = I_n\).

  • If \(m \leq n\), the columns are mutually orthonormal: \(A^* A = I_m\).

Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • n (int) – an integer indicating the number of rows.

  • shape (Shape) – optional, the batch dimensions of the result. Default ().

  • dtype (DTypeLikeFloat | None) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • m (int | None) – an integer indicating the number of columns. Defaults to n.

  • out_sharding (NamedSharding | P | None) – optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array of shape (*shape, n, m) and specified dtype.

Return type:

Array

References