LLM Construction
Multi-Token Prediction
Predicting k future tokens simultaneously using auxiliary prediction heads: forces planning, improves code generation, and connects to speculative decoding.
Prerequisites
Why This Matters
Standard language models predict one token at a time. At each position , the model produces a distribution over the next token given the prefix . This is the autoregressive objective that underlies GPT, Llama, and nearly every decoder-only LLM.
Multi-token prediction changes the training objective: predict the next tokens simultaneously. This forces the model to plan ahead rather than making greedy, myopic predictions. Empirically, multi-token prediction improves performance on tasks requiring planning (code generation, mathematical reasoning) while providing a natural connection to speculative decoding at inference time.
Mental Model
Think of single-token prediction as a chess player who only considers the next move. Multi-token prediction is like requiring the player to announce the next moves in advance. The player must think further ahead, considering how each move constrains future options. The model learns internal representations that encode longer-horizon structure.
Formal Setup
Let be a sequence of tokens. The standard autoregressive loss is:
Multi-Token Prediction Objective
The multi-token prediction objective with lookahead uses independent prediction heads sharing a common trunk (transformer body). The loss is:
Each head predicts the token steps ahead given the same hidden representation from the trunk at position . The trunk is trained with gradients from all heads simultaneously.
Auxiliary Prediction Head
An auxiliary prediction head is a separate output layer (typically a linear projection to vocabulary logits) attached to the shared transformer trunk. Head maps the trunk's hidden state at position to a distribution over . During inference, only head (next-token) is required, but the other heads can serve as draft predictions for speculative decoding.
Architecture
The key architectural choice: the trunk is shared, but the heads are independent. Each head is a separate linear layer mapping from the trunk's hidden dimension to the vocabulary size .
During training, a single forward pass through the trunk produces hidden states at each position. Each head independently computes:
for vocabulary token . The memory overhead is additional weight matrices of size , which is small relative to the trunk.
Main Theorems
Multi-Token Gradient Provides Richer Learning Signal
Statement
The gradient of the multi-token loss with respect to trunk parameters is:
where is the loss from head . The trunk receives gradient contributions from all future positions simultaneously, encouraging representations that encode information relevant to predicting multiple future tokens.
Intuition
With single-token prediction, the trunk at position only needs to represent enough information to predict . With multi-token prediction, the same representation must support predicting . This is a strictly harder task that requires richer, more structured internal representations. The trunk must encode not just "what comes next" but "what trajectory the sequence is on."
Why It Matters
This explains why multi-token prediction improves downstream performance even when only the next-token head is used at inference. The training objective forces the trunk to learn better representations. The auxiliary heads are scaffolding that can be discarded after training (or repurposed for speculative decoding).
Failure Mode
If the -step-ahead prediction is nearly independent of the current context (high entropy futures), the auxiliary heads contribute noisy gradients that may not improve trunk representations. This is more likely for large in domains with high local entropy, such as open-ended dialogue. The benefit is largest for structured domains like code, where future tokens are strongly constrained by the current context.
Connection to Speculative Decoding
Multi-token prediction provides a natural draft model for speculative decoding without needing a separate model. At inference time:
- The auxiliary heads generate draft tokens in parallel (one forward pass)
- The next-token head verifies these drafts in the following forward pass
- Accept or reject using the standard speculative decoding rejection sampling scheme
This is called self-speculative decoding: the model is its own draft model. The advantage over external draft models is that the heads share the trunk's representation, so draft quality tends to be higher.
Training Procedure and Memory Efficiency
Training with heads naively requires storing copies of the vocabulary logits at each position, which is memory. For , , and , this is GB of activations in float32, on top of the trunk's activations.
The memory-efficient approach computes the auxiliary losses sequentially. At each position , the trunk state is computed once and stored. Then each head computes its logits, computes the cross-entropy loss, backpropagates through the head to get , and discards the logits. The gradient contributions from all heads are accumulated into 's gradient before backpropagating through the trunk. This reduces peak memory from to at the cost of sequential head computations.
The trunk backward pass is unchanged: it receives the accumulated gradient and backpropagates normally. The computational cost increases by approximately the cost of forward and backward passes through the head layers, which is small relative to the trunk.
Training cost breakdown for Llama-scale model
For a 7B parameter model with , , and heads:
Trunk forward + backward: ~95% of total compute (dominated by attention and MLP). Head forward + backward: GFLOP per token. Trunk cost per token: approximately GFLOP per token.
The auxiliary heads add about to the per-token compute. The memory overhead (393M extra parameters) adds about 5.6% to the model size. These are modest costs for the training signal improvement.
When Multi-Token Prediction Helps
The benefit depends on the domain:
Code generation: strong improvement. Code has rigid syntactic structure where future tokens are highly constrained by the current context. Predicting tokens ahead forces the model to plan syntactic closures, variable usage, and control flow.
Mathematical reasoning: moderate improvement. Multi-step derivations benefit from planning, but the model must learn to sequence logical steps.
Open-ended text: marginal improvement. Natural language has high local entropy. The 5th token ahead is often weakly determined by the current position alone.
Multi-token prediction is not the same as beam search
Beam search explores multiple alternative continuations at each step. Multi-token prediction generates a single continuation but predicts multiple positions ahead. Beam search is an inference algorithm. Multi-token prediction is a training objective (and optionally an inference optimization via speculative decoding).
The auxiliary heads are not autoregressive with each other
Each head predicts from the trunk state independently. Head does not condition on the predictions of heads or . This independence is what makes parallel prediction possible, but it also limits the expressiveness of later heads. They cannot model dependencies between the predicted tokens themselves.
Summary
- Standard LLMs predict one token at a time; multi-token prediction predicts the next tokens using heads sharing a common trunk
- The trunk receives gradients from all heads, learning richer representations that encode longer-horizon structure
- Auxiliary heads can be repurposed as draft predictors for self-speculative decoding at inference
- Benefits are largest for structured domains (code, math) where future tokens are strongly constrained by context
- The heads predict independently. They do not condition on each other's outputs
Exercises
Problem
A model uses multi-token prediction with heads. The trunk has hidden dimension and vocabulary size . How many additional parameters do the auxiliary heads add (heads 2, 3, 4), and what fraction is this of a 7B-parameter trunk?
Problem
Suppose you train with but at inference only use head (standard autoregressive decoding). Would you expect the model to outperform a model trained with on code completion tasks? What about on open-ended story generation? Explain the mechanism.
Problem
The auxiliary heads predict independently of each other. Design a modification that allows head to condition on the predictions of heads while still permitting parallel training. What is the computational cost?
References
Canonical:
- Gloeckle et al., "Better & Faster Large Language Models via Multi-token Prediction" (Meta, 2024)
Current:
-
Stern et al., "Insertion Transformer: Flexible Sequence Generation via Insertion Operations" (2019)
-
Cai et al., "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" (2024)
-
Phuong & Hutter, "Formal Algorithms for Transformers" (2022), arXiv:2207.09238
Next Topics
The natural next steps from multi-token prediction:
- Latent reasoning: another approach to making models plan ahead, but in continuous hidden state space rather than token space
- Speculative decoding and quantization: the inference framework where multi-token prediction heads serve as draft models
Last reviewed: April 2026
Prerequisites
Foundations this topic depends on.
- Transformer ArchitectureLayer 4
- Attention Mechanism TheoryLayer 4
- Matrix Operations and PropertiesLayer 0A
- Sets, Functions, and RelationsLayer 0A
- Basic Logic and Proof TechniquesLayer 0A
- Softmax and Numerical StabilityLayer 1
- Feedforward Networks and BackpropagationLayer 2
- Differentiation in RnLayer 0A
- Matrix CalculusLayer 1
- The Jacobian MatrixLayer 0A
- The Hessian MatrixLayer 0A
- Activation FunctionsLayer 1
- Convex Optimization BasicsLayer 1