Skip to main content

AI Systems Bridge · 45 min

Distributed Training Basics

When one GPU is too small or too slow, distributed training splits batches, tensors, layers, parameters, gradients, and optimizer state across accelerators.

Why This Matters

A 7B-parameter transformer in bfloat16 needs about 14 GB for weights, but Adam training commonly stores weights, gradients, two moment vectors, and often a master fp32 copy. A rough training footprint is 16 bytes per parameter, so 7B parameters is about 112 GB before activations. One 80 GB accelerator is already short.

Even when the model fits, a single accelerator may take weeks. Distributed training trades extra communication for more compute devices. The central question is numeric: does the saved compute time exceed the time spent moving gradients, activations, parameters, and optimizer state?

Core Definitions

Definition

Data parallelism

Data parallelism replicates the model on each worker, splits a global batch into per-worker mini-batches, computes gradients locally, then averages gradients with an all-reduce before the optimizer step. If there are NN workers and per-worker batch size bb, the global batch size is B=NbB = Nb.

Definition

Tensor parallelism

Tensor parallelism shards a single tensor operation across workers. A large matrix multiplication, attention projection, or MLP projection is split along rows, columns, or heads. Workers exchange partial outputs with all-reduce, all-gather, or reduce-scatter.

Definition

Pipeline parallelism

Pipeline parallelism assigns contiguous layer ranges to different workers. Micro-batches flow through the stages so that stage 0 can process micro-batch 2 while stage 1 processes micro-batch 1.

Definition

Collective operation

A collective operation is a communication primitive involving a group of workers. The common collectives are all-reduce, all-gather, and reduce-scatter. Training systems usually call vendor libraries such as NCCL rather than writing collectives directly.

Definition

ZeRO and fully sharded data parallel

ZeRO partitions training state across data-parallel workers. Stage 1 shards optimizer state, stage 2 also shards gradients, and stage 3 also shards parameters. Fully sharded data parallel follows the same idea: gather parameters for a layer when needed, compute, then release or reshard them.

Data Parallel Training

In synchronous data parallel training, every worker starts a step with identical parameters. Worker ii computes a local gradient gig_i on its local mini-batch. The gradient used by the optimizer is the arithmetic mean:

g=1Ni=0N1gig = \frac{1}{N}\sum_{i=0}^{N-1} g_i

For four workers, suppose a scalar parameter has local gradients 0.20, 0.10, -0.05, and 0.15. The all-reduce sum is 0.40, and the averaged gradient is 0.10. With learning rate 0.001, SGD changes the parameter by -0.0001.

A minimal C++ sketch hides many details but shows the ordering invariant. No worker may update parameters before the averaged gradient is available.

// One synchronous data-parallel step.
// all_reduce_sum writes the elementwise sum into grad.
for (int step = 0; step < steps; ++step) {
  zero_grad(model);
  forward_backward(local_batch, model, grad);

  all_reduce_sum(grad.data(), grad.numel(), group);
  for (size_t k = 0; k < grad.numel(); ++k) {
    grad[k] /= group.size();
  }

  adam_update(model.params(), grad, optimizer_state);
}

The communication volume is large. A model with 1B parameters has a bfloat16 gradient of 2 GB. Ring all-reduce over N=8N=8 workers moves approximately 2(N1)S/N=3.52(N-1)S/N = 3.5 GB per worker for a tensor of size S=2S = 2 GB. On a 900 GB/s intra-node fabric, the bandwidth-only lower bound is about 3.9 ms. Across a 400 Gb/s link, about 50 GB/s before protocol overhead, the same traffic has a lower bound near 70 ms.

Gradient bucketing reduces exposed latency. Rather than waiting until backprop finishes, frameworks place gradients into buckets, for example 25 MB each. When a bucket is complete, its all-reduce starts while earlier layers continue backprop. The step time is then closer to the maximum of compute and communication than their sum.

Tensor Parallel Matrix Multiplication

Transformer layers contain large matrix multiplications. For Y=XWY = XW, let XX have shape [b,d][b, d] and WW have shape [d,4d][d, 4d] for the MLP expansion. With b=2b=2 and d=4d=4, a small example is:

X=[12345678]X = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \end{bmatrix}

Let column parallelism split WW into two column blocks W0W_0 and W1W_1, each of shape [4,8][4, 8] if the full output width is 16. Worker 0 computes Y0=XW0Y_0 = XW_0, and worker 1 computes Y1=XW1Y_1 = XW_1. The full YY is the concatenation [Y0,Y1][Y_0, Y_1], so the next operation may require an all-gather.

For row parallelism, split XX and WW along the hidden dimension:

XW=X0W0+X1W1XW = X_0W_0 + X_1W_1

