Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

Comparison

Multi-Head vs. Multi-Query vs. Grouped-Query Attention

Multi-head attention (MHA) gives each head its own K, V projections. Multi-query attention (MQA) shares a single K, V across all heads. Grouped-query attention (GQA) shares K, V within groups of heads. MQA and GQA reduce KV cache size during autoregressive inference, trading a small quality loss for dramatically lower memory and faster decoding.

What Each Mechanism Does

All three are variants of the attention mechanism in transformers. They differ in how key and value projections are shared (or not) across attention heads.

Multi-Head Attention (MHA): The original formulation from Vaswani et al. Each of hh heads has independent query, key, and value projections:

headi=Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i

where Qi=XWiQQ_i = XW^Q_i, Ki=XWiKK_i = XW^K_i, Vi=XWiVV_i = XW^V_i, and each WiK,WiVRd×dkW^K_i, W^V_i \in \mathbb{R}^{d \times d_k} with dk=d/hd_k = d/h.

Multi-Query Attention (MQA): All heads share a single key and value projection. Each head still has its own query projection:

headi=softmax(QiKTdk)V\text{head}_i = \text{softmax}\left(\frac{Q_i K^T}{\sqrt{d_k}}\right) V

where K=XWKK = XW^K and V=XWVV = XW^V are shared across all hh heads.

Grouped-Query Attention (GQA): Heads are divided into gg groups. Heads within each group share key and value projections. With hh total heads and gg groups, each group has h/gh/g heads sharing one K, V:

headi=softmax(QiKg(i)Tdk)Vg(i)\text{head}_i = \text{softmax}\left(\frac{Q_i K_{g(i)}^T}{\sqrt{d_k}}\right) V_{g(i)}

where g(i)=ig/hg(i) = \lfloor i \cdot g / h \rfloor maps head ii to its group.

MHA is GQA with g=hg = h (every head is its own group). MQA is GQA with g=1g = 1 (all heads share one group). GQA interpolates between the two.

The KV Cache Problem

During autoregressive generation, the model produces one token at a time. To avoid recomputing attention over the entire sequence, the key and value tensors for all previous positions are cached (the KV cache).

For a model with LL layers, hh heads, head dimension dkd_k, and sequence length ss, the KV cache stores:

MechanismKV Cache Size (per layer)Total KV Cache
MHA2×h×s×dk2 \times h \times s \times d_k2Lhsdk2Lhsd_k
MQA2×1×s×dk2 \times 1 \times s \times d_k2Lsdk2Lsd_k
GQA (gg groups)2×g×s×dk2 \times g \times s \times d_k2Lgsdk2Lgsd_k

For a 70B model with L=80L = 80, h=64h = 64, dk=128d_k = 128, and context s=8192s = 8192 in bf16:

MHA's KV cache can exceed the model weight memory at long sequences. GQA with 8 groups reduces it by 8x. MQA reduces it by 64x.

Quality vs. Efficiency Tradeoff

The quality ordering is: MHA > GQA > MQA.

MQA forces all heads to attend with the same keys and values, reducing the model's ability to capture diverse attention patterns. Shazeer (2019) showed MQA quality loss is small (0.1-0.5% on benchmarks) but measurable.

GQA recovers most of the quality gap. Ainslie et al. (2023) showed that GQA with g=8g = 8 for a 64-head model matches MHA quality within measurement noise on most benchmarks while providing 8x KV cache reduction.

The key insight: most of the representational capacity comes from diverse queries. Keys and values are more redundant across heads. Sharing them sacrifices relatively little modeling power.

Side-by-Side Comparison

PropertyMHAGQAMQA
KV headshh (all independent)gg groups (1<g<h1 < g < h)1 (all shared)
KV cache size2hsdk2hsd_k per layer2gsdk2gsd_k per layer2sdk2sd_k per layer
KV cache reduction1x (baseline)h/gh/g timeshh times
QualityBest (baseline)Near-MHASlight degradation
Inference throughputSlowest (memory-bound)2-4x faster than MHAFastest
Training costBaselineSameSame (fewer KV params)
Used inBERT, GPT-2, GPT-3LLaMA-2-70B, Gemma, MistralPaLM, Falcon
Parameter countd(3d)d(3d) per layerd(d+2gdk)d(d + 2gd_k) per layerd(d+2dk)d(d + 2d_k) per layer

