Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

Numerical Stability

Softmax and Numerical Stability

The softmax function maps arbitrary reals to a probability distribution. Getting it right numerically: avoiding overflow and underflow: is the first lesson in writing ML code that actually works.

CoreTier 1Stable~45 min

Why This Matters

Temperature:1.0
59.5%catz=2.529.5%dogz=1.86.6%carz=0.33.0%treez=-0.51.5%fishz=-1.20%20%40%60%80%100%Entropy: 1.46 / 2.32 bitsStandard (T=1)
Low temperature sharpens the distribution (the top token dominates). High temperature flattens it (all tokens become equally likely). T=1 is the unmodified softmax.

The softmax function appears in virtually every classification neural network, every attention mechanism, every language model, and every reinforcement learning policy. It converts a vector of real numbers (logits) into a probability distribution. This sounds simple, but a naive implementation will produce NaN or Inf on perfectly reasonable inputs.

Understanding why softmax breaks numerically, and how the log-sum-exp trick fixes it, is the first real lesson in numerical computing for ML. If you have ever seen NaN losses during training, there is a good chance a softmax-related instability was the cause.

Mental Model

You have a vector of scores z=(z1,,zK)z = (z_1, \ldots, z_K) called logits. You want to convert them into probabilities that sum to 1. Softmax exponentiates each score and normalizes:

pi=ezij=1Kezjp_i = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}

The problem: if zi=1000z_i = 1000, then e1000e^{1000} overflows to Inf in floating point. If zi=1000z_i = -1000, then e1000e^{-1000} underflows to 00. Both are fatal for the computation.

The Softmax Function

Definition

Softmax Function

The softmax function maps zRKz \in \mathbb{R}^K to a probability vector:

softmax(z)i=exp(zi)j=1Kexp(zj),i=1,,K\text{softmax}(z)_i = \frac{\exp(z_i)}{\sum_{j=1}^K \exp(z_j)}, \quad i = 1, \ldots, K

Properties:

  • Output is a valid probability distribution: softmax(z)i>0\text{softmax}(z)_i > 0 and isoftmax(z)i=1\sum_i \text{softmax}(z)_i = 1
  • Monotone: if zi>zjz_i > z_j then softmax(z)i>softmax(z)j\text{softmax}(z)_i > \text{softmax}(z)_j
  • Translation invariant: softmax(z+c1)=softmax(z)\text{softmax}(z + c\mathbf{1}) = \text{softmax}(z) for any scalar cc

The translation invariance property is the key to numerical stability. Shifting all logits by the same constant does not change the output. We will exploit this shortly.

Why It Overflows and Underflows

In IEEE 754 double precision floating point:

  • e7098.2×10307e^{709} \approx 8.2 \times 10^{307} (near DBL_MAX)
  • e710e^{710} overflows to Inf
  • e745e^{-745} underflows to 0

In single precision (float32, the ML default):

  • e88e^{88} overflows to Inf
  • e104e^{-104} underflows to 0

A logit vector like z=(100,200,300)z = (100, 200, 300) in float32 produces exp(z)=(Inf,Inf,Inf)\exp(z) = (\text{Inf}, \text{Inf}, \text{Inf}), and Inf/Inf=NaN\text{Inf}/\text{Inf} = \text{NaN}.

Even with more modest logits, the ratio of a very small exponential to a very large sum gives underflow in the numerator, producing 00 probabilities when they should be small but nonzero. This matters for cross-entropy loss, where log(0)=\log(0) = -\infty.

The Log-Sum-Exp Trick

Proposition

Numerical Stability of the Shifted Log-Sum-Exp

Statement

For any cRc \in \mathbb{R}:

logj=1Kexp(zj)=c+logj=1Kexp(zjc)\log \sum_{j=1}^K \exp(z_j) = c + \log \sum_{j=1}^K \exp(z_j - c)

Setting c=maxjzjc = \max_j z_j ensures that the largest exponent is exp(0)=1\exp(0) = 1, preventing overflow. The other exponents are 1\leq 1, so they cannot overflow either. The sum is in [1,K][1, K], so the log is in [0,logK][0, \log K].

Intuition

By factoring out exp(c)\exp(c) from the sum, we pull the "dangerously large" exponential outside as an additive constant in log space. What remains inside the log-sum-exp are manageable numbers between 0 and 1.

