The Key Fact
FlashAttention computes the exact same result as vanilla attention. It is not an approximation. The outputs are numerically identical (up to floating-point ordering).
The difference is how the computation is organized in GPU memory. Vanilla attention materializes intermediate matrices in high-bandwidth memory (HBM). FlashAttention avoids this by tiling the computation into on-chip SRAM and fusing the operations into a single kernel. The result is the same. The memory footprint and wall-clock time are dramatically different.
Standard Attention Computation
Vanilla Attention
Given queries , keys , and values :
This requires materializing and in HBM. For sequence length and 16-bit precision, each matrix is MB. With multiple heads and layers, this dominates GPU memory.
The bottleneck is not FLOPs. Modern GPUs have enormous compute throughput but limited memory bandwidth. The matrices must be written to HBM after the matmul, read back for the softmax, written again after softmax, and read again for the final matmul with . Each HBM read/write is slow relative to on-chip computation.
FlashAttention Computation
FlashAttention (Tiled, IO-Aware)
FlashAttention splits , , into blocks that fit in SRAM (typically 64-128 rows at a time). For each block of queries :
- Load into SRAM.
- Iterate over blocks of , :
- Compute in SRAM.
- Update running softmax statistics (max and sum) using the online softmax algorithm.
- Accumulate the weighted output incrementally.
- Write the final back to HBM.
The attention matrix is never materialized in HBM. Each block of is computed, used, and discarded in SRAM. The output is accumulated in a single streaming pass.
The Online Softmax Trick
The main algorithmic insight is computing softmax incrementally without seeing the full row of scores at once.
Standard softmax of a vector :
Computing this normally requires two passes: one to find and compute the sum, one to normalize. The online softmax algorithm (Milakov and Gimelshein, 2018) maintains running estimates of the max and the normalizer as blocks arrive:
When a new block of scores arrives with local max :
- Update:
- Rescale the existing normalizer:
- Add new contributions:
- Rescale the accumulated output similarly.
This produces the exact softmax result in a single pass over the data.
Memory Complexity
| Vanilla Attention | FlashAttention | |
|---|---|---|
| HBM for attention scores | (full and matrices) | (only block-sized intermediates) |
| HBM for output | ||
| Total HBM | ||
| SRAM usage | Minimal | per block, where is block size |
For , : vanilla stores MB for the attention matrix; FlashAttention stores nothing beyond the block currently in SRAM.
IO Complexity
Dao et al. (2022) analyze the algorithm in terms of HBM reads and writes, not FLOPs.
| Vanilla Attention | FlashAttention | |
|---|---|---|
| HBM reads/writes | ||
| FLOP count | (identical) | |
| Wall-clock speedup | Baseline | 2-4x for long sequences |
Here is the SRAM size. The key: FlashAttention does the same number of FLOPs but far fewer HBM accesses. Since modern GPUs are memory-bandwidth bound for attention, fewer HBM accesses means faster wall-clock time.
Comparison Table
| Property | Vanilla Attention | FlashAttention |
|---|---|---|
| Mathematical output | Exact softmax attention | Exact softmax attention (identical) |
| Memory for attention matrix | in HBM | Not stored in HBM |
| Peak memory scaling | Quadratic in sequence length | Linear in sequence length |
| FLOP count | (same) | |
| IO complexity | ||
| Kernel fusion | Separate matmul, softmax, matmul kernels | Single fused kernel |
| Backward pass | Store for gradient computation | Recompute from , blocks (no storage) |
| Dropout support | Store dropout mask () | Recompute mask from RNG state |
| Custom masking | Straightforward | Requires block-aware mask handling |
| Implementation complexity | Simple (two matmuls + softmax) | Requires custom CUDA kernels |
| Framework support | All frameworks | PyTorch 2.0+, xformers, FlashAttention library |
Where Each Is Stronger
Vanilla attention wins on simplicity
Vanilla attention is three operations: two matrix multiplications and a softmax. Any deep learning framework implements it in a few lines. Debugging is straightforward. Custom attention patterns (sparse masks, relative position biases) are easy to add.
FlashAttention wins on everything else
For any sequence length beyond a few hundred tokens, FlashAttention is strictly better in memory and usually better in wall-clock time. The crossover point depends on GPU architecture, but on A100 GPUs, FlashAttention is faster for .
The memory savings are what matter most. Vanilla attention with requires 512 MB per head for the attention matrix alone. FlashAttention makes this feasible on a single GPU. This is why long-context models (32k, 128k, 1M tokens) exist at all.
Where Each Fails
Vanilla attention fails at scale
At , the attention matrix is 32 MB per head (16-bit). At , it is 8 GB per head. Vanilla attention hits an OOM wall long before compute becomes the bottleneck.
FlashAttention fails on custom patterns
Arbitrary sparse attention masks, cross-attention with complex routing, or attention variants that modify the softmax denominator require careful adaptation. FlashAttention-2 and FlashAttention-3 added support for many patterns, but novel architectures may need to fall back to vanilla attention or write custom tiling logic.
Common Confusions
FlashAttention is NOT an approximation
This is the most common misconception. Linear attention, Performer, Linformer: these are approximations that change the output. FlashAttention computes the exact same output as standard softmax attention. It is a systems optimization, not a mathematical approximation. The softmax is computed exactly via the online algorithm.
FlashAttention does not reduce FLOPs
The FLOP count is identical: for both. The speedup comes entirely from reduced memory bandwidth usage. Fewer HBM reads and writes means less time waiting for data transfer. On memory-bandwidth-bound hardware (which modern GPUs are for attention), this translates directly to wall-clock speedup.
The backward pass also benefits
In vanilla attention, the backward pass requires the stored attention matrix to compute gradients. FlashAttention recomputes block-by-block from and during the backward pass, trading a small amount of extra compute for massive memory savings. This recomputation is a net win because the saved memory allows larger batch sizes.
References
- Dao, T., Fu, D.Y., Ermon, S., Rudra, A., and Re, C. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. Sections 3-4 derive the tiling algorithm and IO complexity bounds.
- Vaswani, A. et al. "Attention Is All You Need." NeurIPS 2017. Section 3.2 defines scaled dot-product attention.
- Milakov, M. and Gimelshein, N. "Online Normalizer Calculation for Softmax." 2018. The online softmax algorithm that FlashAttention depends on.
- Dao, T. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. Sections 3-4 improve parallelism and add causal masking.
- Rabe, M.N. and Staats, C. "Self-Attention Does Not Need Memory." 2022. Independent derivation of memory-efficient exact attention.
- Dao, T., Haziza, D., Massa, F., and Sizov, G. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." 2024. Hopper GPU optimizations and FP8 support.
- Jia, Z. et al. "Dissecting the NVIDIA Hopper GPU Architecture via Microbenchmarking." 2023. HBM vs SRAM bandwidth measurements that explain FlashAttention speedups.