jax.numpy.trim_zeros#

jax.numpy.trim_zeros(filt, trim='fb', axis=None)[source]#

Trim leading and/or trailing zeros of the input array.

JAX implementation of numpy.trim_zeros().

Parameters:
  • filt (ArrayLike) – N-dimensional input array.

  • trim (str) –

    string, optional, default = fb. Specifies from which end the input is trimmed.

    • f - trims only the leading zeros.

    • b - trims only the trailing zeros.

    • fb - trims both leading and trailing zeros.

  • axis (int | Sequence[int] | None) – optional axis or axes along which to trim. If not specified, trim along all axes of the array.

Returns:

An array containing the trimmed input with same dtype as filt.

Return type:

Array

Examples

One-dimensional input:

>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0])
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)
>>> jnp.trim_zeros(x, trim='f')
Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32)
>>> jnp.trim_zeros(x, trim='b')
Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)

Two-dimensional input:

>>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1)
>>> x
Array([[0., 0., 0., 0., 0.],
       [0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
>>> jnp.trim_zeros(x)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
>>> jnp.trim_zeros(x, trim='f')
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)
>>> jnp.trim_zeros(x, axis=0)
Array([[0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 0.]], dtype=float32)
>>> jnp.trim_zeros(x, axis=1)
Array([[0., 0., 0.],
       [1., 1., 1.],
       [1., 1., 1.],
       [0., 0., 0.]], dtype=float32)