Each worker computes a partial output of shape [b,4d][b, 4d]. The workers must all-reduce the partial outputs. This is why Megatron-style tensor parallelism alternates split dimensions to keep some operations local and place collectives at known points.

The byte layout matters. Eight bfloat16 values occupy 16 bytes. If a row-sharded activation has four bfloat16 elements per worker, two workers hold these bytes:

worker 0 X0, shape [1,4], bfloat16:
value:  1.0   2.0   3.0   4.0
bytes:  80 3f 00 40 40 40 80 40

worker 1 X1, shape [1,4], bfloat16:
value:  5.0   6.0   7.0   8.0
bytes:  a0 40 c0 40 e0 40 00 41

The split is a logical tensor partition, not a compression. Total bytes remain the same, but no single device stores the entire activation slice.

Pipeline Parallel Training

Pipeline parallelism divides the layer sequence. For a 24-layer transformer on four devices, a simple split assigns layers 0-5, 6-11, 12-17, and 18-23. If a mini-batch is split into m=8m=8 micro-batches, the forward pass forms a time grid.

time:     0   1   2   3   4   5
stage 0: F0  F1  F2  F3  F4  F5
stage 1: .   F0  F1  F2  F3  F4
stage 2: .   .   F0  F1  F2  F3
stage 3: .   .   .   F0  F1  F2

The empty cells are pipeline bubbles. For pp stages and mm micro-batches, a simple forward-only utilization approximation is m/(m+p1)m/(m+p-1). With p=4p=4 and m=8m=8, that is 8/110.7278/11 \approx 0.727. Larger mm reduces bubbles but stores more activations unless activation recomputation is used.

Backward scheduling adds weight-version constraints. GPipe uses synchronous flush scheduling, which keeps weight versions simple. PipeDream-style schedules can keep devices busier but must track which weight version produced each activation. If the wrong version is used in backward, the gradient no longer corresponds to the forward pass that created the loss.

Pipeline communication sends activations forward and activation gradients backward. If a boundary activation is [bμ,s,d][b_{\mu}, s, d] in bfloat16, its size is 2bμsd2b_{\mu}sd bytes. With micro-batch bμ=2b_{\mu}=2, sequence length 2048, and hidden size 4096, one boundary activation is 33,554,432 bytes, or 32 MiB. Each boundary sends that in forward and again in backward.

ZeRO, FSDP, and Sharded State

The Adam state footprint dominates training. Per parameter, a common mixed-precision accounting is 2 bytes for bf16 parameter, 2 bytes for bf16 gradient, 4 bytes for fp32 master parameter, 4 bytes for first moment, and 4 bytes for second moment. That is 16 bytes per parameter.

For a 10B-parameter model, the unsharded state is 160 GB. Across eight data-parallel workers:

replicated DP per worker:
params bf16       20 GB
grads bf16        20 GB
master fp32       40 GB
Adam m fp32       40 GB
Adam v fp32       40 GB
total            160 GB

ZeRO-1 per worker:
params bf16       20 GB
grads bf16        20 GB
optimizer shard   15 GB
total             55 GB

ZeRO-2 per worker:
params bf16       20 GB
grads shard        2.5 GB
optimizer shard   15 GB
total             37.5 GB

ZeRO-3 or FSDP per worker:
params shard       2.5 GB
grads shard        2.5 GB
optimizer shard   15 GB
total             20 GB

Stage 3 pays in communication. Before computing a layer, workers all-gather the parameter shard for that layer. After backward, gradients are reduce-scattered so each worker keeps only its shard. FSDP implementations often wrap module blocks so that parameter all-gather and reshard happen at module boundaries.

This changes the memory timeline. A worker may briefly hold a full layer’s parameters, but not the full model. The right wrap size balances communication frequency against peak memory.

Collective Operations and Interconnects

All-reduce, all-gather, and reduce-scatter are the vocabulary of distributed training. If each worker starts with one tensor chunk, all-gather leaves every worker with all chunks. Reduce-scatter first reduces elementwise, then leaves each worker with one reduced shard. All-reduce is equivalent to reduce-scatter followed by all-gather.

A ring all-reduce splits the tensor into NN chunks. Each worker repeatedly sends one chunk to its neighbor and receives one chunk from the other neighbor. It has high bandwidth use for large tensors and cost proportional to 2(N1)S/N2(N-1)S/N bytes per worker. A tree all-reduce uses a reduction tree and a broadcast tree. It has lower latency terms, roughly O(logN)O(\log N) message steps, which makes it better for small tensors.

