jax.lax.mul

Contents

jax.lax.mul#

jax.lax.mul(x, y, *, out_dtype=None)[source]#

Elementwise multiplication: \(x \times y\).

This function lowers directly to the stablehlo.multiply operation.

Parameters:
  • x (ArrayLike) – Input arrays. Must have matching numerical dtypes. If neither is a scalar, x and y must have the same number of dimensions and be broadcast compatible.

  • y (ArrayLike) – Input arrays. Must have matching numerical dtypes. If neither is a scalar, x and y must have the same number of dimensions and be broadcast compatible.

  • out_dtype (DTypeLike | None) – Optional. Either None (default), or a dtype. If it is a dtype, the output will be of the specified dtype. Typically, this is accomplished by casting the inputs to the specified dtype before the multiplication is performed, but on some backends this may be done via a custom kernel.

Returns:

An array of the same dtype as x and y containing the product of each pair of broadcasted entries.

Return type:

Array

See also