Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

Modern Generalization

Mean Field Theory

The mean field limit of neural networks: as width goes to infinity under the right scaling, neurons become independent particles whose weight distribution evolves by Wasserstein gradient flow, capturing feature learning that the NTK regime misses.

AdvancedTier 2Current~65 min
0

Why This Matters

The Neural Tangent Kernel shows that infinitely wide networks in the lazy regime behave like kernel methods --- but this is precisely the regime where networks do not learn features. Practical neural networks learn representations, and NTK cannot explain this.

Mean field theory provides an alternative infinite-width limit that does capture feature learning. Under a different parameterization (mean field scaling instead of NTK scaling), the network weights move substantially during training. In the infinite-width limit, individual neurons become independent, and the distribution of weights evolves according to a partial differential equation --- specifically, a Wasserstein gradient flow.

This is the theoretical framework for understanding what happens beyond the kernel regime, making it one of the most important directions in modern deep learning theory.

Mental Model

Think of a two-layer neural network as a collection of mm "particles" (neurons), each with a weight vector wjw_j. Each particle contributes ajσ(wjx)/ma_j \sigma(w_j^\top x) / m to the output. As mm \to \infty, the sum becomes an integral over a probability distribution μ\mu of weights:

f(x;μ)=aσ(wx)dμ(a,w)f(x; \mu) = \int a \, \sigma(w^\top x) \, d\mu(a, w)

Training the network is equivalent to moving the particles {(aj,wj)}j=1m\{(a_j, w_j)\}_{j=1}^m by gradient descent. In the infinite-width limit, this becomes evolving the distribution μ\mu by a continuous flow. The distribution moves in the direction that decreases the loss fastest --- this is Wasserstein gradient flow.

The key difference from NTK: in NTK, each particle barely moves (order 1/m1/\sqrt{m} displacement). In mean field, particles move substantially (order 1 displacement). This substantial movement is what enables feature learning.

Formal Setup

Consider a two-layer neural network:

fm(x)=1mj=1majσ(wjx)f_m(x) = \frac{1}{m}\sum_{j=1}^m a_j \sigma(w_j^\top x)

where θj=(aj,wj)Rd+1\theta_j = (a_j, w_j) \in \mathbb{R}^{d+1} are the parameters of neuron jj, and σ\sigma is an activation function (e.g., ReLU).

Definition

Mean Field Parameterization

In the mean field parameterization, the network output scales as 1/m1/m (one factor of width in the denominator):

fm(x)=1mj=1mϕ(x;θj)f_m(x) = \frac{1}{m}\sum_{j=1}^m \phi(x; \theta_j)

where ϕ(x;θ)=aσ(wx)\phi(x; \theta) = a \sigma(w^\top x) is the contribution of a single neuron with parameters θ=(a,w)\theta = (a, w).

Compare to NTK parameterization: fm(x)=1mj=1majσ(wjx)f_m(x) = \frac{1}{\sqrt{m}}\sum_{j=1}^m a_j \sigma(w_j^\top x) where only aja_j is trained and the output scales as 1/m1/\sqrt{m}.

The 1/m1/m scaling in mean field is crucial: it means each neuron's contribution is order 1/m1/m, so the network depends on the distribution of neurons, not on any individual one.

Definition

Empirical Measure of Neurons

The empirical measure of the neuron parameters is:

μm=1mj=1mδθj\mu_m = \frac{1}{m}\sum_{j=1}^m \delta_{\theta_j}

where δθj\delta_{\theta_j} is a point mass at θj\theta_j. As mm \to \infty, if the neurons are initialized i.i.d. from some distribution μ0\mu_0, then μmμ0\mu_m \to \mu_0 by the law of large numbers. The network output becomes:

f(x;μ)=ϕ(x;θ)dμ(θ)f(x; \mu) = \int \phi(x; \theta) \, d\mu(\theta)

This is a linear functional of the measure μ\mu.

Definition

Wasserstein Gradient Flow

