jax.nn.softmax#
- jax.nn.softmax(x, axis=-1, where=None)[source]#
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axissum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
x (ArrayLike) – input array
axis (Axis) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer, tuple of integers, or
None(all axes).where (ArrayLike | None) – Elements to include in the
softmax. The output for any masked-out element is zero.
- Returns:
An array.
- Return type:
Note
If any input values are
+inf, the result will be allNaN: this reflects the fact thatinf / infis not well-defined in the context of floating-point math.See also