Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

LLM Construction

Flash Attention

IO-aware exact attention: tile QKV matrices into SRAM-sized blocks, compute attention without materializing the full attention matrix in HBM, reducing memory reads/writes from quadratic to linear.

AdvancedTier 2Current~55 min

Why This Matters

Flash Attention: Tiled ComputationStandard (slow)N x Nattention matrixMaterialized in HBMO(N^2) memoryvsFlash Attention (fast)Q (tiled)Q block iK, V (tiled)K,V block jSRAM (on-chip, fast)Compute Q_i * K_j^T, softmax, * V_j in tilesHBM (off-chip, slow)Never materialize the N x N matrix. Compute attention block-by-block in SRAM. O(N) memory, same result.Standard: O(N^2 d) bytes to HBM. Flash: O(N^2 d^2 / M) bytes. Dao 2022 reports up to 7.6x on GPT-2 attention kernels; speedup grows with sequence length.

Standard transformer attention computes softmax(QK/d)V\text{softmax}(QK^\top / \sqrt{d})V by materializing the full N×NN \times N attention matrix in GPU high-bandwidth memory (HBM). For sequence length N=128,000N = 128{,}000, this matrix has 1616 billion entries in FP16: roughly 32 GB. This does not fit in SRAM (tens of MB) and barely fits in HBM.

Flash Attention computes the exact same output without ever forming the full attention matrix. It does this by tiling the computation so that each block fits in SRAM, performing all the work there, and writing only the final result back to HBM. The output is mathematically identical. The speedup comes entirely from reducing memory traffic.

Mental Model

Think of attention as a matrix multiply with a softmax in the middle. Standard implementations write the intermediate N×NN \times N matrix to HBM, then read it back to apply softmax, then write softmax output to HBM, then read it back to multiply by VV. Each of these read/write round-trips is expensive.

Flash Attention keeps intermediate results in SRAM (fast, small memory) and never writes the N×NN \times N matrix to HBM at all.

The IO Bottleneck

Definition

Arithmetic Intensity

The ratio of floating-point operations to bytes transferred between HBM and compute units. When arithmetic intensity is low, the operation is memory-bound: the GPU spends most of its time waiting for data, not computing.

Standard attention has arithmetic intensity O(1)O(1) for the softmax step: each element requires a few FLOPs but must be read from and written to HBM. The matrix multiply steps (QKQK^\top and softmaxV\text{softmax} \cdot V) have higher arithmetic intensity but the softmax bottleneck dominates wall-clock time.

The Tiling Algorithm

The key idea: partition QQ, KK, VV into blocks that fit in SRAM.

  1. Divide QQ into blocks Q1,,QTrQ_1, \ldots, Q_{T_r} of size Br×dB_r \times d.
  2. Divide KK and VV into blocks K1,,KTcK_1, \ldots, K_{T_c} and V1,,VTcV_1, \ldots, V_{T_c} of size Bc×dB_c \times d.
  3. For each query block QiQ_i: iterate over all key-value blocks (Kj,Vj)(K_j, V_j), computing partial attention scores in SRAM.
  4. Accumulate the output using the online softmax trick (see below).

The block sizes BrB_r and BcB_c are chosen so that QiQ_i, KjK_j, VjV_j, and the partial output all fit simultaneously in SRAM.

Online Softmax

The challenge with tiling is that softmax requires the full row of scores si=QiK/ds_i = Q_i K^\top / \sqrt{d} to compute the normalizer. You cannot compute softmax(si)\text{softmax}(s_i) from blocks independently.

Definition

Online Softmax

Maintain running statistics (m,)(m, \ell) where mm is the running row-wise maximum and \ell is the running sum of exponentials. When processing a new block of scores sijs_{ij}:

mnew=max(mold,max(sij))m_{\text{new}} = \max(m_{\text{old}}, \max(s_{ij}))

new=oldemoldmnew+kesijkmnew\ell_{\text{new}} = \ell_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \sum_k e^{s_{ijk} - m_{\text{new}}}

The output accumulator is rescaled by emoldmnewe^{m_{\text{old}} - m_{\text{new}}} at each step to account for the updated maximum.

This produces the exact same result as computing softmax over the entire row at once.

Main Theorems

Theorem

