jax.lax.broadcast_like

Contents

jax.lax.broadcast_like#

jax.lax.broadcast_like(arr, like_arr)[source]#

Broadcasts an array to match the shape and sharding of another array.

Parameters:
  • arr (ArrayLike) – an array to be broadcasted.

  • like_arr (ArrayLike) – an array whose shape and sharding should be matched.

Returns:

An array containing the broadcasted values of arr.

Return type:

Array

See also

Examples

>>> import jax.numpy as jnp
>>> from jax import lax
>>> arr = jnp.array([1, 2, 3])
>>> like_arr = jnp.zeros((2, 3))
>>> lax.broadcast_like(arr, like_arr)
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)