Training Techniques
Mixed Precision Training
Train with FP16 or BF16 for speed while keeping FP32 master weights for accuracy. Loss scaling, overflow prevention, and when mixed precision fails.
Prerequisites
Why This Matters
Training a large model in FP32 is 2x slower and uses 2x more memory than FP16/BF16 on modern GPUs. Mixed precision training gives you most of this speedup while maintaining FP32 accuracy. Every large language model trained since 2018 uses some form of mixed precision. Understanding the underlying floating-point arithmetic is essential for diagnosing training failures.
Mental Model
Keep a "master copy" of weights in FP32. For each training step: cast weights to FP16/BF16, run the forward and backward pass in reduced precision (fast), then update the FP32 master weights with the computed gradients. The reduced-precision passes are fast because GPU tensor cores operate at 2x throughput for FP16/BF16 compared to FP32.
Formal Setup
Mixed Precision Training
A training procedure that maintains model weights in FP32 (the master weights) while computing forward activations and gradients in FP16 or BF16. Weight updates are applied in FP32:
where is the gradient cast back to FP32 before the update.
FP16 vs BF16
FP16 (IEEE half precision) uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The representable range is approximately .
BF16 (brain floating point) uses 1 sign bit, 8 exponent bits, and 7 mantissa bits. The representable range is approximately .
The critical difference: BF16 has the same exponent range as FP32 but lower precision. FP16 has higher precision but a much smaller exponent range. For training, the exponent range matters more than mantissa precision. Gradients that are very small (below ) underflow to zero in FP16 but are representable in BF16. Activations that are large (above 65504) overflow in FP16 but not in BF16.
This is why BF16 has largely replaced FP16 for training on hardware that supports it (A100 and later GPUs, TPUs).
Loss Scaling
Loss Scaling
Loss scaling multiplies the loss by a constant before the backward pass, then divides gradients by after:
By linearity of differentiation, . The purpose is to shift the gradient distribution into the representable range of FP16 during the backward pass.
In practice, dynamic loss scaling starts with a large (e.g., ) and halves whenever an overflow (NaN/Inf) is detected. If no overflow occurs for consecutive steps, is doubled. This adapts to the gradient magnitude throughout training.
Main Theorems
Loss Scaling Preserves Gradient Direction
Statement
Let be a differentiable loss function and a finite scaling factor. If no floating-point overflow occurs during computation of , then:
in exact arithmetic. In FP16 arithmetic, the scaled computation preserves gradient components that would otherwise underflow to zero, at the cost of potentially overflowing large components.
Intuition
Loss scaling shifts the entire gradient histogram to the right on a log scale. Components that were below the FP16 minimum () are lifted into representable range. Components near the FP16 maximum () may overflow. Dynamic scaling finds the sweet spot automatically.
Proof Sketch
By the chain rule, . Dividing by recovers exactly. In floating-point arithmetic, the multiplication by shifts the exponent of each gradient component by bits, preventing underflow for components whose exponent was within of the minimum.
Why It Matters
Without loss scaling, FP16 training of deep networks fails because a large fraction of gradient components (often 50%+) fall below the FP16 minimum and become zero. Loss scaling is what makes FP16 training possible in practice. This is particularly important when combined with SGD convergence guarantees that assume nonzero gradient information.
Failure Mode
Loss scaling cannot help when gradients span a range wider than FP16 can represent (about 24 orders of magnitude). If the largest gradient component overflows even at scale , or the smallest underflows even at maximum scale, mixed precision with FP16 breaks. BF16 avoids this by having a much wider exponent range.
When Mixed Precision Fails
Gradient accumulation errors. When accumulating gradients across microbatches in FP16, the running sum can lose precision. Small gradient contributions get rounded away when added to a large accumulator. The fix: accumulate in FP32.
Attention logit overflow. In transformers, the attention logits can exceed 65504 in FP16, causing NaN. This happens with long sequences or poorly scaled attention. The fix: compute attention in FP32, or use BF16 which handles the range.
Small weight updates. When the learning rate is very small and the gradient is moderate, the update can underflow in FP16. The master weight strategy already handles this by updating in FP32, but naive implementations that skip master weights will fail.
Common Confusions
Mixed precision does not mean training in FP16
Mixed precision means using both FP16/BF16 and FP32 at different stages. Pure FP16 training (without FP32 master weights) diverges for most models. The "mixed" part is the key.
BF16 does not need loss scaling
Because BF16 has the same exponent range as FP32, gradient underflow is not a problem. Loss scaling is only needed for FP16. This is a major practical advantage of BF16: simpler training code with fewer failure modes.
Memory savings are not 2x
The master weights are still FP32. The memory savings come from FP16 activations (which dominate memory for large models) and FP16 gradients. For a model with parameters, you need bytes for master weights plus bytes for FP16 weights, versus bytes for FP32 only. Activation memory savings are model-dependent.
Key Takeaways
- Keep FP32 master weights, compute forward and backward in FP16/BF16
- Loss scaling prevents gradient underflow in FP16 by shifting the gradient distribution
- BF16 is preferred over FP16 for training because its wider exponent range eliminates most overflow and underflow issues
- Dynamic loss scaling adapts automatically; start high and halve on overflow
- Accumulate gradients in FP32 to avoid precision loss
Exercises
Problem
A gradient component has value in FP32. Will this underflow to zero in FP16? If you apply loss scaling with , what is the scaled value, and does it survive in FP16?
Problem
In a transformer with hidden dimension and sequence length , the attention logits are . If and entries have standard deviation after normalization, estimate the maximum attention logit and determine whether FP16 will overflow.
References
Canonical:
- Micikevicius et al., "Mixed Precision Training" (2018), ICLR
- Kalamkar et al., "A Study of BFLOAT16 for Deep Learning Training" (2019)
Current:
- NVIDIA, "Training with Mixed Precision" documentation (2023)
- Dehghani et al., "The Efficiency Misnomer" (2022), Section 4
Last reviewed: April 2026
Prerequisites
Foundations this topic depends on.
- Floating-Point ArithmeticLayer 0A