Flash Attention IO Complexity

Statement

Standard attention requires Θ(Nd+N2)\Theta(Nd + N^2) HBM accesses. Flash Attention requires Θ(N2d2/M)\Theta(N^2 d^2 / M) HBM accesses.

Intuition

Each SRAM-sized tile of the computation is loaded once and fully processed before moving on. The total number of tiles is Θ(N2/(BrBc))\Theta(N^2 / (B_r B_c)) and each tile loads Θ((Br+Bc)d)\Theta((B_r + B_c)d) bytes. The product gives the total IO.

Proof Sketch

There are Tr=N/BrT_r = N/B_r query blocks and Tc=N/BcT_c = N/B_c key-value blocks. The outer loop iterates over TrTc=N2/(BrBc)T_r \cdot T_c = N^2 / (B_r B_c) block pairs. Each pair loads Brd+2BcdB_r d + 2 B_c d elements from HBM. Total HBM reads: N2d(Br+2Bc)/(BrBc)N^2 d (B_r + 2B_c) / (B_r B_c). Setting Bc=M/(4d)B_c = M / (4d) and Br=min(M/(4d),d)B_r = \min(M/(4d), d) gives the stated Θ(N2d2/M)\Theta(N^2 d^2 / M) bound.

Why It Matters

In the operational regime, MM is fixed by hardware (for example, the H100 has about 228 KB of SRAM per SM) while NN grows. For practical long-context workloads, NdNd is far larger than MM. At N=100,000N = 100{,}000, d=128d = 128, FP16, the array NdNd occupies about 25 MB, roughly 100×100\times the SRAM budget. Both standard attention and Flash Attention therefore remain Ω(N2)\Omega(N^2) in HBM traffic as NN \to \infty with MM fixed. The win is a large constant-factor reduction of memory traffic by Θ(N2/(N2d2/M))=Θ(M/d2)\Theta(N^2 / (N^2 d^2 / M)) = \Theta(M / d^2). For M=228,000M = 228{,}000 bytes and d=128d = 128 in FP16, this constant factor is on the order of 1010 to 30×30\times fewer HBM bytes moved, which is enough to shift attention from memory-bound to closer to compute-bound on current accelerators.

Failure Mode

The IO advantage diminishes when dd is very large relative to MM, because fewer elements fit in SRAM per tile. Also, if the attention pattern is sparse (most entries near zero), sparse attention methods from the efficient transformers survey may achieve even lower IO by skipping blocks entirely. Flash Attention computes exact dense attention and cannot exploit sparsity. A final pitfall: claims of "linear in NN" HBM traffic require M=Θ(Nd)M = \Theta(Nd), which is physically false on current GPUs once NN exceeds a few thousand.

Proposition

Online Softmax Equivalence

Statement

The online softmax algorithm with running maximum and denominator produces the same output as the two-pass softmax algorithm (first pass to find max and sum, second pass to normalize), up to floating-point rounding.

Intuition

The rescaling factor emoldmnewe^{m_{\text{old}} - m_{\text{new}}} exactly compensates for having used a stale maximum in earlier blocks. The algebra telescopes: each partial sum, when rescaled, equals what it would have been if computed with the global maximum from the start.

Proof Sketch

By induction on the number of blocks. The base case (one block) is trivial. For the inductive step, the rescaled accumulator after k+1k+1 blocks equals j=1k+1iesjimk+1\sum_{j=1}^{k+1} \sum_i e^{s_{ji} - m_{k+1}}, which is the same as the full-row computation with maximum mk+1m_{k+1}.

Why It Matters

Without online softmax, tiled attention would produce approximate results. This proposition guarantees exact equivalence, which means Flash Attention is a pure systems optimization with zero accuracy cost.

Failure Mode

Floating-point rounding order differs between the tiled and standard implementations, so outputs may differ at the level of machine epsilon. This is not a mathematical failure but can cause bitwise non-determinism in practice.

Flash Attention 2 and 3

Flash Attention 2 (Dao 2023) improves work partitioning. The key changes: swap the loop order so the outer loop is over query blocks (better parallelism across warps), reduce non-matmul FLOPs, and partition work across warps within a thread block more evenly. Result: roughly 2×2\times speedup over Flash Attention 1.

