LLM Construction
Flash Attention
IO-aware exact attention: tile QKV matrices into SRAM-sized blocks, compute attention without materializing the full attention matrix in HBM, reducing memory reads/writes from quadratic to linear.
Why This Matters
Standard transformer attention computes by materializing the full attention matrix in GPU high-bandwidth memory (HBM). For sequence length , this matrix has billion entries in FP16: roughly 32 GB. This does not fit in SRAM (tens of MB) and barely fits in HBM.
Flash Attention computes the exact same output without ever forming the full attention matrix. It does this by tiling the computation so that each block fits in SRAM, performing all the work there, and writing only the final result back to HBM. The output is mathematically identical. The speedup comes entirely from reducing memory traffic.
Mental Model
Think of attention as a matrix multiply with a softmax in the middle. Standard implementations write the intermediate matrix to HBM, then read it back to apply softmax, then write softmax output to HBM, then read it back to multiply by . Each of these read/write round-trips is expensive.
Flash Attention keeps intermediate results in SRAM (fast, small memory) and never writes the matrix to HBM at all.
The IO Bottleneck
Arithmetic Intensity
The ratio of floating-point operations to bytes transferred between HBM and compute units. When arithmetic intensity is low, the operation is memory-bound: the GPU spends most of its time waiting for data, not computing.
Standard attention has arithmetic intensity for the softmax step: each element requires a few FLOPs but must be read from and written to HBM. The matrix multiply steps ( and ) have higher arithmetic intensity but the softmax bottleneck dominates wall-clock time.
The Tiling Algorithm
The key idea: partition , , into blocks that fit in SRAM.
- Divide into blocks of size .
- Divide and into blocks and of size .
- For each query block : iterate over all key-value blocks , computing partial attention scores in SRAM.
- Accumulate the output using the online softmax trick (see below).
The block sizes and are chosen so that , , , and the partial output all fit simultaneously in SRAM.
Online Softmax
The challenge with tiling is that softmax requires the full row of scores to compute the normalizer. You cannot compute from blocks independently.
Online Softmax
Maintain running statistics where is the running row-wise maximum and is the running sum of exponentials. When processing a new block of scores :
The output accumulator is rescaled by at each step to account for the updated maximum.
This produces the exact same result as computing softmax over the entire row at once.
Main Theorems
Flash Attention IO Complexity
Statement
Standard attention requires HBM accesses. Flash Attention requires HBM accesses.
Intuition
Each SRAM-sized tile of the computation is loaded once and fully processed before moving on. The total number of tiles is and each tile loads bytes. The product gives the total IO.
Proof Sketch
There are query blocks and key-value blocks. The outer loop iterates over block pairs. Each pair loads elements from HBM. Total HBM reads: . Setting and gives the stated bound.
Why It Matters
In the operational regime, is fixed by hardware (for example, the H100 has about 228 KB of SRAM per SM) while grows. For practical long-context workloads, is far larger than . At , , FP16, the array occupies about 25 MB, roughly the SRAM budget. Both standard attention and Flash Attention therefore remain in HBM traffic as with fixed. The win is a large constant-factor reduction of memory traffic by . For bytes and in FP16, this constant factor is on the order of to fewer HBM bytes moved, which is enough to shift attention from memory-bound to closer to compute-bound on current accelerators.
Failure Mode
The IO advantage diminishes when is very large relative to , because fewer elements fit in SRAM per tile. Also, if the attention pattern is sparse (most entries near zero), sparse attention methods from the efficient transformers survey may achieve even lower IO by skipping blocks entirely. Flash Attention computes exact dense attention and cannot exploit sparsity. A final pitfall: claims of "linear in " HBM traffic require , which is physically false on current GPUs once exceeds a few thousand.
Online Softmax Equivalence
Statement
The online softmax algorithm with running maximum and denominator produces the same output as the two-pass softmax algorithm (first pass to find max and sum, second pass to normalize), up to floating-point rounding.
Intuition
The rescaling factor exactly compensates for having used a stale maximum in earlier blocks. The algebra telescopes: each partial sum, when rescaled, equals what it would have been if computed with the global maximum from the start.
Proof Sketch
By induction on the number of blocks. The base case (one block) is trivial. For the inductive step, the rescaled accumulator after blocks equals , which is the same as the full-row computation with maximum .
Why It Matters
Without online softmax, tiled attention would produce approximate results. This proposition guarantees exact equivalence, which means Flash Attention is a pure systems optimization with zero accuracy cost.
Failure Mode
Floating-point rounding order differs between the tiled and standard implementations, so outputs may differ at the level of machine epsilon. This is not a mathematical failure but can cause bitwise non-determinism in practice.
Flash Attention 2 and 3
Flash Attention 2 (Dao 2023) improves work partitioning. The key changes: swap the loop order so the outer loop is over query blocks (better parallelism across warps), reduce non-matmul FLOPs, and partition work across warps within a thread block more evenly. Result: roughly speedup over Flash Attention 1.
Flash Attention 3 (Shah et al. 2024) targets Hopper GPUs (H100). Key ideas: asynchronous memory copies via TMA (tensor memory accelerator), FP8 quantized attention for further throughput gains, and warp specialization for overlapping computation with data movement. Shah et al. report about 740 TFLOP/s in FP16 on H100, roughly 75% of the 989 TFLOP/s FP16 peak. This is a large jump over Flash Attention 2 (about 35% utilization on the same hardware) but still below what dense GEMM kernels reach (typically 90%+ of peak). Framing FA3 as "near peak" overstates it. Framing it as "75% of peak, nearly double FA2 on H100" is accurate.
Common Confusions
Flash Attention does not approximate attention
Flash Attention computes exact standard attention. It is not an approximation scheme like Linformer, Performer, or random feature attention. The output is identical (up to floating-point rounding) to naive attention. The improvement is purely in IO efficiency.
Flash Attention reduces memory, not FLOPs
Flash Attention actually performs more total FLOPs than standard attention (due to recomputation in the backward pass). It is faster because wall-clock time for attention is dominated by memory access time, not compute time. Reducing IO is what matters.
SRAM is not programmer-visible shared memory alone
On NVIDIA GPUs, SRAM refers to the shared memory within each streaming multiprocessor. Its size (typically 48-228 KB per SM) determines the maximum block size in Flash Attention. This is distinct from L2 cache, which is larger but not directly addressable by the programmer.
Summary
- Standard attention is memory-bound: the bottleneck is HBM reads/writes, not FLOPs
- Flash Attention tiles QKV into SRAM-sized blocks and never materializes the full attention matrix in HBM
- Online softmax enables exact tiled computation by maintaining running statistics
- IO traffic drops from to HBM bytes, a constant-factor reduction for fixed hardware
- Flash Attention 2 improves parallelism; Flash Attention 3 adds asynchrony and FP8
Exercises
Problem
A GPU has 192 KB of SRAM per SM and uses FP16 (2 bytes per element). The head dimension is . What is the maximum block size for key and value blocks, assuming we need to fit () and () simultaneously in SRAM with half the SRAM reserved for other use?
Problem
Prove that the online softmax rescaling is exact. Specifically, show that after processing blocks , the accumulated output equals where the softmax is computed over all blocks jointly.
Related Comparisons
References
Canonical:
- Dao, Fu, Ermon, Rudra, Re, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022, arXiv:2205.14135), Sections 3-4 for the IO analysis and tiling algorithm
- Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023, arXiv:2307.08691), Sections 3.1-3.2 for the loop-order and warp-partitioning changes
Current:
- Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao, FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (2024, arXiv:2407.08608), Section 3 for warp specialization and 4 for FP8
- Milakov & Gimelshein, Online normalizer calculation for softmax (2018, arXiv:1805.02867) for the online softmax recursion
- Rabe & Staats, Self-attention Does Not Need Memory (2021, arXiv:2112.05682), the memory-efficient attention algorithm that precedes FlashAttention
- NVIDIA, H100 Tensor Core GPU Architecture Whitepaper (2022), Sections on SM SRAM, TMA, and FP16/FP8 tensor cores, for the hardware numbers cited above
- Jurafsky & Martin, Speech and Language Processing (3rd ed., draft), Chapter 9 for the transformer attention background used by this page
Next Topics
- Fused kernels: Flash Attention is itself a fused kernel; understanding kernel fusion explains why tiling helps
Last reviewed: April 2026
Prerequisites
Foundations this topic depends on.
- Attention Mechanism TheoryLayer 4
- Matrix Operations and PropertiesLayer 0A
- Sets, Functions, and RelationsLayer 0A
- Basic Logic and Proof TechniquesLayer 0A
- Softmax and Numerical StabilityLayer 1