jax.nn.logmeanexp#
- jax.nn.logmeanexp(x, axis=None, where=None, keepdims=False)[source]#
Log mean exp.
Computes the function:
\[\text{logmeanexp}(x) = \log \frac{1}{n} \sum_{i=1}^n \exp x_i = \text{logsumexp}(x) - \log n\]- Parameters:
x (ArrayLike) – Input array.
axis (Axis) – Axis or axes along which to reduce.
where (ArrayLike | None) – Elements to include in the reduction. Optional.
keepdims (bool) – Preserve the dimensions of the input.
- Returns:
An array.
- Return type:
See also