jax.lax.pad#

jax.lax.pad(operand, padding_value, padding_config)[source]#

Applies low, high, and/or interior padding to an array.

Wraps XLA’s Pad operator.

Parameters:
  • operand (ArrayLike) – an array to be padded.

  • padding_value (ArrayLike) – the value to be inserted as padding. Must have the same dtype as operand.

  • padding_config (Sequence[tuple[int, int, int]]) – a sequence of (low, high, interior) tuples of integers, giving the amount of low, high, and interior (dilation) padding to insert in each dimension. Negative values for low and high are allowed and remove elements from the edges of the array.

Returns:

The operand array with padding value padding_value inserted in each dimension according to the padding_config.

Return type:

Array

Examples

>>> from jax import lax
>>> import jax.numpy as jnp

Pad a 1-dimensional array with zeros, We’ll specify two zeros in front and three at the end:

>>> x = jnp.array([1, 2, 3, 4])
>>> lax.pad(x, 0, [(2, 3, 0)])
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)

Pad a 1-dimensional array with interior zeros; i.e. insert a single zero between each value:

>>> lax.pad(x, 0, [(0, 0, 1)])
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)

Pad a 2-dimensional array with the value -1 at front and end, with a pad size of 2 in each dimension:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
Array([[-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1,  1,  2,  3, -1, -1],
       [-1, -1,  4,  5,  6, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)

Use negative padding to remove elements from the edges of an array:

>>> x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.int32)
>>> lax.pad(x, 0, [(-1, -2, 0)])
Array([2, 3], dtype=int32)