Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

Training Techniques

Activation Checkpointing

Trade compute for memory by recomputing activations during the backward pass instead of storing them all. Reduces memory from O(L) to O(sqrt(L)) for L layers.

CoreTier 2Current~35 min
0

Why This Matters

For a network with LL layers, standard backpropagation stores all LL layers of activations simultaneously. For large transformers, activation memory dominates total GPU memory. A 7B parameter transformer with sequence length 2048 can require 60+ GB just for activations. Activation checkpointing reduces this to a manageable level, making it possible to train models that would otherwise not fit in memory.

Mental Model

Standard backpropagation is "store everything, compute once." Activation checkpointing is "store some, recompute the rest." You mark certain layers as checkpoints. During the forward pass, only checkpoint activations are kept in memory; the rest are discarded. During the backward pass, when you need a discarded activation, you recompute it from the nearest preceding checkpoint. This is a memory-compute tradeoff within the automatic differentiation computation graph: you trade extra forward-pass compute for reduced memory in the reverse-mode gradient computation.

Formal Setup

Definition

Activation Checkpointing

Consider a sequential network with layers f1,f2,,fLf_1, f_2, \ldots, f_L, where al=fl(al1)a_l = f_l(a_{l-1}) and a0=xa_0 = x is the input. Standard backpropagation stores all a0,a1,,aL1a_0, a_1, \ldots, a_{L-1} to compute gradients. Activation checkpointing selects a subset C{0,1,,L1}C \subset \{0, 1, \ldots, L-1\} of checkpoint indices and only stores {ac:cC}\{a_c : c \in C\}. For layer ll with nearest preceding checkpoint c<lc < l, the activation al1a_{l-1} is recomputed as:

al1=fl1fl2fc+1(ac)a_{l-1} = f_{l-1} \circ f_{l-2} \circ \cdots \circ f_{c+1}(a_c)

Memory and Compute Analysis

Without checkpointing: Store all LL activations. Memory is O(L)O(L). Compute is one forward pass plus one backward pass, totaling roughly 3F3F (where FF is the forward pass cost, since the backward pass costs approximately 2F2F).

With kk evenly spaced checkpoints: Store kk checkpoint activations. Between checkpoints, you need to temporarily store at most L/kL/k activations during recomputation. Total memory: O(k+L/k)O(k + L/k).

Minimize k+L/kk + L/k by setting k=Lk = \sqrt{L}. This gives:

Memory=O(L)\text{Memory} = O(\sqrt{L})

The compute overhead is one additional forward pass through each segment between checkpoints. Since there are k=Lk = \sqrt{L} segments, the total extra forward compute is at most one full forward pass. Total compute becomes roughly 4F4F instead of 3F3F, a ~33% increase.

Main Theorems

Theorem

Optimal Checkpoint Memory Bound

Statement

For a sequential network with LL layers, each producing activations of size mm, placing k=Lk = \lfloor\sqrt{L}\rfloor evenly spaced checkpoints reduces peak activation memory from LmLm to at most (L+L)m=2Lm(\sqrt{L} + \sqrt{L})m = 2\sqrt{L} \cdot m, while increasing compute by at most one additional forward pass.

Intuition

You store L\sqrt{L} checkpoints. When computing gradients for any segment, you recompute at most L\sqrt{L} activations from the nearest checkpoint. At any point, you hold L\sqrt{L} checkpoints plus at most L\sqrt{L} recomputed activations in memory.

Proof Sketch

Partition the LL layers into kk segments of size L/kL/k. Store the first activation of each segment (kk activations). During backpropagation through segment ii, recompute all L/kL/k activations in that segment from its checkpoint. Peak memory is k+L/kk + L/k. By AM-GM, k+L/k2Lk + L/k \geq 2\sqrt{L} with equality at k=Lk = \sqrt{L}. The recomputation of each segment costs one forward pass through L/kL/k layers, and each segment is recomputed once, so total extra compute is k×(L/k)=Lk \times (L/k) = L layers, which is one full forward pass.

Why It Matters

This O(L)O(\sqrt{L}) result makes it possible to train models with hundreds of layers on a single GPU. Without it, activation memory grows linearly with depth and quickly exceeds hardware limits.

Failure Mode

The analysis assumes equal activation sizes per layer and sequential computation. Transformer models have non-uniform activation sizes (attention matrices are much larger than FFN activations). Residual connections mean the "sequential" assumption is approximate. In practice, checkpoint placement is tuned per-architecture rather than using uniform spacing.

Practical Considerations

What to checkpoint. In transformers, the standard practice is to checkpoint each transformer block. The attention computation and FFN within each block are recomputed from the block input during the backward pass.

Selective checkpointing. Not all layers have equally expensive activations. Attention layers produce activations of size O(Lseq2d)O(L_{\text{seq}}^2 \cdot d) (the attention matrix), which dominates memory. Checkpointing only attention layers gives most of the memory savings.

