jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[source]#
Applies low, high, and/or interior padding to an array.
Wraps XLA’s Pad operator.
- Parameters:
operand (ArrayLike) – an array to be padded.
padding_value (ArrayLike) – the value to be inserted as padding. Must have the same dtype as
operand.padding_config (Sequence[tuple[int, int, int]]) – a sequence of
(low, high, interior)tuples of integers, giving the amount of low, high, and interior (dilation) padding to insert in each dimension. Negative values forlowandhighare allowed and remove elements from the edges of the array.
- Returns:
The
operandarray with padding valuepadding_valueinserted in each dimension according to thepadding_config.- Return type:
Examples
>>> from jax import lax >>> import jax.numpy as jnp
Pad a 1-dimensional array with zeros, We’ll specify two zeros in front and three at the end:
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
Pad a 1-dimensional array with interior zeros; i.e. insert a single zero between each value:
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
Pad a 2-dimensional array with the value
-1at front and end, with a pad size of 2 in each dimension:>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
Use negative padding to remove elements from the edges of an array:
>>> x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.int32) >>> lax.pad(x, 0, [(-1, -2, 0)]) Array([2, 3], dtype=int32)