Skip to main content

LLM Construction

Multi-Token Prediction

Predicting k future tokens simultaneously using auxiliary prediction heads: forces planning, improves code generation, and connects to speculative decoding.

AdvancedTier 2FrontierFrontier watch~45 min

Why This Matters

Standard language models predict one token at a time. At each position tt, the model produces a distribution over the next token xt+1x_{t+1} given the prefix xtx_{\leq t}. This is the autoregressive objective that underlies GPT, Llama, and nearly every decoder-only LLM.

Multi-token prediction changes the training objective: predict the next kk tokens simultaneously. This forces the model to plan ahead rather than making greedy, myopic predictions. Empirically, multi-token prediction improves performance on tasks requiring planning (code generation, mathematical reasoning) while providing a natural connection to speculative decoding at inference time.

Mental Model

Think of single-token prediction as a chess player who only considers the next move. Multi-token prediction is like requiring the player to announce the next kk moves in advance. The player must think further ahead, considering how each move constrains future options. The model learns internal representations that encode longer-horizon structure.

Formal Setup

Let x=(x1,,xT)x = (x_1, \ldots, x_T) be a sequence of tokens. The standard autoregressive loss is:

L1=t=1T1logp(xt+1xt)\mathcal{L}_1 = -\sum_{t=1}^{T-1} \log p(x_{t+1} \mid x_{\leq t})

Definition

Multi-Token Prediction Objective

The multi-token prediction objective with lookahead kk uses kk independent prediction heads h1,,hkh_1, \ldots, h_k sharing a common trunk (transformer body). The loss is:

Lk=j=1kt=1Tjlogpj(xt+jxt)\mathcal{L}_k = -\sum_{j=1}^{k} \sum_{t=1}^{T-j} \log p_j(x_{t+j} \mid x_{\leq t})

Each head hjh_j predicts the token jj steps ahead given the same hidden representation from the trunk at position tt. The trunk is trained with gradients from all kk heads simultaneously.

Definition

Auxiliary Prediction Head

An auxiliary prediction head is a separate output layer (typically a linear projection to vocabulary logits) attached to the shared transformer trunk. Head hjh_j maps the trunk's hidden state at position tt to a distribution over xt+jx_{t+j}. During inference, only head h1h_1 (next-token) is required, but the other heads can serve as draft predictions for speculative decoding.

Architecture

The key architectural choice: the trunk is shared, but the heads are independent. Each head hjh_j is a separate linear layer mapping from the trunk's hidden dimension dd to the vocabulary size V|\mathcal{V}|.

During training, a single forward pass through the trunk produces hidden states ztz_t at each position. Each head independently computes:

pj(vxt)=softmax(Wjzt+bj)vp_j(v \mid x_{\leq t}) = \text{softmax}(W_j z_t + b_j)_v

for vocabulary token vv. The memory overhead is kk additional weight matrices of size d×Vd \times |\mathcal{V}|, which is small relative to the trunk.

Gradient Structure

By linearity of the gradient and the chain rule, the gradient of the multi-token loss Lk\mathcal{L}_k with respect to trunk parameters θ\theta decomposes as:

θLk=j=1kθL(j)\nabla_\theta \mathcal{L}_k = \sum_{j=1}^{k} \nabla_\theta \mathcal{L}^{(j)}

where L(j)=tlogpj(xt+jxt)\mathcal{L}^{(j)} = -\sum_t \log p_j(x_{t+j} \mid x_{\leq t}) is the loss from head jj. This is an algebraic identity, not a substantive theorem. The engineering question is how to parallelize head gradients without stalling the trunk, and how to manage the activation memory of kk vocabulary-sized logit tensors.

The content is in the consequence. With single-token prediction, the trunk at position tt only needs to represent enough information to predict xt+1x_{t+1}. With multi-token prediction, the same representation must support predicting xt+1,xt+2,,xt+kx_{t+1}, x_{t+2}, \ldots, x_{t+k}. This is a strictly harder task that empirically produces richer internal representations. The trunk must encode not just "what comes next" but "what trajectory the sequence is on."

This explains why multi-token training improves downstream performance even when only the next-token head is used at inference. The auxiliary heads are scaffolding that can be discarded after training or repurposed for speculative decoding.

Failure mode. If the kk-step-ahead prediction is nearly independent of the current context (high entropy futures), the auxiliary heads contribute noisy gradients that may not improve trunk representations. This is more likely for large kk in domains with high local entropy, such as open-ended dialogue. The benefit is largest for structured domains like code, where future tokens are strongly constrained by the current context.

