Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

LLM Construction

Distributed Training Theory

Training frontier models requires thousands of GPUs. Data parallelism, model parallelism, and communication-efficient methods make this possible.

AdvancedTier 3Current~55 min
0

Why This Matters

A single GPU cannot train a frontier language model. GPT-4-class models have hundreds of billions of parameters and train on trillions of tokens. Even if you could fit the model on one GPU (you cannot), training would take decades.

Distributed training splits the work across hundreds or thousands of devices. But parallelism is not free: devices must communicate gradients, activations, or parameters, and this communication can bottleneck the entire system. The theory of distributed training is about understanding these tradeoffs and minimizing wasted computation.

Mental Model

Think of building a house. Data parallelism is hiring NN identical crews, each building a copy of the same wall using different bricks, then averaging their work. Model parallelism is splitting the blueprint so each crew builds a different section of the house. Pipeline parallelism is an assembly line: crew 1 pours foundations, passes to crew 2 for framing, crew 3 for roofing.

Each approach has different communication patterns. Data parallelism requires sharing gradients (the "average" step). Tensor parallelism requires sharing intermediate activations within each layer. Pipeline parallelism requires passing activations between stages. The art is choosing the right combination to minimize idle time and communication overhead.

Formal Setup and Notation

Let f(θ;x)f(\theta; x) be the loss for parameters θ\theta on example xx. We want to minimize F(θ)=ExD[f(θ;x)]F(\theta) = \mathbb{E}_{x \sim \mathcal{D}}[f(\theta; x)] using SGD or its variants, distributed across KK workers.

Definition

Data Parallelism

In data parallelism, each of KK workers holds a complete copy of the model θ\theta. At each step:

  1. Each worker kk samples a minibatch BkB_k of size bb
  2. Each worker computes its local gradient gk=1bxBkf(θ;x)g_k = \frac{1}{b} \sum_{x \in B_k} \nabla f(\theta; x)
  3. Workers average their gradients: gˉ=1Kk=1Kgk\bar{g} = \frac{1}{K} \sum_{k=1}^{K} g_k
  4. Each worker updates: θθηgˉ\theta \leftarrow \theta - \eta \bar{g}

The effective batch size is KbKb. Step 3 is typically implemented via AllReduce.

Definition

AllReduce

AllReduce is a collective communication primitive that computes the sum (or average) of vectors across all workers and distributes the result to every worker. For KK workers each holding a vector of dd elements:

  • Ring AllReduce: each worker sends and receives d(K1)/Kd(K-1)/K elements in 2(K1)2(K-1) steps. Total communication volume per worker: 2d(K1)/K2d2d(K-1)/K \approx 2d for large KK.
  • The communication cost is O(d/bandwidth)O(d/\text{bandwidth}) and is independent of KK in the bandwidth-optimal ring algorithm.
Definition

Tensor Parallelism

Tensor parallelism (TP) splits individual layers across devices. For a linear layer Y=XWY = XW, partition WW column-wise across KK devices:

W=[W1,W2,,WK]W = [W_1, W_2, \ldots, W_K]

Device kk computes Yk=XWkY_k = XW_k (a slice of the output). An AllGather collects the full output Y=[Y1,,YK]Y = [Y_1, \ldots, Y_K]. For the backward pass, a ReduceScatter distributes gradient computation.

Definition

Pipeline Parallelism

Pipeline parallelism (PP) partitions the model's layers into SS sequential stages, each on a different device. A microbatch flows through stage 1, then stage 2, etc. Multiple microbatches are in flight simultaneously to fill the pipeline, but pipeline bubbles (idle time) occur at startup and shutdown.

For MM microbatches and SS stages, the bubble fraction is:

bubble=S1M+S1\text{bubble} = \frac{S - 1}{M + S - 1}

which shrinks as MM grows relative to SS.

Core Definitions

