Training Techniques
Batch Size and Learning Dynamics
How batch size affects what SGD finds: gradient noise, implicit regularization, the linear scaling rule, sharp vs flat minima, and the gradient noise scale as the key quantity governing the tradeoff.
Why This Matters
Batch size is not just a memory/compute tradeoff. It changes the noise structure of SGD, which changes the loss landscape regions the optimizer explores, which changes the generalization properties of the trained model. Small-batch SGD and large-batch SGD can converge to different solutions with measurably different test performance.
Understanding this requires connecting optimization (convergence rate), statistics (gradient variance), and geometry (curvature of the loss surface).
Mental Model
SGD computes a gradient estimate from a mini-batch of size . Small means noisy gradients: the optimizer takes jittery steps that help it escape sharp minima. Large means accurate gradients: the optimizer takes confident steps but may get stuck in the nearest sharp minimum. The noise is not a bug; it is an implicit regularizer.
The key quantity is the ratio of gradient noise to gradient signal. This ratio determines whether the optimizer behaves more like SGD (noisy, exploratory) or more like full-batch gradient descent (deterministic, exploitative).
Formal Setup
Consider minimizing using mini-batch SGD. A mini-batch of size gives the gradient estimate:
Gradient Noise Covariance
The covariance of the mini-batch gradient estimate is:
where is the per-sample gradient covariance. The noise scales as : doubling the batch size halves the gradient variance.
Gradient Noise Scale
The gradient noise scale (McCandlish et al., 2018) is:
This is the ratio of gradient variance to gradient signal squared. When , noise dominates and increasing gives near-linear speedup. When , signal dominates and increasing gives diminishing returns.
Main Theorems
Linear Scaling Rule
Statement
If SGD with learning rate and batch size produces a certain training trajectory (in the continuous-time SDE approximation), then SGD with learning rate and batch size produces approximately the same trajectory, for any scaling factor .
The effective noise temperature of SGD is:
Scaling both and by preserves .
Intuition
What matters for the dynamics is the ratio , not or individually. This ratio controls the magnitude of the noise injected per step. If you use larger batches, you must use larger learning rate to maintain the same effective noise.
Proof Sketch
In the continuous-time limit, SGD on the loss is approximated by the SDE:
The drift term is independent of and (after rescaling time by ). The diffusion coefficient depends on . Scaling and preserves the diffusion coefficient.
Why It Matters
This is the theoretical justification for the practice of scaling the learning rate linearly with batch size, used in large-scale distributed training (Goyal et al., 2017). Without this rule, increasing the batch size would reduce noise and change the implicit regularization of SGD.
Failure Mode
The linear scaling rule breaks when: (1) the learning rate becomes so large that the continuous-time SDE approximation fails (discrete effects dominate), (2) the loss landscape is not smooth enough for the local diffusion approximation, or (3) during the initial transient before SGD reaches a stationary regime. In practice, a warm-up period is needed for large learning rates.
Critical Batch Size and Training Efficiency
Statement
Let be the gradient noise scale. The number of optimization steps to reach a target loss scales as:
where is the minimum steps achievable (at ). The total compute (in units of samples processed) is , minimized near .
Intuition
For : noise dominates. Doubling halves the number of steps with nearly the same total compute (linear speedup). For : signal dominates. Doubling barely reduces steps, so total compute roughly doubles (no speedup). The crossover is at .
Proof Sketch
The per-step improvement in loss is approximately . With optimal (balancing progress against noise), the per-step improvement is . Inverting gives the step count formula.
Why It Matters
is measurable during training (estimate gradient variance from multiple mini-batches). It tells you the largest useful batch size: going beyond wastes compute. McCandlish et al. (2018) measured for language models and found it increases during training, explaining why larger batches become useful later.
Failure Mode
The analysis assumes the gradient variance is approximately constant, which changes as training progresses. The noise scale itself varies over training, so the optimal batch size is not a single number but a trajectory.
Sharp vs Flat Minima
Small Batch SGD Favors Flat Minima
Statement
In the SDE approximation of SGD, the stationary distribution concentrates on regions where the local loss is low and the Hessian eigenvalues are small. Specifically, the stationary density is approximately proportional to:
where is the Hessian. The second factor favors flat minima (small eigenvalues of ).
Intuition
SGD noise acts like a temperature that helps the optimizer escape sharp minima (high curvature) but not flat minima (low curvature). The wider a minimum is, the harder it is for noise to push the iterate out. Higher noise temperature ( large) means only the flattest minima are stable.
Proof Sketch
Model SGD as a Langevin diffusion with . The stationary distribution of Langevin dynamics is the Gibbs measure , modified by the determinant factor when the noise covariance depends on .
Why It Matters
This provides a theoretical explanation for the empirical observation that small-batch SGD often generalizes better than large-batch SGD: flat minima are associated with better generalization because nearby points have similar loss (the function is stable under perturbation of parameters). This is an observation, not a theorem about generalization itself; the connection between flatness and generalization is debated.
Failure Mode
The sharp/flat minima story has known weaknesses. Dinh et al. (2017) showed you can reparameterize a network to make any minimum arbitrarily sharp without changing the function it computes. The SDE approximation also breaks for large learning rates. Flatness measured by the Hessian trace may not correlate with generalization in all architectures.
Practical Implications
Warm-up. When using the linear scaling rule with large batch sizes, the initial learning rate is large. A warm-up period (linearly increasing from a small value over the first few epochs) prevents divergence during the initial transient when the loss landscape curvature is high.
LARS and LAMB. For very large batch sizes (), layer-wise adaptive learning rates (LARS for SGD, LAMB for Adam) adjust the step size per layer based on the ratio of parameter norm to gradient norm. This compensates for different layers having different gradient noise scales.
Diminishing returns. Training large language models, is typically to tokens. Beyond this, more parallelism does not reduce wall-clock training time proportionally.
Canonical Examples
ResNet-50 on ImageNet: batch size scaling
Goyal et al. (2017) trained ResNet-50 with batch sizes from 256 to 8192. With the linear scaling rule () and warm-up, they achieved equivalent accuracy across all batch sizes. At (32x baseline), training completed in 1 hour on 256 GPUs. Beyond , accuracy began to degrade, consistent with being roughly in that range for this task.
Gradient noise scale in language modeling
McCandlish et al. (2018) measured for a Transformer language model during training. Early in training, (small batches suffice because gradients are noisy relative to their magnitude). Late in training, (the model is near a minimum, gradients are small, so you need large batches to estimate the direction accurately).
Common Confusions
Larger batch does not always mean faster training
Larger batches reduce the number of steps, but each step processes more data. Total compute (samples processed) is minimized near . Beyond that, you are paying more compute per step without proportional reduction in steps. Wall-clock time may still decrease if you have idle GPUs, but compute efficiency drops.
The linear scaling rule is not a law
It is an approximation that holds when the SDE limit is valid, which requires to be small relative to the inverse curvature of the loss. For very large learning rates or very large batch sizes, discrete-time effects (finite step size) break the continuous-time approximation. In practice, the rule works well up to some critical batch size and then fails.
Flat minima are not guaranteed to generalize better
The flat minima hypothesis is plausible but not proven. Reparameterization can change the curvature without changing the function. PAC-Bayes bounds provide some theoretical support (flat minima correspond to wide posteriors with low KL penalty), but the connection is not definitive. Treat the flat minima story as a useful heuristic, not a theorem.
Key Takeaways
- Batch size controls the noise-to-signal ratio of SGD gradients
- The gradient noise scale is the critical batch size where noise and signal are balanced
- Linear scaling rule: scale learning rate proportionally with batch size to preserve dynamics
- Small batches inject more noise, which can help escape sharp minima (implicit regularization)
- Beyond , larger batches give diminishing returns in compute efficiency
- The connection between flatness and generalization is suggestive but not proven
Exercises
Problem
You are training with batch size and learning rate . You want to switch to . What learning rate should you use according to the linear scaling rule? What quantity is preserved?
Problem
The gradient noise scale of your model is . You have access to 16 GPUs, each handling a local batch of 128. Your global batch size is . A colleague offers you 16 more GPUs. Should you double the global batch size to 4096? Justify using the step count formula.
References
Canonical:
- Goyal et al., "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" (2017)
- McCandlish et al., "An Empirical Model of Large-Batch Training" (2018)
Current:
- Smith et al., "Don't Decay the Learning Rate, Increase the Batch Size" (ICLR, 2018)
- Hoffer et al., "Train Longer, Generalize Better" (NeurIPS, 2017)
- Dinh et al., "Sharp Minima Can Generalize for Deep Nets" (ICML, 2017)
Next Topics
Natural extensions from batch size dynamics:
- Learning rate schedules: how to adjust learning rate over training, complementary to batch size choices
- Distributed training theory: communication costs and gradient compression when parallelizing across many workers
Last reviewed: April 2026
Prerequisites
Foundations this topic depends on.
- Stochastic Gradient Descent ConvergenceLayer 2
- Gradient Descent VariantsLayer 1
- Convex Optimization BasicsLayer 1
- Differentiation in RnLayer 0A
- Sets, Functions, and RelationsLayer 0A
- Basic Logic and Proof TechniquesLayer 0A
- Matrix Operations and PropertiesLayer 0A
- Concentration InequalitiesLayer 1
- Common Probability DistributionsLayer 0A
- Expectation, Variance, Covariance, and MomentsLayer 0A
- Adam OptimizerLayer 2