jax.lax.broadcast_like#
- jax.lax.broadcast_like(arr, like_arr)[source]#
Broadcasts an array to match the shape and sharding of another array.
- Parameters:
arr (ArrayLike) – an array to be broadcasted.
like_arr (ArrayLike) – an array whose shape and sharding should be matched.
- Returns:
An array containing the broadcasted values of
arr.- Return type:
See also
jax.lax.broadcast(): simpler interface to add new leading dimensions.jax.lax.broadcast_in_dim(): general broadcasting at any dimension in the array.jax.numpy.broadcast_to(): NumPy-style API for general broadcasting.
Examples
>>> import jax.numpy as jnp >>> from jax import lax >>> arr = jnp.array([1, 2, 3]) >>> like_arr = jnp.zeros((2, 3)) >>> lax.broadcast_like(arr, like_arr) Array([[1, 2, 3], [1, 2, 3]], dtype=int32)