Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

LLM Construction

Attention Variants and Efficiency

Multi-head, multi-query, grouped-query, linear, and sparse attention: how each variant trades expressivity for efficiency, and when to use which.

AdvancedTier 2Current~55 min

Why This Matters

Standard scaled dot-product attention has O(n2d)O(n^2 d) time and O(n2)O(n^2) memory complexity where nn is sequence length. For long sequences (documents, code, genomics), this becomes the bottleneck. Different attention variants trade expressivity for efficiency, and the right choice depends on the deployment scenario: training throughput, inference latency, KV cache size, or sequence length.

Understanding these variants requires knowing exactly what each one changes and what it preserves.

Baseline: Multi-Head Attention (MHA)

Definition

Multi-Head Attention

Given HH attention heads, each head hh has its own projection matrices WhQ,WhK,WhVRd×dkW_h^Q, W_h^K, W_h^V \in \mathbb{R}^{d \times d_k} where dk=d/Hd_k = d/H:

headh=Attention(XWhQ,XWhK,XWhV)=softmax(XWhQ(XWhK)dk)XWhV\text{head}_h = \text{Attention}(XW_h^Q, XW_h^K, XW_h^V) = \text{softmax}\left(\frac{XW_h^Q (XW_h^K)^\top}{\sqrt{d_k}}\right) XW_h^V

MHA(X)=Concat(head1,,headH)WO\text{MHA}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) W^O

where WORd×dW^O \in \mathbb{R}^{d \times d} is the output projection.

Parameter count per layer: 4d24d^2 (three projections WQ,WK,WVW^Q, W^K, W^V plus output WOW^O, each d×dd \times d).

KV cache per token: 2Hdk=2d2 \cdot H \cdot d_k = 2d values must be stored for each token in the context. For a model with LL layers, the total KV cache per token is 2Ld2Ld.

Proposition

Multi-Head Expressivity

Statement

Each attention head can learn a different attention pattern (different notion of "relevance" between tokens). With HH heads, the layer can simultaneously attend to HH different types of relationships. The concatenation followed by WOW^O allows the layer to mix information from all attention patterns.

Formally, the rank of the attention matrix in each head is at most dkd_k, but the combined output through WOW^O can represent functions that no single head of dimension dd could represent. The multi-head decomposition is not merely a factorization; it enables attending to different subspaces independently.

Intuition

Think of each head as asking a different question about the input. Head 1 might attend based on syntactic relationships (subject to verb). Head 2 might attend based on semantic similarity. Head 3 might attend to positional proximity. By running these in parallel and combining, the model captures multiple types of dependencies simultaneously.

Proof Sketch

Consider a function that requires attending to two incompatible patterns simultaneously (e.g., the nearest token and the most semantically similar token). A single softmax attention head produces a convex combination of values, which cannot represent two sharp peaks simultaneously. Two separate heads can each produce one sharp peak, and the output projection can combine them.

Why It Matters

This explains why multi-head attention outperforms single-head attention with the same total dimension. The computational cost is identical (same number of parameters and FLOPs), but the representational capacity increases.

Failure Mode

In practice, many heads learn redundant or near-identical patterns. Michel et al. (2019) showed that most heads can be pruned after training with minimal quality loss. The theoretical expressivity advantage is not always used. This observation motivates the reduced-head variants below.

Multi-Query Attention (MQA)

Definition

Multi-Query Attention

Multi-query attention (Shazeer, 2019) uses separate WhQW_h^Q projections per head but shares a single WKW^K and WVW^V across all heads:

headh=softmax(XWhQ(XWK)dk)XWV\text{head}_h = \text{softmax}\left(\frac{XW_h^Q (XW^K)^\top}{\sqrt{d_k}}\right) XW^V

All heads compute attention over the same key-value pairs but with different queries.

KV cache per token: 2dk2d_k (one set of keys and values instead of HH sets). This is an HH-fold reduction compared to MHA.

