jax.lax.tile#
- jax.lax.tile(operand, reps)[source]#
Tiles an array by repeating it along each dimension.
- Parameters:
operand (ArrayLike) – an array to tile.
reps (Sequence[int]) – a sequence of integers representing the number of repeats for each dimension. Must have the same length as
operand.ndim.
- Returns:
A tiled array with shape
(operand.shape[0] * reps[0], ..., operand.shape[-1] * reps[-1]).- Return type:
Examples
>>> x = jnp.array([[1, 2], [3, 4]]) >>> lax.tile(x, (2, 3)) Array([[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4], [1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]], dtype=int32)
>>> y = jnp.array([1, 2, 3]) >>> lax.tile(y, (2,)) Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>> z = jnp.array([[1], [2]]) >>> lax.tile(z, (1, 3)) Array([[1, 1, 1], [2, 2, 2]], dtype=int32)