Speculative-Decoding Speedup Bound

Proposition

Expected Speedup from k-Head Speculative Decoding

Statement

Let AkA_k be the number of draft tokens accepted in one verification step under independent acceptance with probability α(0,1)\alpha \in (0, 1) per position, up to kk drafts. Then

E[Ak]=j=1kαj=α(1αk)1α\mathbb{E}[A_k] = \sum_{j=1}^{k} \alpha^j = \frac{\alpha (1 - \alpha^k)}{1 - \alpha}

and the expected wall-clock speedup over single-token decoding is 1+E[Ak]1 + \mathbb{E}[A_k], which is bounded above by 1+α/(1α)1 + \alpha / (1 - \alpha) as kk \to \infty.

Intuition

Acceptance is a geometric stopping time: drafts 1,,j1, \ldots, j survive only if each of them was accepted. The expected run-length is the sum of survival probabilities. Adding more heads past the point where αj\alpha^j is small brings diminishing returns, which is why Medusa and EAGLE papers use kk in the range of 3 to 5 rather than scaling kk arbitrarily.

Why It Matters

The bound exposes the central trade-off: speedup is gated by acceptance rate, not by head count. Doubling kk from 4 to 8 at α=0.7\alpha = 0.7 only moves expected accepted drafts from 1.931.93 to 2.252.25. Engineering effort should go into raising α\alpha (better draft head quality, EAGLE-style feature conditioning, DeepSeek-V3-style MTP training) rather than stacking more heads.

Failure Mode

The token-level independence assumption is optimistic. Real acceptance events are positively correlated: if draft position 2 is rejected due to a hard context, later positions are often rejected too. This makes the bound an upper estimate; measured speedups on real workloads are typically below 1+E[Ak]1 + \mathbb{E}[A_k].

Connection to Speculative Decoding

Multi-token prediction provides a natural draft model for speculative decoding without needing a separate model. At inference time:

  1. The auxiliary heads h2,,hkh_2, \ldots, h_k generate k1k-1 draft tokens in parallel (one forward pass)
  2. The next-token head h1h_1 verifies these drafts in the following forward pass
  3. Accept or reject using the standard speculative decoding rejection sampling scheme

This is called self-speculative decoding: the model is its own draft model. The advantage over external draft models is that the heads share the trunk's representation, so draft quality tends to be higher.

DeepSeek-V3 (2024) is the flagship production example. It trains an MTP objective as a dense auxiliary signal and reuses the extra heads at inference for speculative decoding. Medusa (Cai et al., 2024) and EAGLE (Li et al., 2024) are the canonical multi-head speculative-decoding precedents; they target inference acceleration rather than training-signal density, and EAGLE adds a small autoregressive draft network on top of the trunk features.

Training Procedure and Memory Efficiency

Training with kk heads naively requires storing kk copies of the vocabulary logits at each position, which is O(kTV)O(kT|\mathcal{V}|) memory. For k=4k = 4, T=2048T = 2048, and V=32000|\mathcal{V}| = 32000, this is 4×2048×32000×4=1.054 \times 2048 \times 32000 \times 4 = 1.05 GB of activations in float32, on top of the trunk's activations.

The memory-efficient approach computes the auxiliary losses sequentially. At each position tt, the trunk state ztz_t is computed once and stored. Then each head hjh_j computes its logits, computes the cross-entropy loss, backpropagates through the head to get L(j)/zt\partial \mathcal{L}^{(j)} / \partial z_t, and discards the logits. The gradient contributions from all kk heads are accumulated into ztz_t's gradient before backpropagating through the trunk. This reduces peak memory from O(kTV)O(kT|\mathcal{V}|) to O(TV)O(T|\mathcal{V}|) at the cost of kk sequential head computations.

The trunk backward pass is unchanged: it receives the accumulated gradient j=1kztL(j)\sum_{j=1}^k \nabla_{z_t} \mathcal{L}^{(j)} and backpropagates normally. The computational cost increases by approximately the cost of kk forward and backward passes through the head layers, which is small relative to the trunk.

Example

Training cost breakdown for Llama-scale model

All FLOP counts use the forward+backward accounting (factor 6 per weight-token pair), per Hoffmann et al. 2022 / Kaplan et al. 2020 convention. Forward-only would be factor 2. The comparison only makes sense when both sides use the same accounting.