Tradeoff: MQA reduces KV cache by a factor of HH (typically 32x to 128x for large models), enabling much longer contexts and larger batch sizes during inference. Quality degrades slightly because heads can no longer attend to different subspaces of keys and values. For many tasks, this degradation is small (less than 1% on benchmarks).

Grouped-Query Attention (GQA)

Definition

Grouped-Query Attention

Grouped-query attention (Ainslie et al., 2023) divides the HH query heads into GG groups. Each group shares one set of key-value projections:

headh=softmax(XWhQ(XWg(h)K)dk)XWg(h)V\text{head}_h = \text{softmax}\left(\frac{XW_h^Q (XW_{g(h)}^K)^\top}{\sqrt{d_k}}\right) XW_{g(h)}^V

where g(h)=hG/Hg(h) = \lfloor hG/H \rfloor maps head hh to its group.

KV cache per token: 2Gdk2G \cdot d_k. When G=HG = H, GQA is MHA. When G=1G = 1, GQA is MQA.

GQA is the current standard for large language models (Llama 2/3, Mistral). It recovers most of MHA's quality while getting most of MQA's cache efficiency. Typical choices: H=32H = 32 query heads with G=8G = 8 KV groups (4x cache reduction).

Summary of Head-Sharing Variants

VariantKV headsKV cache per tokenQualityInference speed
MHAHH2Hdk=2d2Hd_k = 2dBestSlowest (large cache)
GQAGG (1<G<H1 < G < H)2Gdk2Gd_kNear MHAFast
MQA112dk2d_kSlightly worseFastest

Linear Attention

Standard attention computes softmax(QK/dk)V\text{softmax}(QK^\top/\sqrt{d_k})V, which requires materializing the n×nn \times n attention matrix. Linear attention replaces the softmax with a kernel function to avoid this.

Proposition

Linear Attention as Kernel Approximation

Statement

Define ϕ:RdkRm\phi: \mathbb{R}^{d_k} \to \mathbb{R}^m as a feature map. Linear attention replaces the softmax attention with:

LinearAttn(Q,K,V)i=ϕ(qi)j=1nϕ(kj)vjϕ(qi)j=1nϕ(kj)\text{LinearAttn}(Q, K, V)_i = \frac{\phi(q_i)^\top \sum_{j=1}^{n} \phi(k_j) v_j^\top}{\phi(q_i)^\top \sum_{j=1}^{n} \phi(k_j)}

The key observation: j=1nϕ(kj)vj\sum_{j=1}^{n} \phi(k_j) v_j^\top is an m×dvm \times d_v matrix that can be computed once and reused for all queries. This gives O(nmdv)O(nmd_v) complexity instead of O(n2dk)O(n^2 d_k).

Intuition

Standard attention computes pairwise similarity between all query-key pairs, then uses these similarities to weight values. Linear attention factorizes this: first summarize all key-value pairs into a fixed-size matrix, then match each query against this summary. The summary is a compressed representation of the entire context.

Proof Sketch

Write the standard kernel attention: Attn(qi,K,V)=jκ(qi,kj)vj/jκ(qi,kj)\text{Attn}(q_i, K, V) = \sum_j \kappa(q_i, k_j) v_j / \sum_j \kappa(q_i, k_j) where κ(q,k)=ϕ(q)ϕ(k)\kappa(q, k) = \phi(q)^\top \phi(k). Substitute the kernel decomposition and rearrange: the numerator becomes ϕ(qi)(jϕ(kj)vj)\phi(q_i)^\top (\sum_j \phi(k_j) v_j^\top), which separates the query from the key-value aggregation.

Why It Matters

Linear attention achieves O(n)O(n) complexity in sequence length (for fixed mm and dvd_v). This enables attention over sequences of length 10510^5 or more, which is impractical with standard O(n2)O(n^2) attention.

Failure Mode