ZeRO (Zero Redundancy Optimizer) from DeepSpeed eliminates memory redundancy in data parallelism. Standard data parallelism replicates the full optimizer state, gradients, and parameters on every worker. ZeRO partitions these across workers:

  • ZeRO Stage 1: partition optimizer states (e.g., Adam's mm and vv). Saves 4×4\times memory for Adam (from 12 bytes/param to 4 bytes/param per worker).
  • ZeRO Stage 2: additionally partition gradients. Workers compute full gradients locally but only store their partition.
  • ZeRO Stage 3: additionally partition parameters. Workers must AllGather parameters before each forward/backward step, trading communication for memory.

Gradient compression reduces communication volume:

  • Gradient quantization: send gradients in low precision (e.g., FP16 or INT8) instead of FP32, halving or quartering bandwidth.
  • Gradient sparsification: send only the top-kk gradient components. The error from dropped components is accumulated locally and sent in future rounds.
  • PowerSGD: approximate the gradient matrix with a low-rank decomposition, reducing communication from O(d)O(d) to O(dr/min(m,n))O(d \cdot r / \min(m,n)) for rank rr.

Main Theorems

Theorem

Linear Speedup of Data-Parallel SGD

Statement

For data-parallel SGD with KK workers, each using local batch size bb, the effective batch size is B=KbB = Kb. After TT steps, the convergence rate for a smooth non-convex objective satisfies:

1Tt=1TEF(θt)2O(F(θ0)FηT+ηLσ2Kb)\frac{1}{T} \sum_{t=1}^{T} \mathbb{E}\|\nabla F(\theta_t)\|^2 \leq O\left(\frac{F(\theta_0) - F^*}{\eta T} + \eta L \frac{\sigma^2}{Kb}\right)

With learning rate η=O(Kb/T)\eta = O(\sqrt{Kb/T}), this gives:

O((F(θ0)F)Lσ2KbT)O\left(\sqrt{\frac{(F(\theta_0) - F^*)L\sigma^2}{KbT}}\right)

This is a factor of K\sqrt{K} better than single-worker SGD with batch size bb and the same number of steps. Equivalently, to reach the same accuracy, data-parallel SGD needs KK times fewer steps, giving linear speedup.

Intuition

Each worker contributes independent gradient estimates. Averaging KK independent estimates reduces variance by a factor of KK, just like averaging KK random variables. Less variance means you can take larger steps or reach the same accuracy faster. The speedup is linear until the batch size becomes so large that the variance reduction saturates (the "critical batch size").

Proof Sketch

Start from the smoothness inequality: F(θt+1)F(θt)ηF(θt)Tgˉt+η2L2gˉt2F(\theta_{t+1}) \leq F(\theta_t) - \eta \nabla F(\theta_t)^T \bar{g}_t + \frac{\eta^2 L}{2} \|\bar{g}_t\|^2. Take expectations. The averaged gradient gˉt\bar{g}_t has variance σ2/(Kb)\sigma^2 / (Kb) (each of KK workers averages bb independent samples). The bias is zero since each worker computes unbiased gradients. Telescope over TT steps and optimize the learning rate.

Why It Matters

Linear speedup is the best you can hope for from parallelism. This theorem says data-parallel SGD achieves it, up to a critical batch size beyond which increasing KK gives diminishing returns. Understanding this limit is essential for choosing the right number of workers and avoiding wasted compute.

Failure Mode

Linear speedup breaks down when: (1) the batch size KbKb exceeds the critical batch size Bcritσ2/F2B_{\text{crit}} \approx \sigma^2 / \|\nabla F\|^2, after which additional workers reduce variance below what matters; (2) communication overhead exceeds computation savings; (3) large batch sizes require learning rate warmup and careful tuning to maintain convergence.

Proposition

Communication-Computation Tradeoff

Statement

In data-parallel training with ring AllReduce, the total time per step is:

Tstep=C+2dWT_{\text{step}} = C + \frac{2d}{W}

where CC is the local computation time (forward + backward pass) and 2d/W2d / W is the AllReduce communication time (sending 2d2d values at bandwidth WW). The computation-to-communication ratio is:

ratio=C2d/W=CW2d\text{ratio} = \frac{C}{2d / W} = \frac{CW}{2d}

Training is efficient when this ratio is much greater than 1 (computation dominates). It is communication-bound when the ratio is less than 1.

Intuition

Every step, you do some math (forward and backward pass, time CC) and some networking (share gradients, time 2d/W2d/W). If the network is slow relative to the GPU, workers spend most of their time waiting. Making each worker do more computation per communication round (larger local batch, gradient accumulation) improves the ratio.

Proof Sketch

Ring AllReduce for dd parameters across KK workers requires each worker to send and receive d(K1)/Kdd(K-1)/K \approx d values in each of two phases (reduce- scatter and all-gather). At bandwidth WW, this takes 2d/W2d/W time. Computation and communication can overlap partially (compute gradients for layer l+1l+1 while communicating gradients for layer ll), but the non-overlapping portion still imposes the bound.

Why It Matters

This ratio determines whether adding more GPUs actually helps. For a model with d=109d = 10^9 parameters in FP16 (2×1092 \times 10^9 bytes), AllReduce takes 4×109/W4 \times 10^9 / W seconds. At 100 Gbps bandwidth (12.5\approx 12.5 GB/s), this is about 0.32 seconds. If the forward/backward pass takes 0.5 seconds, communication takes 39% of step time, a significant overhead. Faster interconnects (InfiniBand, NVLink) or gradient compression are needed.

Canonical Examples

Example

Scaling from 1 to 256 GPUs

A model with d=7×109d = 7 \times 10^9 parameters (7B) in FP16 uses 14GB for parameters alone. Adam optimizer states add another 28GB (FP32 copy + mm + vv). Total: 42GB per worker in standard data parallelism.

With ZeRO Stage 1: optimizer states are split across KK workers. Each worker stores 28/K28/K GB of optimizer state plus 14GB of parameters and gradients. At K=8K = 8, this is 3.5+14=17.53.5 + 14 = 17.5 GB per worker, fitting on 24GB GPUs.

With ZeRO Stage 3: everything is split. Each worker stores (14+14+28)/K(14 + 14 + 28)/K GB. At K=8K = 8, this is 7GB per worker, but each forward and backward pass requires AllGather operations to reconstruct parameters.

Common Confusions

Watch Out

Data parallelism does not change the model

Data parallelism with synchronized gradients produces exactly the same updates as single-GPU training with the same effective batch size KbKb. The only difference is speed. Asynchronous data parallelism, where workers do not wait for all gradients, does change the dynamics and can hurt convergence.

Watch Out

Pipeline parallelism is not just splitting layers

Naive pipeline parallelism has the pipeline bubble problem: only one stage is active at a time, giving utilization 1/S1/S. The key innovation (GPipe, PipeDream) is microbatching: split the minibatch into MM microbatches and overlap their execution. With MSM \gg S, utilization approaches 100%, but memory increases because activations for all in-flight microbatches must be stored.

Watch Out

3D parallelism is not optional at scale

Frontier models use all three forms simultaneously: data parallelism across groups, tensor parallelism within a node (fast NVLink), and pipeline parallelism across nodes. The 3D configuration is a system design choice driven by hardware topology, not just model architecture.

Summary

  • Data parallelism gives linear speedup up to a critical batch size; beyond that, adding workers wastes compute
  • AllReduce communication cost is O(d/W)O(d/W), independent of worker count in the bandwidth-optimal case
  • Tensor parallelism splits layers within a node (requires high bandwidth); pipeline parallelism splits layers across nodes (tolerates lower bandwidth)
  • ZeRO eliminates memory redundancy: Stage 1 partitions optimizer states, Stage 2 adds gradients, Stage 3 adds parameters
  • The computation-to-communication ratio determines whether training is GPU-bound or network-bound

Exercises

ExerciseCore

Problem

A model has d=1×109d = 1 \times 10^9 parameters in FP32. Each worker has 100 Gbps network bandwidth. How long does AllReduce take, assuming the ring algorithm?

ExerciseAdvanced

Problem

With K=64K = 64 workers and local batch size b=16b = 16, the effective batch size is 10241024. If the critical batch size is Bcrit=2048B_{\text{crit}} = 2048, are we in the linear speedup regime? What happens if we double KK to 128128?

ExerciseResearch

Problem

Explain the memory-communication tradeoff in ZeRO Stages 1-3. For each stage, what memory is saved per worker and what additional communication is required?

References

Canonical:

  • Dean et al., Large Scale Distributed Deep Networks (NIPS 2012)
  • Rajbhandari et al., ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (SC 2020)

Current:

  • Shoeybi et al., Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (2020)
  • Narayanan et al., Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (SC 2021)

Next Topics

Distributed training connects to:

  • Scaling laws that determine how many GPUs and how much data you need
  • Mixed-precision training for reducing communication and memory further

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Builds on This