Flash Attention 3 (Shah et al. 2024) targets Hopper GPUs (H100). Key ideas: asynchronous memory copies via TMA (tensor memory accelerator), FP8 quantized attention for further throughput gains, and warp specialization for overlapping computation with data movement. Shah et al. report about 740 TFLOP/s in FP16 on H100, roughly 75% of the 989 TFLOP/s FP16 peak. This is a large jump over Flash Attention 2 (about 35% utilization on the same hardware) but still below what dense GEMM kernels reach (typically 90%+ of peak). Framing FA3 as "near peak" overstates it. Framing it as "75% of peak, nearly double FA2 on H100" is accurate.

Common Confusions

Watch Out

Flash Attention does not approximate attention

Flash Attention computes exact standard attention. It is not an approximation scheme like Linformer, Performer, or random feature attention. The output is identical (up to floating-point rounding) to naive attention. The improvement is purely in IO efficiency.

Watch Out

Flash Attention reduces memory, not FLOPs

Flash Attention actually performs more total FLOPs than standard attention (due to recomputation in the backward pass). It is faster because wall-clock time for attention is dominated by memory access time, not compute time. Reducing IO is what matters.

Watch Out

SRAM is not programmer-visible shared memory alone

On NVIDIA GPUs, SRAM refers to the shared memory within each streaming multiprocessor. Its size (typically 48-228 KB per SM) determines the maximum block size in Flash Attention. This is distinct from L2 cache, which is larger but not directly addressable by the programmer.

Summary

  • Standard attention is memory-bound: the bottleneck is HBM reads/writes, not FLOPs
  • Flash Attention tiles QKV into SRAM-sized blocks and never materializes the full N×NN \times N attention matrix in HBM
  • Online softmax enables exact tiled computation by maintaining running statistics
  • IO traffic drops from Θ(N2+Nd)\Theta(N^2 + Nd) to Θ(N2d2/M)\Theta(N^2 d^2 / M) HBM bytes, a constant-factor Θ(M/d2)\Theta(M / d^2) reduction for fixed hardware MM
  • Flash Attention 2 improves parallelism; Flash Attention 3 adds asynchrony and FP8

Exercises

ExerciseCore

Problem

A GPU has 192 KB of SRAM per SM and uses FP16 (2 bytes per element). The head dimension is d=128d = 128. What is the maximum block size BcB_c for key and value blocks, assuming we need to fit KjK_j (Bc×dB_c \times d) and VjV_j (Bc×dB_c \times d) simultaneously in SRAM with half the SRAM reserved for other use?

ExerciseAdvanced

Problem

Prove that the online softmax rescaling is exact. Specifically, show that after processing blocks 1,,k1, \ldots, k, the accumulated output OkO_k equals j=1ksoftmax(Qi[K1,,Kk]/d):,blockjVj\sum_{j=1}^{k} \text{softmax}(Q_i [K_1, \ldots, K_k]^\top / \sqrt{d})_{:, \text{block}_j} V_j where the softmax is computed over all kk blocks jointly.

Related Comparisons

References

Canonical:

  • Dao, Fu, Ermon, Rudra, Re, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022, arXiv:2205.14135), Sections 3-4 for the IO analysis and tiling algorithm
  • Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023, arXiv:2307.08691), Sections 3.1-3.2 for the loop-order and warp-partitioning changes

Current:

  • Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao, FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (2024, arXiv:2407.08608), Section 3 for warp specialization and 4 for FP8
  • Milakov & Gimelshein, Online normalizer calculation for softmax (2018, arXiv:1805.02867) for the online softmax recursion
  • Rabe & Staats, Self-attention Does Not Need O(n2)O(n^2) Memory (2021, arXiv:2112.05682), the memory-efficient attention algorithm that precedes FlashAttention
  • NVIDIA, H100 Tensor Core GPU Architecture Whitepaper (2022), Sections on SM SRAM, TMA, and FP16/FP8 tensor cores, for the hardware numbers cited above
  • Jurafsky & Martin, Speech and Language Processing (3rd ed., draft), Chapter 9 for the transformer attention background used by this page

Next Topics

  • Fused kernels: Flash Attention is itself a fused kernel; understanding kernel fusion explains why tiling helps

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Builds on This

Next Topics