Interconnect placement changes the parallelism choice. Within an 8-GPU node, NVLink on systems such as H100 SXM has about 900 GB/s aggregate GPU-to-GPU bandwidth per GPU. Across nodes, 200 to 400 Gb/s InfiniBand or RoCE gives about 25 to 50 GB/s before overhead. A tensor-parallel group that communicates every transformer block is usually kept inside a node. Data parallelism or ZeRO groups are more likely to cross nodes because they can communicate fewer, larger buckets.

Key Result

For one training step, a useful planning model is:

TstepTcompute+Tunhidden comm+Tbubble+TstragglerT_{\text{step}} \approx T_{\text{compute}} + T_{\text{unhidden comm}} + T_{\text{bubble}} + T_{\text{straggler}}

With overlap, communication that fits under backprop compute is hidden. For a gradient bucket of size SS, ring all-reduce on NN workers has the bandwidth term:

Tring2(N1)SNBT_{\text{ring}} \geq \frac{2(N-1)S}{NB}

Here BB is effective link bandwidth. For N=8N=8, S=25S=25 MiB, and B=900B=900 GB/s, the lower bound is about 0.051 ms. With B=50B=50 GB/s, it is about 0.92 ms. Latency, PCIe hops, kernel launch overhead, and contention add to this bound.

The invariant is stricter than the formula. Synchronous DP requires identical parameters at the start of each step. Tensor parallelism requires matching shards for a single logical tensor operation. Pipeline parallelism requires backward to use the same weight version as forward, unless the algorithm explicitly tolerates stale weights. ZeRO-3 and FSDP require a parameter to be gathered before its computation and resharded only after all consumers finish.

Common Confusions

Watch Out

All-reduce is not the same as all-gather

All-gather concatenates shards. It does not add values. If worker 0 has gradient [1, 2] and worker 1 has [3, 4], all-gather gives both workers [1, 2, 3, 4]. All-reduce sum gives both workers [4, 6].

Watch Out

Tensor parallelism does not reduce total FLOPs

Sharding a matrix multiplication divides FLOPs across devices, but the global multiplication is the same size. The win is lower wall time per layer when communication is smaller than the saved local compute.

Watch Out

ZeRO-3 saves memory but adds parameter traffic

A ZeRO-3 worker does not hold all parameters all the time. It must all-gather full parameters for the active module. If modules are wrapped too finely, many small all-gathers expose latency.

Watch Out

A lost worker is different from a slow worker

A straggler delays the step but eventually contributes gradients. A lost worker breaks the collective. In synchronous training, the remaining workers cannot finish an NCCL all-reduce that includes the missing rank.

Exercises

ExerciseCore

Problem

Four data-parallel workers train a model with two scalar parameters. Their local gradients are g0=[1,3]g_0=[1,3], g1=[2,1]g_1=[2,1], g2=[1,5]g_2=[-1,5], and g3=[0,1]g_3=[0,-1]. Compute the averaged gradient and the SGD update for learning rate 0.01.

ExerciseCore

Problem

A tensor has 64 MiB of bf16 gradients. Estimate the per-worker byte traffic for ring all-reduce on eight workers. Then estimate the bandwidth-only time on 900 GB/s and 50 GB/s links.

ExerciseAdvanced

Problem

A 12B-parameter model uses the 16-byte-per-parameter Adam accounting from this page. Compute per-worker memory for replicated DP and for ZeRO-3 across eight workers. Ignore activations and allocator overhead.

References

Canonical:

  • Vaswani et al., Attention Is All You Need (2017), §§3.1-3.3, transformer layer shapes and attention projections
  • NVIDIA, CUDA C++ Programming Guide (2024), §§5.3, 8.2, memory hierarchy, streams, and asynchronous execution
  • Williams, Waterman, and Patterson, Roofline: An Insightful Visual Performance Model for Multicore Architectures (2009), §§1-3, compute and bandwidth ceilings
  • Rajbhandari et al., ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (SC 2020), §§3-4, ZeRO stages and partitioned optimizer state
  • Huang et al., GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (NeurIPS 2019), §§2-3, micro-batch pipeline scheduling
  • Narayanan et al., Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (SC 2021), §§3-4, tensor and pipeline parallel transformer training

Accessible:

  • PyTorch documentation, FullyShardedDataParallel, API notes and sharding strategy descriptions
  • NVIDIA NCCL documentation, Collective Operations, definitions of all-reduce, all-gather, and reduce-scatter
  • Hugging Face documentation, Model Parallelism, overview of data, tensor, pipeline, and ZeRO-style sharding

Next Topics

  • /computationpath/roofline-model
  • /computationpath/cuda-execution-model
  • /computationpath/gpu-memory-hierarchy
  • /computationpath/inference-serving-basics
  • /topics/scaling-laws-for-language-models