jax.nn.logmeanexp#

jax.nn.logmeanexp(x, axis=None, where=None, keepdims=False)[source]#

Log mean exp.

Computes the function:

\[\text{logmeanexp}(x) = \log \frac{1}{n} \sum_{i=1}^n \exp x_i = \text{logsumexp}(x) - \log n\]
Parameters:
  • x (ArrayLike) – Input array.

  • axis (Axis) – Axis or axes along which to reduce.

  • where (ArrayLike | None) – Elements to include in the reduction. Optional.

  • keepdims (bool) – Preserve the dimensions of the input.

Returns:

An array.

Return type:

Array