jax.lax.mul#
- jax.lax.mul(x, y, *, out_dtype=None)[source]#
Elementwise multiplication: \(x \times y\).
This function lowers directly to the stablehlo.multiply operation.
- Parameters:
x (ArrayLike) – Input arrays. Must have matching numerical dtypes. If neither is a scalar,
xandymust have the same number of dimensions and be broadcast compatible.y (ArrayLike) – Input arrays. Must have matching numerical dtypes. If neither is a scalar,
xandymust have the same number of dimensions and be broadcast compatible.out_dtype (DTypeLike | None) – Optional. Either
None(default), or a dtype. If it is a dtype, the output will be of the specified dtype. Typically, this is accomplished by casting the inputs to the specified dtype before the multiplication is performed, but on some backends this may be done via a custom kernel.
- Returns:
An array of the same dtype as
xandycontaining the product of each pair of broadcasted entries.- Return type:
See also
jax.numpy.multiply(): NumPy-style multiplication supporting inputs with mixed dtypes and ranks.