Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

LLM Construction

Multi-Token Prediction

Predicting k future tokens simultaneously using auxiliary prediction heads: forces planning, improves code generation, and connects to speculative decoding.

AdvancedTier 2Frontier~45 min
0

Why This Matters

Standard language models predict one token at a time. At each position tt, the model produces a distribution over the next token xt+1x_{t+1} given the prefix xtx_{\leq t}. This is the autoregressive objective that underlies GPT, Llama, and nearly every decoder-only LLM.

Multi-token prediction changes the training objective: predict the next kk tokens simultaneously. This forces the model to plan ahead rather than making greedy, myopic predictions. Empirically, multi-token prediction improves performance on tasks requiring planning (code generation, mathematical reasoning) while providing a natural connection to speculative decoding at inference time.

Mental Model

Think of single-token prediction as a chess player who only considers the next move. Multi-token prediction is like requiring the player to announce the next kk moves in advance. The player must think further ahead, considering how each move constrains future options. The model learns internal representations that encode longer-horizon structure.

Formal Setup

Let x=(x1,,xT)x = (x_1, \ldots, x_T) be a sequence of tokens. The standard autoregressive loss is:

L1=t=1T1logp(xt+1xt)\mathcal{L}_1 = -\sum_{t=1}^{T-1} \log p(x_{t+1} \mid x_{\leq t})

Definition

Multi-Token Prediction Objective

The multi-token prediction objective with lookahead kk uses kk independent prediction heads h1,,hkh_1, \ldots, h_k sharing a common trunk (transformer body). The loss is:

Lk=j=1kt=1Tjlogpj(xt+jxt)\mathcal{L}_k = -\sum_{j=1}^{k} \sum_{t=1}^{T-j} \log p_j(x_{t+j} \mid x_{\leq t})

Each head hjh_j predicts the token jj steps ahead given the same hidden representation from the trunk at position tt. The trunk is trained with gradients from all kk heads simultaneously.

Definition

Auxiliary Prediction Head

An auxiliary prediction head is a separate output layer (typically a linear projection to vocabulary logits) attached to the shared transformer trunk. Head hjh_j maps the trunk's hidden state at position tt to a distribution over xt+jx_{t+j}. During inference, only head h1h_1 (next-token) is required, but the other heads can serve as draft predictions for speculative decoding.

Architecture

The key architectural choice: the trunk is shared, but the heads are independent. Each head hjh_j is a separate linear layer mapping from the trunk's hidden dimension dd to the vocabulary size V|\mathcal{V}|.

During training, a single forward pass through the trunk produces hidden states ztz_t at each position. Each head independently computes:

pj(vxt)=softmax(Wjzt+bj)vp_j(v \mid x_{\leq t}) = \text{softmax}(W_j z_t + b_j)_v

for vocabulary token vv. The memory overhead is kk additional weight matrices of size d×Vd \times |\mathcal{V}|, which is small relative to the trunk.

Main Theorems

Proposition

Multi-Token Gradient Provides Richer Learning Signal

Statement

The gradient of the multi-token loss Lk\mathcal{L}_k with respect to trunk parameters θ\theta is:

θLk=j=1kθL(j)\nabla_\theta \mathcal{L}_k = \sum_{j=1}^{k} \nabla_\theta \mathcal{L}^{(j)}

where L(j)=tlogpj(xt+jxt)\mathcal{L}^{(j)} = -\sum_t \log p_j(x_{t+j} \mid x_{\leq t}) is the loss from head jj. The trunk receives gradient contributions from all kk future positions simultaneously, encouraging representations that encode information relevant to predicting multiple future tokens.

Intuition

With single-token prediction, the trunk at position tt only needs to represent enough information to predict xt+1x_{t+1}. With multi-token prediction, the same representation must support predicting xt+1,xt+2,,xt+kx_{t+1}, x_{t+2}, \ldots, x_{t+k}. This is a strictly harder task that requires richer, more structured internal representations. The trunk must encode not just "what comes next" but "what trajectory the sequence is on."

Why It Matters

This explains why multi-token prediction improves downstream performance even when only the next-token head is used at inference. The training objective forces the trunk to learn better representations. The auxiliary heads are scaffolding that can be discarded after training (or repurposed for speculative decoding).

Failure Mode

If the kk-step-ahead prediction is nearly independent of the current context (high entropy futures), the auxiliary heads contribute noisy gradients that may not improve trunk representations. This is more likely for large kk in domains with high local entropy, such as open-ended dialogue. The benefit is largest for structured domains like code, where future tokens are strongly constrained by the current context.

Connection to Speculative Decoding

Multi-token prediction provides a natural draft model for speculative decoding without needing a separate model. At inference time:

  1. The auxiliary heads h2,,hkh_2, \ldots, h_k generate k1k-1 draft tokens in parallel (one forward pass)
  2. The next-token head h1h_1 verifies these drafts in the following forward pass
  3. Accept or reject using the standard speculative decoding rejection sampling scheme

This is called self-speculative decoding: the model is its own draft model. The advantage over external draft models is that the heads share the trunk's representation, so draft quality tends to be higher.