Interaction with mixed precision. Checkpointed activations can be stored in FP16/BF16, further halving checkpoint memory. Recomputed activations are also in reduced precision, matching the original forward pass.

Memory Budget Comparison

The following table compares memory strategies for training deep networks. All entries assume LL layers, each producing activations of size mm, with parameter memory held constant.

StrategyPeak activation memoryCompute overheadWhen to use
No checkpointingLmLmNone (baseline 3F3F)Model fits in memory with room to spare. Small models or large GPUs.
Uniform checkpointing (k=Lk = \sqrt{L})2Lm2\sqrt{L} \cdot m~33% (4F4F total)Standard choice for large models. The default in most frameworks.
Selective checkpointing (attention only)Varies; typically 0.40.4-0.6×Lm0.6 \times Lm10-20%When attention activations dominate and FFN activations are small. Reduces overhead compared to full checkpointing.
Recursive checkpointingO(logL)mO(\log L) \cdot mO(LlogL/L)logLO(L \log L / L) \approx \log L factorExtreme memory constraints. Rarely used in practice because the compute cost grows superlinearly.
Offloading to CPUO(1)O(1) GPU memory per layerPCIe transfer latencyWhen GPU memory is the hard constraint and PCIe bandwidth is sufficient. Often combined with checkpointing.

The 33% compute overhead from uniform checkpointing is almost always acceptable. Training large language models already takes days or weeks; a 33% increase in step time is a minor cost compared to the alternative of reducing batch size (which harms convergence) or using model parallelism (which adds communication overhead).

Interaction with Other Training Techniques

Activation checkpointing does not exist in isolation. It interacts with several other memory optimization techniques.

Gradient accumulation. Gradient accumulation reduces memory by processing micro-batches sequentially and summing gradients. It reduces the activation memory proportional to the micro-batch size. Checkpointing and gradient accumulation are complementary: checkpointing reduces per-layer memory, accumulation reduces per-sample memory. Using both together enables training with effective batch sizes that would otherwise require multiple GPUs.

Batch normalization. BatchNorm layers store running statistics and per-sample statistics during the forward pass. When a BatchNorm layer is inside a checkpointed segment, its forward pass runs twice (once in the original forward, once during recomputation). The running statistics must not be updated during recomputation, or they will be corrupted. PyTorch handles this automatically when using torch.utils.checkpoint, but custom implementations must be careful to freeze BatchNorm statistics during the recomputation pass.

Skip connections. Residual connections create dependencies that span across the "segments" of a checkpointed network. If layer ll has a skip connection from layer l3l - 3, then recomputing layer ll requires the activation from layer l3l - 3, which may be in a different checkpointed segment. Frameworks handle this by storing activations at skip connection boundaries as additional checkpoints. This slightly increases memory but preserves correctness.

Common Confusions

Watch Out

Gradient checkpointing and activation checkpointing are the same thing

These terms are used interchangeably. PyTorch calls it torch.utils.checkpoint, some papers call it "gradient checkpointing," others call it "activation checkpointing." They all refer to the same technique: discarding activations and recomputing them during the backward pass.

Watch Out

The 33% compute overhead is not always 33%

The ~33% overhead assumes recomputing one full forward pass (cost FF) on top of the standard 3F3F. But if only some layers are checkpointed, the overhead is less. If nested checkpointing is used (checkpoints within segments), the compute-memory tradeoff curve shifts. The 33% figure is a useful approximation, not an exact number.

Key Takeaways

  • Standard backprop stores O(L)O(L) activations; checkpointing reduces this to O(L)O(\sqrt{L})
  • Optimal checkpoint count is L\sqrt{L}, placed evenly
  • Compute overhead is roughly 33% (one extra forward pass)
  • In transformers, checkpoint each transformer block
  • Activation checkpointing combines with mixed precision for additional memory savings

Exercises

ExerciseCore

Problem

A network has L=100L = 100 layers. With optimal checkpointing, how many checkpoints should you place, what is the peak activation memory (in units of single-layer activation size), and what is the compute overhead?

ExerciseAdvanced

Problem

Chen et al. (2016) showed that with recursive (nested) checkpointing, you can reduce memory to O(logL)O(\log L) at the cost of O(LlogL)O(L \log L) compute. Explain the recursive strategy and why the logarithmic memory bound holds.

References

Canonical:

  • Chen et al., "Training Deep Nets with Sublinear Memory Cost" (2016), arXiv:1604.06174
  • Griewank & Walther, "Algorithm 799: Revolve" (2000), ACM TOMS

Current:

  • Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models" (2023), Megatron-LM

  • Hastie, Tibshirani, Friedman, The Elements of Statistical Learning (2009)

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.