The Wasserstein gradient flow is the continuous-time evolution of a probability measure μt\mu_t that follows the steepest descent direction of a functional L(μ)\mathcal{L}(\mu) in the Wasserstein-2 metric space:

tμt=(μtδLδμ(μt))\partial_t \mu_t = \nabla \cdot \left(\mu_t \nabla \frac{\delta \mathcal{L}}{\delta \mu}(\mu_t)\right)

where δLδμ\frac{\delta \mathcal{L}}{\delta \mu} is the first variation (functional derivative) of L\mathcal{L} with respect to μ\mu.

In the neural network context, L(μ)=L(f(;μ))\mathcal{L}(\mu) = L(f(\cdot; \mu)) is the training loss viewed as a functional of the weight distribution. The first variation at a point θ\theta is the gradient of the loss with respect to a single neuron's parameters: δLδμ(θ)=θLθ\frac{\delta \mathcal{L}}{\delta \mu}(\theta) = \nabla_\theta L_\theta evaluated at the current distribution.

Main Theorems

Theorem

Mean Field Limit for Two-Layer Networks

Statement

Consider a two-layer network with mm neurons trained by gradient flow on loss LL. Under regularity conditions on σ\sigma and LL, as mm \to \infty:

  1. The empirical measure μm(t)\mu_m(t) converges weakly to a deterministic measure μ(t)\mu(t) for all t0t \geq 0
  2. The limiting measure μ(t)\mu(t) satisfies the mean field PDE:

tμt=(μtθδLδμ(μt,θ))\partial_t \mu_t = \nabla \cdot \left(\mu_t \nabla_\theta \frac{\delta \mathcal{L}}{\delta \mu}(\mu_t, \theta)\right)

  1. The network output fm(x;t)f_m(x; t) converges to f(x;μt)=ϕ(x;θ)dμt(θ)f(x; \mu_t) = \int \phi(x; \theta) \, d\mu_t(\theta)
  2. Each neuron evolves independently in the limit, following: θ˙j=θδLδμ(μt,θj)\dot{\theta}_j = -\nabla_\theta \frac{\delta \mathcal{L}}{\delta \mu}(\mu_t, \theta_j) where μt\mu_t is the population-level distribution

Intuition

The 1/m1/m scaling means each individual neuron has a vanishing effect on the total output. As mm \to \infty, changing one neuron does not affect the loss gradient seen by other neurons. This is the "propagation of chaos" phenomenon: interacting particles become independent in the many-particle limit. Each neuron follows its own gradient as if the distribution μt\mu_t were fixed --- but μt\mu_t itself evolves self-consistently as the aggregate of all neurons.

This is analogous to how individual gas molecules become effectively independent in the thermodynamic limit, even though they all interact via the mean field.

Proof Sketch

The proof uses propagation of chaos techniques from mathematical physics:

Step 1: Show that the gradient update for neuron jj depends on the other neurons only through the empirical measure μm\mu_m. The gradient is θjL=θδLδμ(μm,θj)\nabla_{\theta_j} L = \nabla_\theta \frac{\delta \mathcal{L}}{\delta \mu}(\mu_m, \theta_j).

Step 2: Show that μm(t)\mu_m(t) concentrates around a deterministic trajectory μ(t)\mu(t). This uses the law of large numbers for interacting particle systems: as mm \to \infty, the empirical measure of i.i.d. particles undergoing mean-field interactions converges to the solution of the mean field PDE.

Step 3: Verify that the limiting PDE is well-posed (existence and uniqueness of solutions) under the regularity assumptions on σ\sigma.

Why It Matters

This theorem says that infinitely wide mean-field networks are described by a PDE, not by a kernel. The distribution μt\mu_t evolves nontrivially during training --- the neurons move to new locations in parameter space. This is feature learning: the features σ(wjx)\sigma(w_j^\top x) change during training because the wjw_j change. NTK theory, by contrast, freezes the features at their initialization.

