jax.experimental.pallas.broadcast_to#

jax.experimental.pallas.broadcast_to(a, shape)[source]#

Broadcasts an array to a new shape.

Parameters:
  • a (Array) – The array to broadcast.

  • shape (tuple[int, ...]) – The desired shape to broadcast to.

Returns:

An array of shape shape.

Return type:

Array