jax.scipy.linalg.hankel#
- jax.scipy.linalg.hankel(c, r=None)[source]#
Construct a Hankel matrix.
JAX implementation of
scipy.linalg.hankel().A Hankel matrix has constant anti-diagonals:
H[i, j] = v[i + j], wherev = concatenate([c, r[1:]]). Notice this implies thatr[0]is ignored.- Parameters:
c (ArrayLike) – array of shape
(..., N)specifying the first column.r (ArrayLike | None) – (optional) array of shape
(..., M)specifying the last row. Leading dimensions must be broadcast-compatible with those ofc. If not specified,rdefaults tozeros_like(c).
- Returns:
A Hankel matrix of shape
(..., N, M).- Return type:
Examples
>>> c = jnp.array([1, 2, 3]) >>> jax.scipy.linalg.hankel(c) Array([[1, 2, 3], [2, 3, 0], [3, 0, 0]], dtype=int32)
>>> r = jnp.array([999, 4, 5, 6]) # Note r[0] is ignored >>> jax.scipy.linalg.hankel(c, r) Array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32)
For N-dimensional
cand/orr, the result is a batch of Hankel matrices.