For a 7B parameter model with d=4096d = 4096, V=32000|\mathcal{V}| = 32000, and k=4k = 4 heads:

Trunk forward + backward: dominated by attention and MLP, approximately 6Nparams=6×7×109426 \cdot N_{\text{params}} = 6 \times 7 \times 10^9 \approx 42 GFLOP per token.

Head forward + backward: 6kdV=6×4×4096×320003.156 \cdot k \cdot d \cdot |\mathcal{V}| = 6 \times 4 \times 4096 \times 32000 \approx 3.15 GFLOP per token.

The auxiliary heads add about 3.15/427.5%3.15 / 42 \approx 7.5\% to the per-token compute. The ratio is kdVNparams\frac{k \cdot d \cdot |\mathcal{V}|}{N_{\text{params}}}, independent of the factor 6. The memory overhead (393M extra parameters) adds about 5.6% to the model size. These are modest costs for the training signal improvement.

When Multi-Token Prediction Helps

The benefit depends on the domain:

Code generation: strong improvement. Code has rigid syntactic structure where future tokens are highly constrained by the current context. Predicting kk tokens ahead forces the model to plan syntactic closures, variable usage, and control flow.

Mathematical reasoning: moderate improvement. Multi-step derivations benefit from planning, but the model must learn to sequence logical steps.

Open-ended text: marginal improvement. Natural language has high local entropy. The 5th token ahead is often weakly determined by the current position alone.

Watch Out

Multi-token prediction is not the same as beam search

Beam search explores multiple alternative continuations at each step. Multi-token prediction generates a single continuation but predicts multiple positions ahead. Beam search is an inference algorithm. Multi-token prediction is a training objective (and optionally an inference optimization via speculative decoding).

Watch Out

The auxiliary heads are not autoregressive with each other

Each head hjh_j predicts xt+jx_{t+j} from the trunk state ztz_t independently. Head h3h_3 does not condition on the predictions of heads h1h_1 or h2h_2. This independence is what makes parallel prediction possible, but it also limits the expressiveness of later heads. They cannot model dependencies between the predicted tokens themselves.

Summary

  • Standard LLMs predict one token at a time; multi-token prediction predicts the next kk tokens using kk heads sharing a common trunk
  • The trunk receives gradients from all kk heads, learning richer representations that encode longer-horizon structure
  • Auxiliary heads can be repurposed as draft predictors for self-speculative decoding at inference
  • Benefits are largest for structured domains (code, math) where future tokens are strongly constrained by context
  • The heads predict independently. They do not condition on each other's outputs

Exercises

ExerciseCore

Problem

A model uses multi-token prediction with k=4k = 4 heads. The trunk has hidden dimension d=4096d = 4096 and vocabulary size V=32000|\mathcal{V}| = 32000. How many additional parameters do the auxiliary heads add (heads 2, 3, 4), and what fraction is this of a 7B-parameter trunk?

ExerciseAdvanced

Problem

Suppose you train with k=8k = 8 but at inference only use head h1h_1 (standard autoregressive decoding). Would you expect the model to outperform a model trained with k=1k = 1 on code completion tasks? What about on open-ended story generation? Explain the mechanism.

ExerciseResearch

Problem

Medusa (Cai et al., 2024, arXiv:2401.10774) augments a frozen base model with kk independent heads and accepts drafts under speculative-decoding rejection sampling. Assume token-level independent acceptance with probability α\alpha per draft position. Derive the expected number of accepted tokens per decoding step as a function of α\alpha and kk, and the expected wall-clock speedup over single-token decoding when the verification pass dominates cost.

References

Canonical:

  • Gloeckle et al., "Better & Faster Large Language Models via Multi-token Prediction" (Meta, 2024), arXiv:2404.19737
  • DeepSeek-AI, "DeepSeek-V3 Technical Report" (2024), arXiv:2412.19437. Uses MTP as an auxiliary training objective for dense supervision and at inference for speculative decoding.

Speculative-decoding heads:

  • Cai et al., "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" (2024), arXiv:2401.10774
  • Li et al., "EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty" (2024), arXiv:2401.15077

Reference architecture:

  • Phuong & Hutter, "Formal Algorithms for Transformers" (2022), arXiv:2207.09238

Next Topics

The natural next steps from multi-token prediction:

Last reviewed: April 18, 2026

Canonical graph

Required before and derived from this topic

These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.

Required prerequisites

1

Derived topics

2