jax.numpy.trim_zeros#
- jax.numpy.trim_zeros(filt, trim='fb', axis=None)[source]#
Trim leading and/or trailing zeros of the input array.
JAX implementation of
numpy.trim_zeros().- Parameters:
filt (ArrayLike) – N-dimensional input array.
trim (str) –
string, optional, default =
fb. Specifies from which end the input is trimmed.f- trims only the leading zeros.b- trims only the trailing zeros.fb- trims both leading and trailing zeros.
axis (int | Sequence[int] | None) – optional axis or axes along which to trim. If not specified, trim along all axes of the array.
- Returns:
An array containing the trimmed input with same dtype as
filt.- Return type:
Examples
One-dimensional input:
>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) >>> jnp.trim_zeros(x, trim='f') Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32) >>> jnp.trim_zeros(x, trim='b') Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)
Two-dimensional input:
>>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1) >>> x Array([[0., 0., 0., 0., 0.], [0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.], [0., 0., 0., 0., 0.]], dtype=float32) >>> jnp.trim_zeros(x) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32) >>> jnp.trim_zeros(x, trim='f') Array([[1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32) >>> jnp.trim_zeros(x, axis=0) Array([[0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.]], dtype=float32) >>> jnp.trim_zeros(x, axis=1) Array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.], [0., 0., 0.]], dtype=float32)