Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

Training Techniques

Mixed Precision Training

Train with FP16 or BF16 for speed while keeping FP32 master weights for accuracy. Loss scaling, overflow prevention, and when mixed precision fails.

CoreTier 2Current~45 min
0

Why This Matters

Training a large model in FP32 is 2x slower and uses 2x more memory than FP16/BF16 on modern GPUs. Mixed precision training gives you most of this speedup while maintaining FP32 accuracy. Every large language model trained since 2018 uses some form of mixed precision. Understanding the underlying floating-point arithmetic is essential for diagnosing training failures.

Mental Model

Keep a "master copy" of weights in FP32. For each training step: cast weights to FP16/BF16, run the forward and backward pass in reduced precision (fast), then update the FP32 master weights with the computed gradients. The reduced-precision passes are fast because GPU tensor cores operate at 2x throughput for FP16/BF16 compared to FP32.

Formal Setup

Definition

Mixed Precision Training

A training procedure that maintains model weights θ\theta in FP32 (the master weights) while computing forward activations al=fl(Wlfp16al1)a_l = f_l(W_l^{\text{fp16}} \cdot a_{l-1}) and gradients in FP16 or BF16. Weight updates are applied in FP32:

θt+1fp32=θtfp32ηgtfp32\theta^{\text{fp32}}_{t+1} = \theta^{\text{fp32}}_t - \eta \cdot g_t^{\text{fp32}}

where gtfp32g_t^{\text{fp32}} is the gradient cast back to FP32 before the update.

FP16 vs BF16

FP16 (IEEE half precision) uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The representable range is approximately [6×108,65504][6 \times 10^{-8}, 65504].

BF16 (brain floating point) uses 1 sign bit, 8 exponent bits, and 7 mantissa bits. The representable range is approximately [1038,3.4×1038][10^{-38}, 3.4 \times 10^{38}].

The critical difference: BF16 has the same exponent range as FP32 but lower precision. FP16 has higher precision but a much smaller exponent range. For training, the exponent range matters more than mantissa precision. Gradients that are very small (below 6×1086 \times 10^{-8}) underflow to zero in FP16 but are representable in BF16. Activations that are large (above 65504) overflow in FP16 but not in BF16.

This is why BF16 has largely replaced FP16 for training on hardware that supports it (A100 and later GPUs, TPUs).

Loss Scaling

Definition

Loss Scaling

Loss scaling multiplies the loss by a constant s>1s > 1 before the backward pass, then divides gradients by ss after:

g~=1sθ(sL(θ))\tilde{g} = \frac{1}{s} \nabla_\theta (s \cdot L(\theta))

By linearity of differentiation, g~=θL(θ)\tilde{g} = \nabla_\theta L(\theta). The purpose is to shift the gradient distribution into the representable range of FP16 during the backward pass.

In practice, dynamic loss scaling starts with a large ss (e.g., 2162^{16}) and halves ss whenever an overflow (NaN/Inf) is detected. If no overflow occurs for NN consecutive steps, ss is doubled. This adapts to the gradient magnitude throughout training.

Main Theorems

Proposition

Loss Scaling Preserves Gradient Direction

Statement

Let L(θ)L(\theta) be a differentiable loss function and s>0s > 0 a finite scaling factor. If no floating-point overflow occurs during computation of θ(sL(θ))\nabla_\theta(s \cdot L(\theta)), then:

1sθ(sL(θ))=θL(θ)\frac{1}{s}\nabla_\theta(s \cdot L(\theta)) = \nabla_\theta L(\theta)

in exact arithmetic. In FP16 arithmetic, the scaled computation preserves gradient components that would otherwise underflow to zero, at the cost of potentially overflowing large components.

Intuition

Loss scaling shifts the entire gradient histogram to the right on a log scale. Components that were below the FP16 minimum (6×1086 \times 10^{-8}) are lifted into representable range. Components near the FP16 maximum (6550465504) may overflow. Dynamic scaling finds the sweet spot automatically.

