LLM Construction
Attention Variants and Efficiency
Multi-head, multi-query, grouped-query, linear, and sparse attention: how each variant trades expressivity for efficiency, and when to use which.
Prerequisites
Why This Matters
Standard scaled dot-product attention has time and memory complexity where is sequence length. For long sequences (documents, code, genomics), this becomes the bottleneck. Different attention variants trade expressivity for efficiency, and the right choice depends on the deployment scenario: training throughput, inference latency, KV cache size, or sequence length.
Understanding these variants requires knowing exactly what each one changes and what it preserves.
Baseline: Multi-Head Attention (MHA)
Multi-Head Attention
Given attention heads, each head has its own projection matrices where :
where is the output projection.
Parameter count per layer: (three projections plus output , each ).
KV cache per token: values must be stored for each token in the context. For a model with layers, the total KV cache per token is .
Multi-Head Expressivity
Statement
Each attention head can learn a different attention pattern (different notion of "relevance" between tokens). With heads, the layer can simultaneously attend to different types of relationships. The concatenation followed by allows the layer to mix information from all attention patterns.
Formally, the rank of the attention matrix in each head is at most , but the combined output through can represent functions that no single head of dimension could represent. The multi-head decomposition is not merely a factorization; it enables attending to different subspaces independently.
Intuition
Think of each head as asking a different question about the input. Head 1 might attend based on syntactic relationships (subject to verb). Head 2 might attend based on semantic similarity. Head 3 might attend to positional proximity. By running these in parallel and combining, the model captures multiple types of dependencies simultaneously.
Proof Sketch
Consider a function that requires attending to two incompatible patterns simultaneously (e.g., the nearest token and the most semantically similar token). A single softmax attention head produces a convex combination of values, which cannot represent two sharp peaks simultaneously. Two separate heads can each produce one sharp peak, and the output projection can combine them.
Why It Matters
This explains why multi-head attention outperforms single-head attention with the same total dimension. The computational cost is identical (same number of parameters and FLOPs), but the representational capacity increases.
Failure Mode
In practice, many heads learn redundant or near-identical patterns. Michel et al. (2019) showed that most heads can be pruned after training with minimal quality loss. The theoretical expressivity advantage is not always used. This observation motivates the reduced-head variants below.
Multi-Query Attention (MQA)
Multi-Query Attention
Multi-query attention (Shazeer, 2019) uses separate projections per head but shares a single and across all heads:
All heads compute attention over the same key-value pairs but with different queries.
KV cache per token: (one set of keys and values instead of sets). This is an -fold reduction compared to MHA.
Tradeoff: MQA reduces KV cache by a factor of (typically 32x to 128x for large models), enabling much longer contexts and larger batch sizes during inference. Quality degrades slightly because heads can no longer attend to different subspaces of keys and values. For many tasks, this degradation is small (less than 1% on benchmarks).
Grouped-Query Attention (GQA)
Grouped-Query Attention
Grouped-query attention (Ainslie et al., 2023) divides the query heads into groups. Each group shares one set of key-value projections:
where maps head to its group.
KV cache per token: . When , GQA is MHA. When , GQA is MQA.
GQA is the current standard for large language models (Llama 2/3, Mistral). It recovers most of MHA's quality while getting most of MQA's cache efficiency. Typical choices: query heads with KV groups (4x cache reduction).
Summary of Head-Sharing Variants
| Variant | KV heads | KV cache per token | Quality | Inference speed |
|---|---|---|---|---|
| MHA | Best | Slowest (large cache) | ||
| GQA | () | Near MHA | Fast | |
| MQA | Slightly worse | Fastest |
Linear Attention
Standard attention computes , which requires materializing the attention matrix. Linear attention replaces the softmax with a kernel function to avoid this.
Linear Attention as Kernel Approximation
Statement
Define as a feature map. Linear attention replaces the softmax attention with:
The key observation: is an matrix that can be computed once and reused for all queries. This gives complexity instead of .
Intuition
Standard attention computes pairwise similarity between all query-key pairs, then uses these similarities to weight values. Linear attention factorizes this: first summarize all key-value pairs into a fixed-size matrix, then match each query against this summary. The summary is a compressed representation of the entire context.
Proof Sketch
Write the standard kernel attention: where . Substitute the kernel decomposition and rearrange: the numerator becomes , which separates the query from the key-value aggregation.
Why It Matters
Linear attention achieves complexity in sequence length (for fixed and ). This enables attention over sequences of length or more, which is impractical with standard attention.
Failure Mode
The feature map must approximate the softmax kernel well. Common choices (ELU+1, random Fourier features) provide poor approximations, leading to quality degradation. On language modeling benchmarks, linear attention models consistently underperform softmax attention by a meaningful margin. The matrix computed by softmax attention contains fine-grained pairwise information that the summary cannot fully capture when .
Sparse Attention
Instead of replacing softmax, sparse attention restricts which query-key pairs are computed.
Local window attention: each token attends only to the nearest tokens. Complexity: . Captures local dependencies but misses long-range ones.
Strided attention: attend to every -th token plus a local window. Captures periodic structure and some long-range dependencies.
Hash-based (LSH) attention (Kitaev et al., 2020): use locality-sensitive hashing to find the most similar keys for each query. Only compute attention for likely-high-similarity pairs. Expected complexity: .
Block-sparse patterns: combine local windows with a few global tokens that attend to everything. Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020) use this approach.
The practical problem: sparse patterns must be known in advance or computed cheaply. For autoregressive generation, where the attention pattern depends on the content being generated, fixed sparse patterns may miss important long-range dependencies.
Common Confusions
MQA does not reduce training FLOPs
MQA and GQA save memory for the KV cache during inference, which allows larger batch sizes and longer contexts. During training, the dominant cost is the matrix multiplications for and the attention-weighted sum over values, which are the same for all variants (the shared K, V projections are simply broadcast). The training speedup from MQA is modest; the inference speedup is large.
Linear attention is not just softmax without exp
Dropping the exponential from softmax gives unnormalized attention, not linear attention. Linear attention specifically uses a kernel decomposition to factorize the computation. The feature map is what makes the factorization possible. Simply removing exp would give negative attention weights and no computational benefit.
Canonical Examples
KV cache savings with GQA
Consider a model with , heads, layers, serving at sequence length with float16 precision (2 bytes per value). MHA KV cache per sequence: GB. GQA with : GB. MQA: GB. On a GPU with 80 GB memory, the difference between 4.29 GB and 0.134 GB per sequence determines whether you can serve 10 or 500 concurrent requests.
Exercises
Problem
A Transformer has query heads and uses GQA with KV groups. How many query heads share each KV group? What is the KV cache reduction factor compared to standard MHA?
Problem
Show that standard softmax attention can be written as a kernel: with . Why can this kernel not be exactly decomposed as for a finite-dimensional ?
References
Canonical:
- Vaswani et al., "Attention Is All You Need" (NeurIPS 2017), Section 3.2
- Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need" (2019)
Current:
- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (EMNLP 2023)
- Katharopoulos et al., "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (ICML 2020)
Next Topics
- KV cache: how these attention variants affect the memory and speed of autoregressive generation
Last reviewed: April 2026
Prerequisites
Foundations this topic depends on.
- Attention Mechanism TheoryLayer 4
- Matrix Operations and PropertiesLayer 0A
- Sets, Functions, and RelationsLayer 0A
- Basic Logic and Proof TechniquesLayer 0A
- Softmax and Numerical StabilityLayer 1
- Flash AttentionLayer 5