jax.lax.broadcast

Contents

jax.lax.broadcast#

jax.lax.broadcast(operand, sizes, *, out_sharding=None)[source]#

Broadcasts an array, adding new leading dimensions only.

Parameters:
  • operand (ArrayLike) – an array

  • sizes (Sequence[int]) – a sequence of integers, giving the sizes of new leading dimensions to add to the front of the array.

Returns:

The result array, of shape (*sizes, *operand.shape) containing broadcasted values of operand.

Return type:

Array

See also

Examples

>>> import jax.numpy as jnp
>>> from jax import lax
>>> arr = jnp.zeros((4, 5))
>>> result = lax.broadcast(arr, (2, 3))
>>> result.shape
(2, 3, 4, 5)