Proof Sketch

By the chain rule, θ(sL)=sθL\nabla_\theta(s \cdot L) = s \cdot \nabla_\theta L. Dividing by ss recovers θL\nabla_\theta L exactly. In floating-point arithmetic, the multiplication by ss shifts the exponent of each gradient component by log2(s)\log_2(s) bits, preventing underflow for components whose exponent was within log2(s)\log_2(s) of the minimum.

Why It Matters

Without loss scaling, FP16 training of deep networks fails because a large fraction of gradient components (often 50%+) fall below the FP16 minimum and become zero. Loss scaling is what makes FP16 training possible in practice. This is particularly important when combined with SGD convergence guarantees that assume nonzero gradient information.

Failure Mode

Loss scaling cannot help when gradients span a range wider than FP16 can represent (about 24 orders of magnitude). If the largest gradient component overflows even at scale s=1s = 1, or the smallest underflows even at maximum scale, mixed precision with FP16 breaks. BF16 avoids this by having a much wider exponent range.

When Mixed Precision Fails

Gradient accumulation errors. When accumulating gradients across microbatches in FP16, the running sum can lose precision. Small gradient contributions get rounded away when added to a large accumulator. The fix: accumulate in FP32.

Attention logit overflow. In transformers, the attention logits QKT/dQK^T / \sqrt{d} can exceed 65504 in FP16, causing NaN. This happens with long sequences or poorly scaled attention. The fix: compute attention in FP32, or use BF16 which handles the range.

Small weight updates. When the learning rate is very small and the gradient is moderate, the update ηg\eta \cdot g can underflow in FP16. The master weight strategy already handles this by updating in FP32, but naive implementations that skip master weights will fail.

Common Confusions

Watch Out

Mixed precision does not mean training in FP16

Mixed precision means using both FP16/BF16 and FP32 at different stages. Pure FP16 training (without FP32 master weights) diverges for most models. The "mixed" part is the key.

Watch Out

BF16 does not need loss scaling

Because BF16 has the same exponent range as FP32, gradient underflow is not a problem. Loss scaling is only needed for FP16. This is a major practical advantage of BF16: simpler training code with fewer failure modes.

Watch Out

Memory savings are not 2x

The master weights are still FP32. The memory savings come from FP16 activations (which dominate memory for large models) and FP16 gradients. For a model with NN parameters, you need 4N4N bytes for master weights plus 2N2N bytes for FP16 weights, versus 4N4N bytes for FP32 only. Activation memory savings are model-dependent.

Key Takeaways

  • Keep FP32 master weights, compute forward and backward in FP16/BF16
  • Loss scaling prevents gradient underflow in FP16 by shifting the gradient distribution
  • BF16 is preferred over FP16 for training because its wider exponent range eliminates most overflow and underflow issues
  • Dynamic loss scaling adapts automatically; start high and halve on overflow
  • Accumulate gradients in FP32 to avoid precision loss

Exercises

ExerciseCore

Problem

A gradient component has value 3×1083 \times 10^{-8} in FP32. Will this underflow to zero in FP16? If you apply loss scaling with s=210=1024s = 2^{10} = 1024, what is the scaled value, and does it survive in FP16?

ExerciseAdvanced

Problem

In a transformer with hidden dimension d=4096d = 4096 and sequence length L=8192L = 8192, the attention logits are QKT/dQK^T / \sqrt{d}. If QQ and KK entries have standard deviation σ=1\sigma = 1 after normalization, estimate the maximum attention logit and determine whether FP16 will overflow.

References

Canonical:

  • Micikevicius et al., "Mixed Precision Training" (2018), ICLR
  • Kalamkar et al., "A Study of BFLOAT16 for Deep Learning Training" (2019)

Current:

  • NVIDIA, "Training with Mixed Precision" documentation (2023)
  • Dehghani et al., "The Efficiency Misnomer" (2022), Section 4

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Next Topics