jax.nn.softmax#

jax.nn.softmax(x, axis=-1, where=None)[source]#

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters:
  • x (ArrayLike) – input array

  • axis (Axis) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer, tuple of integers, or None (all axes).

  • where (ArrayLike | None) – Elements to include in the softmax. The output for any masked-out element is zero.

Returns:

An array.

Return type:

Array

Note

If any input values are +inf, the result will be all NaN: this reflects the fact that inf / inf is not well-defined in the context of floating-point math.

See also

log_softmax()