Training Techniques
Activation Checkpointing
Trade compute for memory by recomputing activations during the backward pass instead of storing them all. Reduces memory from O(L) to O(sqrt(L)) for L layers.
Prerequisites
Why This Matters
For a network with layers, standard backpropagation stores all layers of activations simultaneously. For large transformers, activation memory dominates total GPU memory. A 7B parameter transformer with sequence length 2048 can require 60+ GB just for activations. Activation checkpointing reduces this to a manageable level, making it possible to train models that would otherwise not fit in memory.
Mental Model
Standard backpropagation is "store everything, compute once." Activation checkpointing is "store some, recompute the rest." You mark certain layers as checkpoints. During the forward pass, only checkpoint activations are kept in memory; the rest are discarded. During the backward pass, when you need a discarded activation, you recompute it from the nearest preceding checkpoint. This is a memory-compute tradeoff within the automatic differentiation computation graph: you trade extra forward-pass compute for reduced memory in the reverse-mode gradient computation.
Formal Setup
Activation Checkpointing
Consider a sequential network with layers , where and is the input. Standard backpropagation stores all to compute gradients. Activation checkpointing selects a subset of checkpoint indices and only stores . For layer with nearest preceding checkpoint , the activation is recomputed as:
Memory and Compute Analysis
Without checkpointing: Store all activations. Memory is . Compute is one forward pass plus one backward pass, totaling roughly (where is the forward pass cost, since the backward pass costs approximately ).
With evenly spaced checkpoints: Store checkpoint activations. Between checkpoints, you need to temporarily store at most activations during recomputation. Total memory: .
Minimize by setting . This gives:
The compute overhead is one additional forward pass through each segment between checkpoints. Since there are segments, the total extra forward compute is at most one full forward pass. Total compute becomes roughly instead of , a ~33% increase.
Main Theorems
Optimal Checkpoint Memory Bound
Statement
For a sequential network with layers, each producing activations of size , placing evenly spaced checkpoints reduces peak activation memory from to at most , while increasing compute by at most one additional forward pass.
Intuition
You store checkpoints. When computing gradients for any segment, you recompute at most activations from the nearest checkpoint. At any point, you hold checkpoints plus at most recomputed activations in memory.
Proof Sketch
Partition the layers into segments of size . Store the first activation of each segment ( activations). During backpropagation through segment , recompute all activations in that segment from its checkpoint. Peak memory is . By AM-GM, with equality at . The recomputation of each segment costs one forward pass through layers, and each segment is recomputed once, so total extra compute is layers, which is one full forward pass.
Why It Matters
This result makes it possible to train models with hundreds of layers on a single GPU. Without it, activation memory grows linearly with depth and quickly exceeds hardware limits.
Failure Mode
The analysis assumes equal activation sizes per layer and sequential computation. Transformer models have non-uniform activation sizes (attention matrices are much larger than FFN activations). Residual connections mean the "sequential" assumption is approximate. In practice, checkpoint placement is tuned per-architecture rather than using uniform spacing.
Practical Considerations
What to checkpoint. In transformers, the standard practice is to checkpoint each transformer block. The attention computation and FFN within each block are recomputed from the block input during the backward pass.
Selective checkpointing. Not all layers have equally expensive activations. Attention layers produce activations of size (the attention matrix), which dominates memory. Checkpointing only attention layers gives most of the memory savings.
Interaction with mixed precision. Checkpointed activations can be stored in FP16/BF16, further halving checkpoint memory. Recomputed activations are also in reduced precision, matching the original forward pass.
Memory Budget Comparison
The following table compares memory strategies for training deep networks. All entries assume layers, each producing activations of size , with parameter memory held constant.
| Strategy | Peak activation memory | Compute overhead | When to use |
|---|---|---|---|
| No checkpointing | None (baseline ) | Model fits in memory with room to spare. Small models or large GPUs. | |
| Uniform checkpointing () | ~33% ( total) | Standard choice for large models. The default in most frameworks. | |
| Selective checkpointing (attention only) | Varies; typically - | 10-20% | When attention activations dominate and FFN activations are small. Reduces overhead compared to full checkpointing. |
| Recursive checkpointing | factor | Extreme memory constraints. Rarely used in practice because the compute cost grows superlinearly. | |
| Offloading to CPU | GPU memory per layer | PCIe transfer latency | When GPU memory is the hard constraint and PCIe bandwidth is sufficient. Often combined with checkpointing. |
The 33% compute overhead from uniform checkpointing is almost always acceptable. Training large language models already takes days or weeks; a 33% increase in step time is a minor cost compared to the alternative of reducing batch size (which harms convergence) or using model parallelism (which adds communication overhead).
Interaction with Other Training Techniques
Activation checkpointing does not exist in isolation. It interacts with several other memory optimization techniques.
Gradient accumulation. Gradient accumulation reduces memory by processing micro-batches sequentially and summing gradients. It reduces the activation memory proportional to the micro-batch size. Checkpointing and gradient accumulation are complementary: checkpointing reduces per-layer memory, accumulation reduces per-sample memory. Using both together enables training with effective batch sizes that would otherwise require multiple GPUs.
Batch normalization. BatchNorm layers store running statistics and per-sample statistics during the forward pass. When a BatchNorm layer is inside a checkpointed segment, its forward pass runs twice (once in the original forward, once during recomputation). The running statistics must not be updated during recomputation, or they will be corrupted. PyTorch handles this automatically when using torch.utils.checkpoint, but custom implementations must be careful to freeze BatchNorm statistics during the recomputation pass.
Skip connections. Residual connections create dependencies that span across the "segments" of a checkpointed network. If layer has a skip connection from layer , then recomputing layer requires the activation from layer , which may be in a different checkpointed segment. Frameworks handle this by storing activations at skip connection boundaries as additional checkpoints. This slightly increases memory but preserves correctness.
Common Confusions
Gradient checkpointing and activation checkpointing are the same thing
These terms are used interchangeably. PyTorch calls it torch.utils.checkpoint, some papers call it "gradient checkpointing," others call it "activation checkpointing." They all refer to the same technique: discarding activations and recomputing them during the backward pass.
The 33% compute overhead is not always 33%
The ~33% overhead assumes recomputing one full forward pass (cost ) on top of the standard . But if only some layers are checkpointed, the overhead is less. If nested checkpointing is used (checkpoints within segments), the compute-memory tradeoff curve shifts. The 33% figure is a useful approximation, not an exact number.
Key Takeaways
- Standard backprop stores activations; checkpointing reduces this to
- Optimal checkpoint count is , placed evenly
- Compute overhead is roughly 33% (one extra forward pass)
- In transformers, checkpoint each transformer block
- Activation checkpointing combines with mixed precision for additional memory savings
Exercises
Problem
A network has layers. With optimal checkpointing, how many checkpoints should you place, what is the peak activation memory (in units of single-layer activation size), and what is the compute overhead?
Problem
Chen et al. (2016) showed that with recursive (nested) checkpointing, you can reduce memory to at the cost of compute. Explain the recursive strategy and why the logarithmic memory bound holds.
References
Canonical:
- Chen et al., "Training Deep Nets with Sublinear Memory Cost" (2016), arXiv:1604.06174
- Griewank & Walther, "Algorithm 799: Revolve" (2000), ACM TOMS
Current:
-
Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models" (2023), Megatron-LM
-
Hastie, Tibshirani, Friedman, The Elements of Statistical Learning (2009)
Last reviewed: April 2026
Prerequisites
Foundations this topic depends on.
- Feedforward Networks and BackpropagationLayer 2
- Differentiation in RnLayer 0A
- Sets, Functions, and RelationsLayer 0A
- Basic Logic and Proof TechniquesLayer 0A
- Matrix CalculusLayer 1
- The Jacobian MatrixLayer 0A
- The Hessian MatrixLayer 0A
- Activation FunctionsLayer 1
- Convex Optimization BasicsLayer 1
- Matrix Operations and PropertiesLayer 0A