The mean field limit shows that feature learning is not a finite-width artifact --- it persists at infinite width under the right scaling.

Failure Mode

The regularity conditions on σ\sigma typically require smoothness, which excludes ReLU. Extensions to ReLU exist but require more delicate analysis. The convergence rate is typically O(1/m)O(1/\sqrt{m}) in Wasserstein distance, which is slow. Most importantly, the mean field limit applies cleanly only to two-layer networks. Extending to deep networks requires multi-layer mean field theories that are still under active development.

Proposition

Training Loss Decreases Along Wasserstein Gradient Flow

Statement

Along the Wasserstein gradient flow μt\mu_t, the training loss L(μt)\mathcal{L}(\mu_t) is non-increasing:

ddtL(μt)=θδLδμ(μt,θ)2dμt(θ)0\frac{d}{dt}\mathcal{L}(\mu_t) = -\int \left\|\nabla_\theta \frac{\delta \mathcal{L}}{\delta \mu}(\mu_t, \theta)\right\|^2 d\mu_t(\theta) \leq 0

The loss decreases at a rate proportional to the expected squared gradient norm under the current weight distribution.

Intuition

This is the infinite-width analogue of "gradient descent decreases the loss." Each neuron moves in its negative gradient direction, and the aggregate effect is a decrease in the loss. The rate of decrease depends on how large the gradients are on average under μt\mu_t. The flow stops (reaches a critical point) when the gradient is zero μt\mu_t-almost everywhere.

Proof Sketch

By the chain rule in Wasserstein space:

ddtL(μt)=δLδμ(μt,θ)tμt(θ)\frac{d}{dt}\mathcal{L}(\mu_t) = \int \frac{\delta \mathcal{L}}{\delta \mu}(\mu_t, \theta) \, \partial_t \mu_t(\theta)

Substituting the mean field PDE and integrating by parts:

=θδLδμ2dμt0= -\int \left\|\nabla_\theta \frac{\delta \mathcal{L}}{\delta \mu}\right\|^2 d\mu_t \leq 0

The integration by parts moves the divergence operator \nabla \cdot onto the first variation, producing the squared gradient norm with a negative sign.

Why It Matters

This guarantees convergence to a critical point (under appropriate compactness conditions). Combined with global optimality results for over-parameterized mean field networks, it shows that gradient flow on infinitely wide networks finds good solutions. The key advantage over NTK is that this convergence happens while the features are being learned.

Failure Mode

The critical point reached by the flow may be a saddle point or local minimum, not a global minimum. Global convergence results require additional assumptions (e.g., convexity of the loss functional in Wasserstein space, or specific properties of the activation function).

Mean Field vs. NTK: The Central Comparison

PropertyNTK (Lazy Regime)Mean Field (Rich Regime)
Parameterization1/m1/\sqrt{m} scaling1/m1/m scaling
Weight movementO(1/m)O(1/\sqrt{m}) --- vanishingO(1)O(1) --- substantial
FeaturesFrozen at initializationLearned during training
Infinite-width limitKernel regression (linear)Wasserstein gradient flow (nonlinear)
Mathematical toolKernel theory, RKHSPDE, optimal transport
Feature learningNoYes
Captures practicePoorlyBetter (but still limited)

The fundamental reason for the difference: the 1/m1/\sqrt{m} NTK scaling means that each neuron's output change during training is order 1/m1/\sqrt{m}, so the linearization of the network is accurate. The 1/m1/m mean field scaling means the output depends on the average over mm neurons. Each neuron can move substantially (order 1) because its individual contribution to the output is only 1/m1/m.

Canonical Examples

Example

Mean field dynamics for a toy problem

Consider fitting f(x)=sign(x)f(x) = \text{sign}(x) with a two-layer ReLU network on [1,1][-1, 1]. Under NTK: the random features σ(wjx)\sigma(w_j^\top x) are fixed, and the network can only learn a linear combination of these random features. This is like trying to approximate a step function using a fixed random basis.

