Numerical Stability
Softmax and Numerical Stability
The softmax function maps arbitrary reals to a probability distribution. Getting it right numerically: avoiding overflow and underflow: is the first lesson in writing ML code that actually works.
Why This Matters
The softmax function appears in virtually every classification neural
network, every attention mechanism, every language model, and every
reinforcement learning policy. It converts a vector of real numbers (logits)
into a probability distribution. This sounds simple, but a naive
implementation will produce NaN or Inf on perfectly reasonable inputs.
Understanding why softmax breaks numerically, and how the log-sum-exp
trick fixes it, is the first real lesson in numerical computing for ML.
If you have ever seen NaN losses during training, there is a good chance
a softmax-related instability was the cause.
Mental Model
You have a vector of scores called logits. You want to convert them into probabilities that sum to 1. Softmax exponentiates each score and normalizes:
The problem: if , then overflows to Inf in
floating point. If , then underflows to . Both
are fatal for the computation.
The Softmax Function
Softmax Function
The softmax function maps to a probability vector:
Properties:
- Output is a valid probability distribution: and
- Monotone: if then
- Translation invariant: for any scalar
The translation invariance property is the key to numerical stability. Shifting all logits by the same constant does not change the output. We will exploit this shortly.
Why It Overflows and Underflows
In IEEE 754 double precision floating point:
- (near
DBL_MAX) - overflows to
Inf - underflows to
0
In single precision (float32, the ML default):
- overflows to
Inf - underflows to
0
A logit vector like in float32 produces , and .
Even with more modest logits, the ratio of a very small exponential to a very large sum gives underflow in the numerator, producing probabilities when they should be small but nonzero. This matters for cross-entropy loss, where .
The Log-Sum-Exp Trick
Numerical Stability of the Shifted Log-Sum-Exp
Statement
For any :
Setting ensures that the largest exponent is , preventing overflow. The other exponents are , so they cannot overflow either. The sum is in , so the log is in .
Intuition
By factoring out from the sum, we pull the "dangerously large" exponential outside as an additive constant in log space. What remains inside the log-sum-exp are manageable numbers between 0 and 1.
Proof Sketch
With , every , so . The largest term equals , so the sum is at least and the log is nonneg. No overflow occurs.
Why It Matters
This trick is used everywhere in numerical computing for ML. Cross-entropy
loss, KL divergence, log-likelihood of categorical distributions, attention
scores. all involve log-sum-exp internally. Every major ML framework
(PyTorch, JAX, TensorFlow) implements this automatically in functions like
log_softmax, cross_entropy, and logsumexp.
Failure Mode
Some of the terms may still underflow to when .
This is usually acceptable: those terms contribute negligibly to the sum.
The stability guarantee is against overflow (which produces NaN) rather
than underflow (which produces a small approximation error).
Stable Softmax and Log-Softmax
Using the trick, the stable softmax computation is:
Even more important is log-softmax, because cross-entropy loss uses directly:
Computing directly (without first computing softmax and then taking ) avoids the catastrophic loss of precision that occurs when softmax outputs a number very close to 1 and maps it to nearly 0.
This is why PyTorch has separate F.softmax and F.log_softmax functions,
and why F.cross_entropy takes raw logits rather than probabilities.
Temperature Scaling
Temperature-Scaled Softmax
The temperature parameter controls the "sharpness" of the distribution:
- As : the distribution approaches uniform (maximum entropy)
- As : the distribution concentrates on the argmax (minimum entropy)
- : standard softmax
Temperature is used in:
- Knowledge distillation: high softens teacher outputs to reveal inter-class relationships
- Language model sampling: makes generation more deterministic, makes it more random
- Reinforcement learning: Boltzmann exploration policies use temperature to trade off exploitation vs. exploration
Connection to Probability Distributions
Softmax arises naturally as the canonical link function for the categorical (multinoulli) distribution in the exponential family. If with natural parameters , then the mean parameters are recovered by:
In physics, this is the Boltzmann distribution (or Gibbs distribution): a system at temperature with energy levels occupies state with probability:
The partition function is precisely the denominator of softmax. Computing is the log-sum-exp.
Gumbel-Softmax for Differentiable Sampling
A key problem: sampling from a categorical distribution is not differentiable, so you cannot backpropagate through it. The Gumbel-softmax trick provides a differentiable approximation.
Gumbel-Softmax (Concrete Distribution)
To approximately sample from :
- Draw i.i.d. Gumbel noise: (equivalently, where )
- Compute the relaxed sample:
where is a temperature parameter.
As , approaches a one-hot vector (exact sample). For , is a "soft" sample on the simplex that admits gradients via backpropagation.
This is used extensively in VAEs with discrete latent variables, neural architecture search, and any setting where you need to "differentiate through a categorical choice."
The Gumbel-max theorem guarantees that is an exact sample from . Gumbel-softmax relaxes the to a , making it differentiable at the cost of approximation.
Common Confusions
Softmax IS argmax as temperature goes to zero
A common description is that softmax is a "soft version of argmax." This is correct but the implication often drawn is wrong. People sometimes think softmax approximates argmax but is structurally different. In fact, is exactly the argmax (as a one-hot vector), assuming a unique maximum. Softmax is a one-parameter family that interpolates continuously between the uniform distribution () and the argmax (). It is not an approximation. It is a temperature-parameterized generalization.
Never compute softmax then log. use log-softmax directly
If is very close to 1 (say ), then in float64 , losing all the information in the small correction. Computing directly as avoids this entirely, because the subtraction happens before any precision is lost. This is not a minor optimization. It is the difference between working and broken training.
Summary
- Softmax: . maps logits to probabilities
- Naive implementation overflows for large logits, underflows for very negative logits
- Log-sum-exp trick: subtract before exponentiating
- Always compute
log_softmaxdirectly, neverlog(softmax(z)) - Temperature : high = uniform, low = peaked at argmax
- Softmax is the canonical link for the categorical distribution and the Boltzmann distribution
- Gumbel-softmax provides differentiable approximate sampling from categoricals
Exercises
Problem
Implement a numerically stable log_softmax function. Given a vector
, compute for all
without overflow or unnecessary precision loss.
Write the formula and explain each step.
Problem
Show that for the Gumbel-max trick, where i.i.d. is distributed as . That is, prove:
References
Canonical:
- Goodfellow, Bengio, Courville, Deep Learning (2016), Section 4.1 (numerical stability)
- Bishop, Pattern Recognition and Machine Learning (2006), Section 4.3.4
Current:
-
Jang, Gu, Poole, "Categorical Reparameterization with Gumbel-Softmax" (2017)
-
Maddison, Mnih, Teh, "The Concrete Distribution" (2017)
-
Hastie, Tibshirani, Friedman, The Elements of Statistical Learning (2009)
Next Topics
The natural next step from numerical stability:
- Conditioning and condition number: formalizing when a computation is inherently sensitive to perturbations, beyond specific tricks
Last reviewed: April 2026
Builds on This
- Attention Mechanism TheoryLayer 4
- Decoding StrategiesLayer 3
- Flash AttentionLayer 5
- Log-Probability ComputationLayer 1
- Quantization TheoryLayer 5
- Transformer ArchitectureLayer 4