The feature map ϕ\phi must approximate the softmax kernel well. Common choices (ELU+1, random Fourier features) provide poor approximations, leading to quality degradation. On language modeling benchmarks, linear attention models consistently underperform softmax attention by a meaningful margin. The n×nn \times n matrix computed by softmax attention contains fine-grained pairwise information that the m×dvm \times d_v summary cannot fully capture when mnm \ll n.

Sparse Attention

Instead of replacing softmax, sparse attention restricts which query-key pairs are computed.

Local window attention: each token attends only to the ww nearest tokens. Complexity: O(nw)O(nw). Captures local dependencies but misses long-range ones.

Strided attention: attend to every ss-th token plus a local window. Captures periodic structure and some long-range dependencies.

Hash-based (LSH) attention (Kitaev et al., 2020): use locality-sensitive hashing to find the most similar keys for each query. Only compute attention for likely-high-similarity pairs. Expected complexity: O(nlogn)O(n \log n).

Block-sparse patterns: combine local windows with a few global tokens that attend to everything. Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020) use this approach.

The practical problem: sparse patterns must be known in advance or computed cheaply. For autoregressive generation, where the attention pattern depends on the content being generated, fixed sparse patterns may miss important long-range dependencies.

Common Confusions

Watch Out

MQA does not reduce training FLOPs

MQA and GQA save memory for the KV cache during inference, which allows larger batch sizes and longer contexts. During training, the dominant cost is the matrix multiplications for QKQK^\top and the attention-weighted sum over values, which are the same for all variants (the shared K, V projections are simply broadcast). The training speedup from MQA is modest; the inference speedup is large.

Watch Out

Linear attention is not just softmax without exp

Dropping the exponential from softmax gives unnormalized attention, not linear attention. Linear attention specifically uses a kernel decomposition κ(q,k)=ϕ(q)ϕ(k)\kappa(q, k) = \phi(q)^\top \phi(k) to factorize the computation. The feature map ϕ\phi is what makes the factorization possible. Simply removing exp would give negative attention weights and no computational benefit.

Canonical Examples

Example

KV cache savings with GQA

Consider a model with d=4096d = 4096, H=32H = 32 heads, L=32L = 32 layers, serving at sequence length n=8192n = 8192 with float16 precision (2 bytes per value). MHA KV cache per sequence: 2×32×128×32×8192×2=4.292 \times 32 \times 128 \times 32 \times 8192 \times 2 = 4.29 GB. GQA with G=8G = 8: 2×8×128×32×8192×2=1.072 \times 8 \times 128 \times 32 \times 8192 \times 2 = 1.07 GB. MQA: 2×1×128×32×8192×2=0.1342 \times 1 \times 128 \times 32 \times 8192 \times 2 = 0.134 GB. On a GPU with 80 GB memory, the difference between 4.29 GB and 0.134 GB per sequence determines whether you can serve 10 or 500 concurrent requests.

Exercises

ExerciseCore

Problem

A Transformer has H=64H = 64 query heads and uses GQA with G=8G = 8 KV groups. How many query heads share each KV group? What is the KV cache reduction factor compared to standard MHA?

ExerciseAdvanced

Problem

Show that standard softmax attention can be written as a kernel: Attn(qi,K,V)=jκ(qi,kj)vj/jκ(qi,kj)\text{Attn}(q_i, K, V) = \sum_j \kappa(q_i, k_j) v_j / \sum_j \kappa(q_i, k_j) with κ(q,k)=exp(qk/dk)\kappa(q, k) = \exp(q^\top k / \sqrt{d_k}). Why can this kernel not be exactly decomposed as ϕ(q)ϕ(k)\phi(q)^\top \phi(k) for a finite-dimensional ϕ\phi?

References

Canonical:

  • Vaswani et al., "Attention Is All You Need" (NeurIPS 2017), Section 3.2
  • Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need" (2019)

Current:

  • Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (EMNLP 2023)
  • Katharopoulos et al., "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (ICML 2020)

Next Topics

  • KV cache: how these attention variants affect the memory and speed of autoregressive generation

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Builds on This

Next Topics