jax.lax.tile

Contents

jax.lax.tile#

jax.lax.tile(operand, reps)[source]#

Tiles an array by repeating it along each dimension.

Parameters:
  • operand (ArrayLike) – an array to tile.

  • reps (Sequence[int]) – a sequence of integers representing the number of repeats for each dimension. Must have the same length as operand.ndim.

Returns:

A tiled array with shape (operand.shape[0] * reps[0], ..., operand.shape[-1] * reps[-1]).

Return type:

Array

Examples

>>> x = jnp.array([[1, 2], [3, 4]])
>>> lax.tile(x, (2, 3))
Array([[1, 2, 1, 2, 1, 2],
       [3, 4, 3, 4, 3, 4],
       [1, 2, 1, 2, 1, 2],
       [3, 4, 3, 4, 3, 4]], dtype=int32)
>>> y = jnp.array([1, 2, 3])
>>> lax.tile(y, (2,))
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>> z = jnp.array([[1], [2]])
>>> lax.tile(z, (1, 3))
Array([[1, 1, 1],
       [2, 2, 2]], dtype=int32)