Inference Throughput: Why KV Cache Size Matters

Autoregressive generation is memory-bandwidth bound, not compute-bound. At each step, the model reads the entire KV cache to compute attention for one new token. The time per token is proportional to the amount of data read from memory.

For a single batch on an A100 (2 TB/s memory bandwidth), generating one token with a 70B MHA model at sequence length 8192 requires reading 20 GB of KV cache, taking ~10 ms just for memory transfer. GQA with g=8g = 8 reads 2.5 GB (~1.25 ms). MQA reads 0.3 GB (~0.15 ms).

Larger batch sizes multiply the KV cache linearly. Serving 32 concurrent requests with MHA at 8K context requires 640 GB of KV cache, far exceeding a single GPU. GQA makes large-batch serving feasible.

Converting MHA to GQA

Ainslie et al. (2023) showed that pretrained MHA models can be converted to GQA through a two-step process:

  1. Construct grouped K, V: Average the K (and V) weight matrices within each group: WgroupjK=1h/gigroupjWiKW^K_{\text{group}_j} = \frac{1}{h/g} \sum_{i \in \text{group}_j} W^K_i.

  2. Continue pretraining: Fine-tune with the grouped architecture for 5-10% of original pretraining tokens to recover quality.

This is cheaper than training GQA from scratch and allows upgrading existing MHA checkpoints. LLaMA-2-70B used this approach, starting from an MHA architecture and converting to GQA with g=8g = 8.

Common Confusions

Watch Out

GQA does not reduce training compute

GQA reduces the number of K, V projection parameters, but the dominant training cost is the attention computation (O(s2d)O(s^2 d)) and the feedforward layers, not the KV projection. Training FLOPs for GQA and MHA are nearly identical. The benefit is almost entirely at inference time, through smaller KV caches and higher throughput.

Watch Out

MQA is not the same as single-head attention

MQA has hh heads with independent query projections. Each head computes its own attention pattern using the shared keys. The model still captures hh different attention patterns. Single-head attention uses one Q, one K, one V and computes one attention pattern. MQA retains multi-head diversity in queries while sharing the key-value memory.

Watch Out

KV cache size does not depend on batch size linearly in all serving frameworks

Paged attention (vLLM) allocates KV cache in blocks, allowing non-contiguous memory. This reduces waste from padding and variable-length sequences, but the total KV cache for active requests still scales linearly with batch size and sequence length. GQA reduces the per-request cache, which compounds with batching.

Watch Out

The number of GQA groups is not always 8

g=8g = 8 is common (LLaMA-2-70B, Mistral-7B) but not universal. The optimal gg depends on the model size, target context length, and serving constraints. Smaller models with fewer heads may use g=2g = 2 or g=4g = 4. The choice balances KV cache reduction against quality preservation.

References

  1. Vaswani, A. et al. (2017). "Attention Is All You Need." NeurIPS 2017. (Original multi-head attention.)
  2. Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150. (Multi-query attention, single shared KV.)
  3. Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. (Grouped-query attention, uptraining from MHA.)
  4. Touvron, H. et al. (2023). "Llama 2: Open Foundation and Fine-Tuned Chat Models." arXiv:2307.09288. (LLaMA-2-70B using GQA with 8 groups.)
  5. Pope, R. et al. (2023). "Efficiently Scaling Transformer Inference." MLSys 2023. (Analysis of memory-bandwidth bottleneck in autoregressive inference.)
  6. Kwon, W. et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023. (vLLM paged attention for KV cache management.)
  7. Jiang, A. Q. et al. (2023). "Mistral 7B." arXiv:2310.06825. (GQA in a 7B model with sliding window attention.)