What Each Mechanism Does
All three are variants of the attention mechanism in transformers. They differ in how key and value projections are shared (or not) across attention heads.
Multi-Head Attention (MHA): The original formulation from Vaswani et al. Each of heads has independent query, key, and value projections:
where , , , and each with .
Multi-Query Attention (MQA): All heads share a single key and value projection. Each head still has its own query projection:
where and are shared across all heads.
Grouped-Query Attention (GQA): Heads are divided into groups. Heads within each group share key and value projections. With total heads and groups, each group has heads sharing one K, V:
where maps head to its group.
MHA is GQA with (every head is its own group). MQA is GQA with (all heads share one group). GQA interpolates between the two.
The KV Cache Problem
During autoregressive generation, the model produces one token at a time. To avoid recomputing attention over the entire sequence, the key and value tensors for all previous positions are cached (the KV cache).
For a model with layers, heads, head dimension , and sequence length , the KV cache stores:
| Mechanism | KV Cache Size (per layer) | Total KV Cache |
|---|---|---|
| MHA | ||
| MQA | ||
| GQA ( groups) |
For a 70B model with , , , and context in bf16:
- MHA: bytes = 20.0 GB
- GQA (g=8): bytes = 2.5 GB
- MQA: bytes = 0.3 GB
MHA's KV cache can exceed the model weight memory at long sequences. GQA with 8 groups reduces it by 8x. MQA reduces it by 64x.
Quality vs. Efficiency Tradeoff
The quality ordering is: MHA > GQA > MQA.
MQA forces all heads to attend with the same keys and values, reducing the model's ability to capture diverse attention patterns. Shazeer (2019) showed MQA quality loss is small (0.1-0.5% on benchmarks) but measurable.
GQA recovers most of the quality gap. Ainslie et al. (2023) showed that GQA with for a 64-head model matches MHA quality within measurement noise on most benchmarks while providing 8x KV cache reduction.
The key insight: most of the representational capacity comes from diverse queries. Keys and values are more redundant across heads. Sharing them sacrifices relatively little modeling power.
Side-by-Side Comparison
| Property | MHA | GQA | MQA |
|---|---|---|---|
| KV heads | (all independent) | groups () | 1 (all shared) |
| KV cache size | per layer | per layer | per layer |
| KV cache reduction | 1x (baseline) | times | times |
| Quality | Best (baseline) | Near-MHA | Slight degradation |
| Inference throughput | Slowest (memory-bound) | 2-4x faster than MHA | Fastest |
| Training cost | Baseline | Same | Same (fewer KV params) |
| Used in | BERT, GPT-2, GPT-3 | LLaMA-2-70B, Gemma, Mistral | PaLM, Falcon |
| Parameter count | per layer | per layer | per layer |
Inference Throughput: Why KV Cache Size Matters
Autoregressive generation is memory-bandwidth bound, not compute-bound. At each step, the model reads the entire KV cache to compute attention for one new token. The time per token is proportional to the amount of data read from memory.
For a single batch on an A100 (2 TB/s memory bandwidth), generating one token with a 70B MHA model at sequence length 8192 requires reading 20 GB of KV cache, taking ~10 ms just for memory transfer. GQA with reads 2.5 GB (~1.25 ms). MQA reads 0.3 GB (~0.15 ms).
Larger batch sizes multiply the KV cache linearly. Serving 32 concurrent requests with MHA at 8K context requires 640 GB of KV cache, far exceeding a single GPU. GQA makes large-batch serving feasible.
Converting MHA to GQA
Ainslie et al. (2023) showed that pretrained MHA models can be converted to GQA through a two-step process:
-
Construct grouped K, V: Average the K (and V) weight matrices within each group: .
-
Continue pretraining: Fine-tune with the grouped architecture for 5-10% of original pretraining tokens to recover quality.
This is cheaper than training GQA from scratch and allows upgrading existing MHA checkpoints. LLaMA-2-70B used this approach, starting from an MHA architecture and converting to GQA with .
Common Confusions
GQA does not reduce training compute
GQA reduces the number of K, V projection parameters, but the dominant training cost is the attention computation () and the feedforward layers, not the KV projection. Training FLOPs for GQA and MHA are nearly identical. The benefit is almost entirely at inference time, through smaller KV caches and higher throughput.
MQA is not the same as single-head attention
MQA has heads with independent query projections. Each head computes its own attention pattern using the shared keys. The model still captures different attention patterns. Single-head attention uses one Q, one K, one V and computes one attention pattern. MQA retains multi-head diversity in queries while sharing the key-value memory.
KV cache size does not depend on batch size linearly in all serving frameworks
Paged attention (vLLM) allocates KV cache in blocks, allowing non-contiguous memory. This reduces waste from padding and variable-length sequences, but the total KV cache for active requests still scales linearly with batch size and sequence length. GQA reduces the per-request cache, which compounds with batching.
The number of GQA groups is not always 8
is common (LLaMA-2-70B, Mistral-7B) but not universal. The optimal depends on the model size, target context length, and serving constraints. Smaller models with fewer heads may use or . The choice balances KV cache reduction against quality preservation.
References
- Vaswani, A. et al. (2017). "Attention Is All You Need." NeurIPS 2017. (Original multi-head attention.)
- Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150. (Multi-query attention, single shared KV.)
- Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. (Grouped-query attention, uptraining from MHA.)
- Touvron, H. et al. (2023). "Llama 2: Open Foundation and Fine-Tuned Chat Models." arXiv:2307.09288. (LLaMA-2-70B using GQA with 8 groups.)
- Pope, R. et al. (2023). "Efficiently Scaling Transformer Inference." MLSys 2023. (Analysis of memory-bandwidth bottleneck in autoregressive inference.)
- Kwon, W. et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023. (vLLM paged attention for KV cache management.)
- Jiang, A. Q. et al. (2023). "Mistral 7B." arXiv:2310.06825. (GQA in a 7B model with sliding window attention.)