Proof Sketch

logjexp(zj)=log[exp(c)jexp(zjc)]=c+logjexp(zjc)\log \sum_j \exp(z_j) = \log \left[\exp(c) \sum_j \exp(z_j - c)\right] = c + \log \sum_j \exp(z_j - c)

With c=maxjzjc = \max_j z_j, every zjc0z_j - c \leq 0, so exp(zjc)(0,1]\exp(z_j - c) \in (0, 1]. The largest term equals 11, so the sum is at least 11 and the log is nonneg. No overflow occurs.

Why It Matters

This trick is used everywhere in numerical computing for ML. Cross-entropy loss, KL divergence, log-likelihood of categorical distributions, attention scores. all involve log-sum-exp internally. Every major ML framework (PyTorch, JAX, TensorFlow) implements this automatically in functions like log_softmax, cross_entropy, and logsumexp.

Failure Mode

Some of the exp(zjc)\exp(z_j - c) terms may still underflow to 00 when zjcz_j \ll c. This is usually acceptable: those terms contribute negligibly to the sum. The stability guarantee is against overflow (which produces NaN) rather than underflow (which produces a small approximation error).

Stable Softmax and Log-Softmax

Using the trick, the stable softmax computation is:

softmax(z)i=exp(zic)jexp(zjc),c=maxjzj\text{softmax}(z)_i = \frac{\exp(z_i - c)}{\sum_j \exp(z_j - c)}, \qquad c = \max_j z_j

Even more important is log-softmax, because cross-entropy loss uses logpi\log p_i directly:

logsoftmax(z)i=ziclogjexp(zjc)\log \text{softmax}(z)_i = z_i - c - \log \sum_j \exp(z_j - c)

=zilogsumexp(z)= z_i - \text{logsumexp}(z)

Computing logsoftmax\log \text{softmax} directly (without first computing softmax and then taking log\log) avoids the catastrophic loss of precision that occurs when softmax outputs a number very close to 1 and log\log maps it to nearly 0.

This is why PyTorch has separate F.softmax and F.log_softmax functions, and why F.cross_entropy takes raw logits rather than probabilities.

Temperature Scaling

Definition

Temperature-Scaled Softmax

The temperature parameter T>0T > 0 controls the "sharpness" of the distribution:

softmax(z/T)i=exp(zi/T)jexp(zj/T)\text{softmax}(z / T)_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}

  • As TT \to \infty: the distribution approaches uniform (maximum entropy)
  • As T0+T \to 0^+: the distribution concentrates on the argmax (minimum entropy)
  • T=1T = 1: standard softmax

Temperature is used in:

  • Knowledge distillation: high TT softens teacher outputs to reveal inter-class relationships
  • Language model sampling: T<1T < 1 makes generation more deterministic, T>1T > 1 makes it more random
  • Reinforcement learning: Boltzmann exploration policies use temperature to trade off exploitation vs. exploration

Connection to Probability Distributions

Softmax arises naturally as the canonical link function for the categorical (multinoulli) distribution in the exponential family. If XCategorical(π1,,πK)X \sim \text{Categorical}(\pi_1, \ldots, \pi_K) with natural parameters ηi=logπi\eta_i = \log \pi_i, then the mean parameters are recovered by:

πi=exp(ηi)jexp(ηj)=softmax(η)i\pi_i = \frac{\exp(\eta_i)}{\sum_j \exp(\eta_j)} = \text{softmax}(\eta)_i

In physics, this is the Boltzmann distribution (or Gibbs distribution): a system at temperature TT with energy levels E1,,EKE_1, \ldots, E_K occupies state ii with probability:

pi=exp(Ei/T)jexp(Ej/T)=softmax(E/T)ip_i = \frac{\exp(-E_i / T)}{\sum_j \exp(-E_j / T)} = \text{softmax}(-E / T)_i

The partition function Z=jexp(Ej/T)Z = \sum_j \exp(-E_j / T) is precisely the denominator of softmax. Computing logZ\log Z is the log-sum-exp.

Gumbel-Softmax for Differentiable Sampling

A key problem: sampling from a categorical distribution is not differentiable, so you cannot backpropagate through it. The Gumbel-softmax trick provides a differentiable approximation.

Definition

Gumbel-Softmax (Concrete Distribution)