Under mean field: the weight distribution μt\mu_t evolves so that neurons concentrate near w=0w = 0 (to create the jump). The neurons discover that placing features near the discontinuity is useful. This is feature learning in action: the network adapts its features to the target function, rather than relying on random features.

Common Confusions

Watch Out

Mean field does not mean mean field approximation from physics

In statistical physics, "mean field approximation" means replacing interactions with their average --- an approximation that becomes exact in certain limits. In the neural network context, the mean field limit is not an approximation of a finite-width network. It is the exact limit as width goes to infinity under the 1/m1/m scaling. The name comes from the same mathematical structure (propagation of chaos, interacting particle systems), but it is a theorem, not an approximation.

Watch Out

Mean field is not a replacement for NTK --- they describe different scaling regimes

NTK and mean field describe different limits of the same architecture with different parameterizations. Neither is "wrong." The question is which limit better describes the behavior of a practical network. For very wide networks with small learning rate, NTK is relevant. For networks that learn features (which includes most practical networks), the mean field perspective is more informative.

Watch Out

Mean field theory is currently limited to shallow networks

The cleanest mean field results are for two-layer networks. Deep mean field theory exists (e.g., through tensor programs) but is substantially more complex. The infinite-width limit for deep networks depends on the order in which layers are taken to infinity, and different orderings give different limits.

Summary

  • Mean field parameterization uses 1/m1/m scaling; NTK uses 1/m1/\sqrt{m}
  • At infinite width, neurons become independent particles (propagation of chaos)
  • The weight distribution evolves by Wasserstein gradient flow (a PDE)
  • Mean field captures feature learning: weights move O(1)O(1), not O(1/m)O(1/\sqrt{m})
  • NTK freezes features (lazy regime); mean field learns features (rich regime)
  • The mean field limit is a PDE, not a kernel --- structurally different mathematical object
  • Currently best understood for two-layer networks; deep extensions are active research
  • Mean field is the right framework for understanding why neural networks outperform kernel methods

Exercises

ExerciseAdvanced

Problem

Consider a two-layer network fm(x)=1mj=1majσ(wjx)f_m(x) = \frac{1}{m}\sum_{j=1}^m a_j \sigma(w_j x) with scalar input xRx \in \mathbb{R}. Under the mean field limit, write the network output as an integral over the weight distribution μ\mu and compute the first variation δLδμ(θ)\frac{\delta \mathcal{L}}{\delta \mu}(\theta) for the squared loss L(μ)=12(f(x0;μ)y0)2\mathcal{L}(\mu) = \frac{1}{2}(f(x_0; \mu) - y_0)^2 at a single data point (x0,y0)(x_0, y_0).

ExerciseAdvanced

Problem

Explain why the NTK parameterization (1/m1/\sqrt{m} scaling) prevents feature learning in the infinite-width limit, while the mean field parameterization (1/m1/m scaling) allows it. Consider the magnitude of the gradient update for a single neuron's weight wjw_j in both cases.

ExerciseResearch

Problem

The mean field PDE is a Wasserstein gradient flow. Explain what it means for the loss functional L(μ)\mathcal{L}(\mu) to be displacement convex in the Wasserstein-2 metric, and why displacement convexity would guarantee global convergence of the mean field dynamics to the global minimum. Does displacement convexity hold for typical neural network loss functionals?

Related Comparisons

References

Canonical:

  • Mei, Montanari, Nguyen, "A Mean Field View of the Landscape of Two-Layer Neural Networks" (PNAS 2018)
  • Chizat & Bach, "On the Global Convergence of Gradient Descent for Over-Parameterized Models using Optimal Transport" (NeurIPS 2018)

Current:

  • Rotskoff & Vanden-Eijnden, "Trainability and Accuracy of Artificial Neural Networks" (CPAM 2022)
  • Yang and Hu, "Tensor Programs" series (2020-2023) --- extends mean field ideas to deep networks via the feature learning limit

Next Topics

Mean field theory connects to:

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Builds on This