Training Procedure and Memory Efficiency

Training with kk heads naively requires storing kk copies of the vocabulary logits at each position, which is O(kTV)O(kT|\mathcal{V}|) memory. For k=4k = 4, T=2048T = 2048, and V=32000|\mathcal{V}| = 32000, this is 4×2048×32000×4=1.054 \times 2048 \times 32000 \times 4 = 1.05 GB of activations in float32, on top of the trunk's activations.

The memory-efficient approach computes the auxiliary losses sequentially. At each position tt, the trunk state ztz_t is computed once and stored. Then each head hjh_j computes its logits, computes the cross-entropy loss, backpropagates through the head to get L(j)/zt\partial \mathcal{L}^{(j)} / \partial z_t, and discards the logits. The gradient contributions from all kk heads are accumulated into ztz_t's gradient before backpropagating through the trunk. This reduces peak memory from O(kTV)O(kT|\mathcal{V}|) to O(TV)O(T|\mathcal{V}|) at the cost of kk sequential head computations.

The trunk backward pass is unchanged: it receives the accumulated gradient j=1kztL(j)\sum_{j=1}^k \nabla_{z_t} \mathcal{L}^{(j)} and backpropagates normally. The computational cost increases by approximately the cost of kk forward and backward passes through the head layers, which is small relative to the trunk.

Example

Training cost breakdown for Llama-scale model

For a 7B parameter model with d=4096d = 4096, V=32000|\mathcal{V}| = 32000, and k=4k = 4 heads:

Trunk forward + backward: ~95% of total compute (dominated by attention and MLP). Head forward + backward: k×2×d×V=4×2×4096×320001.05k \times 2 \times d \times |\mathcal{V}| = 4 \times 2 \times 4096 \times 32000 \approx 1.05 GFLOP per token. Trunk cost per token: approximately 6×7×109426 \times 7 \times 10^9 \approx 42 GFLOP per token.

The auxiliary heads add about 1.05/422.5%1.05/42 \approx 2.5\% to the per-token compute. The memory overhead (393M extra parameters) adds about 5.6% to the model size. These are modest costs for the training signal improvement.

When Multi-Token Prediction Helps

The benefit depends on the domain:

Code generation: strong improvement. Code has rigid syntactic structure where future tokens are highly constrained by the current context. Predicting kk tokens ahead forces the model to plan syntactic closures, variable usage, and control flow.

Mathematical reasoning: moderate improvement. Multi-step derivations benefit from planning, but the model must learn to sequence logical steps.

Open-ended text: marginal improvement. Natural language has high local entropy. The 5th token ahead is often weakly determined by the current position alone.

Watch Out

Multi-token prediction is not the same as beam search

Beam search explores multiple alternative continuations at each step. Multi-token prediction generates a single continuation but predicts multiple positions ahead. Beam search is an inference algorithm. Multi-token prediction is a training objective (and optionally an inference optimization via speculative decoding).

Watch Out

The auxiliary heads are not autoregressive with each other

Each head hjh_j predicts xt+jx_{t+j} from the trunk state ztz_t independently. Head h3h_3 does not condition on the predictions of heads h1h_1 or h2h_2. This independence is what makes parallel prediction possible, but it also limits the expressiveness of later heads. They cannot model dependencies between the predicted tokens themselves.

Summary

  • Standard LLMs predict one token at a time; multi-token prediction predicts the next kk tokens using kk heads sharing a common trunk
  • The trunk receives gradients from all kk heads, learning richer representations that encode longer-horizon structure
  • Auxiliary heads can be repurposed as draft predictors for self-speculative decoding at inference
  • Benefits are largest for structured domains (code, math) where future tokens are strongly constrained by context
  • The heads predict independently. They do not condition on each other's outputs

Exercises

ExerciseCore

Problem

A model uses multi-token prediction with k=4k = 4 heads. The trunk has hidden dimension d=4096d = 4096 and vocabulary size V=32000|\mathcal{V}| = 32000. How many additional parameters do the auxiliary heads add (heads 2, 3, 4), and what fraction is this of a 7B-parameter trunk?

ExerciseAdvanced

Problem

Suppose you train with k=8k = 8 but at inference only use head h1h_1 (standard autoregressive decoding). Would you expect the model to outperform a model trained with k=1k = 1 on code completion tasks? What about on open-ended story generation? Explain the mechanism.

ExerciseResearch

Problem

The auxiliary heads predict xt+jx_{t+j} independently of each other. Design a modification that allows head hjh_j to condition on the predictions of heads h1,,hj1h_1, \ldots, h_{j-1} while still permitting parallel training. What is the computational cost?

References

Canonical:

  • Gloeckle et al., "Better & Faster Large Language Models via Multi-token Prediction" (Meta, 2024)

Current:

  • Stern et al., "Insertion Transformer: Flexible Sequence Generation via Insertion Operations" (2019)

  • Cai et al., "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" (2024)

  • Phuong & Hutter, "Formal Algorithms for Transformers" (2022), arXiv:2207.09238

Next Topics

The natural next steps from multi-token prediction:

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Next Topics