jax.lax.broadcast#
- jax.lax.broadcast(operand, sizes, *, out_sharding=None)[source]#
Broadcasts an array, adding new leading dimensions only.
- Parameters:
operand (ArrayLike) – an array
sizes (Sequence[int]) – a sequence of integers, giving the sizes of new leading dimensions to add to the front of the array.
- Returns:
The result array, of shape
(*sizes, *operand.shape)containing broadcasted values ofoperand.- Return type:
See also
jax.lax.broadcast_in_dim(): general broadcasting at any dimension in the array.jax.numpy.broadcast_to(): NumPy-style API for general broadcasting.
Examples
>>> import jax.numpy as jnp >>> from jax import lax >>> arr = jnp.zeros((4, 5)) >>> result = lax.broadcast(arr, (2, 3)) >>> result.shape (2, 3, 4, 5)