To approximately sample from Categorical(softmax(z))\text{Categorical}(\text{softmax}(z)):

  1. Draw i.i.d. Gumbel noise: giGumbel(0,1)g_i \sim \text{Gumbel}(0, 1) (equivalently, gi=log(log(ui))g_i = -\log(-\log(u_i)) where uiUniform(0,1)u_i \sim \text{Uniform}(0,1))
  2. Compute the relaxed sample:

yi=softmax(zi+giτ)y_i = \text{softmax}\left(\frac{z_i + g_i}{\tau}\right)

where τ>0\tau > 0 is a temperature parameter.

As τ0\tau \to 0, yy approaches a one-hot vector (exact sample). For τ>0\tau > 0, yy is a "soft" sample on the simplex that admits gradients via backpropagation.

This is used extensively in VAEs with discrete latent variables, neural architecture search, and any setting where you need to "differentiate through a categorical choice."

The Gumbel-max theorem guarantees that argmaxi(zi+gi)\arg\max_i (z_i + g_i) is an exact sample from Categorical(softmax(z))\text{Categorical}(\text{softmax}(z)). Gumbel-softmax relaxes the argmax\arg\max to a softmax\text{softmax}, making it differentiable at the cost of approximation.

Common Confusions

Watch Out

Softmax IS argmax as temperature goes to zero

A common description is that softmax is a "soft version of argmax." This is correct but the implication often drawn is wrong. People sometimes think softmax approximates argmax but is structurally different. In fact, limT0+softmax(z/T)\lim_{T \to 0^+} \text{softmax}(z/T) is exactly the argmax (as a one-hot vector), assuming a unique maximum. Softmax is a one-parameter family that interpolates continuously between the uniform distribution (TT \to \infty) and the argmax (T0T \to 0). It is not an approximation. It is a temperature-parameterized generalization.

Watch Out

Never compute softmax then log. use log-softmax directly

If pi=softmax(z)ip_i = \text{softmax}(z)_i is very close to 1 (say 110151 - 10^{-15}), then in float64 log(pi)log(1)=0\log(p_i) \approx \log(1) = 0, losing all the information in the small correction. Computing logsoftmax\log \text{softmax} directly as zilogsumexp(z)z_i - \text{logsumexp}(z) avoids this entirely, because the subtraction happens before any precision is lost. This is not a minor optimization. It is the difference between working and broken training.

Summary

  • Softmax: pi=exp(zi)/jexp(zj)p_i = \exp(z_i) / \sum_j \exp(z_j). maps logits to probabilities
  • Naive implementation overflows for large logits, underflows for very negative logits
  • Log-sum-exp trick: subtract maxjzj\max_j z_j before exponentiating
  • Always compute log_softmax directly, never log(softmax(z))
  • Temperature TT: high = uniform, low = peaked at argmax
  • Softmax is the canonical link for the categorical distribution and the Boltzmann distribution
  • Gumbel-softmax provides differentiable approximate sampling from categoricals

Exercises

ExerciseCore

Problem

Implement a numerically stable log_softmax function. Given a vector zRKz \in \mathbb{R}^K, compute logsoftmax(z)i\log \text{softmax}(z)_i for all ii without overflow or unnecessary precision loss.

Write the formula and explain each step.

ExerciseAdvanced

Problem

Show that for the Gumbel-max trick, argmaxi(zi+gi)\arg\max_i (z_i + g_i) where giGumbel(0,1)g_i \sim \text{Gumbel}(0,1) i.i.d. is distributed as Categorical(softmax(z))\text{Categorical}(\text{softmax}(z)). That is, prove:

P[argmaxi(zi+gi)=k]=softmax(z)kP\left[\arg\max_i (z_i + g_i) = k\right] = \text{softmax}(z)_k

References

Canonical:

  • Goodfellow, Bengio, Courville, Deep Learning (2016), Section 4.1 (numerical stability)
  • Bishop, Pattern Recognition and Machine Learning (2006), Section 4.3.4

Current:

  • Jang, Gu, Poole, "Categorical Reparameterization with Gumbel-Softmax" (2017)

  • Maddison, Mnih, Teh, "The Concrete Distribution" (2017)

  • Hastie, Tibshirani, Friedman, The Elements of Statistical Learning (2009)

Next Topics

The natural next step from numerical stability:

Last reviewed: April 2026

Builds on This

Next Topics