Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

Comparison

FlashAttention vs. Vanilla Attention

FlashAttention and vanilla attention compute the exact same output. The difference is entirely in IO complexity: vanilla materializes the full n x n attention matrix in GPU HBM, while FlashAttention tiles the computation into SRAM blocks using an online softmax trick, reducing memory from O(n^2) to O(n) and achieving 2-4x wall-clock speedup.

The Key Fact

FlashAttention computes the exact same result as vanilla attention. It is not an approximation. The outputs are numerically identical (up to floating-point ordering).

The difference is how the computation is organized in GPU memory. Vanilla attention materializes intermediate matrices in high-bandwidth memory (HBM). FlashAttention avoids this by tiling the computation into on-chip SRAM and fusing the operations into a single kernel. The result is the same. The memory footprint and wall-clock time are dramatically different.

Standard Attention Computation

Definition

Vanilla Attention

Given queries QRn×dQ \in \mathbb{R}^{n \times d}, keys KRn×dK \in \mathbb{R}^{n \times d}, and values VRn×dV \in \mathbb{R}^{n \times d}:

S=QKRn×nS = QK^\top \in \mathbb{R}^{n \times n} P=softmax(S/d)Rn×nP = \text{softmax}(S / \sqrt{d}) \in \mathbb{R}^{n \times n} O=PVRn×dO = PV \in \mathbb{R}^{n \times d}

This requires materializing SS and PP in HBM. For sequence length n=8192n = 8192 and 16-bit precision, each matrix is 81922×2=1288192^2 \times 2 = 128 MB. With multiple heads and layers, this dominates GPU memory.

The bottleneck is not FLOPs. Modern GPUs have enormous compute throughput but limited memory bandwidth. The n×nn \times n matrices must be written to HBM after the matmul, read back for the softmax, written again after softmax, and read again for the final matmul with VV. Each HBM read/write is slow relative to on-chip computation.

FlashAttention Computation

Definition

FlashAttention (Tiled, IO-Aware)

FlashAttention splits QQ, KK, VV into blocks that fit in SRAM (typically 64-128 rows at a time). For each block of queries QiQ_i:

  1. Load QiQ_i into SRAM.
  2. Iterate over blocks of KjK_j, VjV_j:
    • Compute Sij=QiKjS_{ij} = Q_i K_j^\top in SRAM.
    • Update running softmax statistics (max and sum) using the online softmax algorithm.
    • Accumulate the weighted output OiO_i incrementally.
  3. Write the final OiO_i back to HBM.

The n×nn \times n attention matrix is never materialized in HBM. Each block of SS is computed, used, and discarded in SRAM. The output OO is accumulated in a single streaming pass.

The Online Softmax Trick

The main algorithmic insight is computing softmax incrementally without seeing the full row of scores at once.

Standard softmax of a vector xRnx \in \mathbb{R}^n:

softmax(xi)=eximj=1nexjm,m=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{n} e^{x_j - m}}, \quad m = \max_j x_j

Computing this normally requires two passes: one to find mm and compute the sum, one to normalize. The online softmax algorithm (Milakov and Gimelshein, 2018) maintains running estimates of the max mm and the normalizer \ell as blocks arrive:

When a new block of scores arrives with local max mnewm_{\text{new}}:

This produces the exact softmax result in a single pass over the data.

Memory Complexity

Vanilla AttentionFlashAttention
HBM for attention scoresO(n2)O(n^2) (full SS and PP matrices)O(n)O(n) (only block-sized intermediates)
HBM for outputO(nd)O(nd)O(nd)O(nd)
Total HBMO(n2+nd)O(n^2 + nd)O(nd)O(nd)
SRAM usageMinimalO(B2)O(B^2) per block, where BB is block size

For n=8192n = 8192, d=128d = 128: vanilla stores 128\sim 128 MB for the attention matrix; FlashAttention stores nothing beyond the block currently in SRAM.

IO Complexity

Dao et al. (2022) analyze the algorithm in terms of HBM reads and writes, not FLOPs.

Vanilla AttentionFlashAttention
HBM reads/writesO(n2d+n2)O(n^2 d + n^2)O(n2d2/M)O(n^2 d^2 / M)
FLOP countO(n2d)O(n^2 d)O(n2d)O(n^2 d) (identical)
Wall-clock speedupBaseline2-4x for long sequences

Here MM is the SRAM size. The key: FlashAttention does the same number of FLOPs but far fewer HBM accesses. Since modern GPUs are memory-bandwidth bound for attention, fewer HBM accesses means faster wall-clock time.

Comparison Table

PropertyVanilla AttentionFlashAttention
Mathematical outputExact softmax attentionExact softmax attention (identical)
Memory for attention matrixO(n2)O(n^2) in HBMNot stored in HBM
Peak memory scalingQuadratic in sequence lengthLinear in sequence length
FLOP countO(n2d)O(n^2 d)O(n2d)O(n^2 d) (same)
IO complexityO(n2d+n2)O(n^2 d + n^2)O(n2d2/M)O(n^2 d^2 / M)
Kernel fusionSeparate matmul, softmax, matmul kernelsSingle fused kernel
Backward passStore PP for gradient computationRecompute PP from QQ, KK blocks (no storage)
Dropout supportStore dropout mask (O(n2)O(n^2))Recompute mask from RNG state
Custom maskingStraightforwardRequires block-aware mask handling
Implementation complexitySimple (two matmuls + softmax)Requires custom CUDA kernels
Framework supportAll frameworksPyTorch 2.0+, xformers, FlashAttention library

Where Each Is Stronger

Vanilla attention wins on simplicity

Vanilla attention is three operations: two matrix multiplications and a softmax. Any deep learning framework implements it in a few lines. Debugging is straightforward. Custom attention patterns (sparse masks, relative position biases) are easy to add.

FlashAttention wins on everything else

For any sequence length beyond a few hundred tokens, FlashAttention is strictly better in memory and usually better in wall-clock time. The crossover point depends on GPU architecture, but on A100 GPUs, FlashAttention is faster for n256n \geq 256.

The memory savings are what matter most. Vanilla attention with n=16384n = 16384 requires 512 MB per head for the attention matrix alone. FlashAttention makes this feasible on a single GPU. This is why long-context models (32k, 128k, 1M tokens) exist at all.

Where Each Fails

Vanilla attention fails at scale

At n=4096n = 4096, the attention matrix is 32 MB per head (16-bit). At n=65536n = 65536, it is 8 GB per head. Vanilla attention hits an OOM wall long before compute becomes the bottleneck.

FlashAttention fails on custom patterns

Arbitrary sparse attention masks, cross-attention with complex routing, or attention variants that modify the softmax denominator require careful adaptation. FlashAttention-2 and FlashAttention-3 added support for many patterns, but novel architectures may need to fall back to vanilla attention or write custom tiling logic.

Common Confusions

Watch Out

FlashAttention is NOT an approximation

This is the most common misconception. Linear attention, Performer, Linformer: these are approximations that change the output. FlashAttention computes the exact same output as standard softmax attention. It is a systems optimization, not a mathematical approximation. The softmax is computed exactly via the online algorithm.

Watch Out

FlashAttention does not reduce FLOPs

The FLOP count is identical: O(n2d)O(n^2 d) for both. The speedup comes entirely from reduced memory bandwidth usage. Fewer HBM reads and writes means less time waiting for data transfer. On memory-bandwidth-bound hardware (which modern GPUs are for attention), this translates directly to wall-clock speedup.

Watch Out

The backward pass also benefits

In vanilla attention, the backward pass requires the stored n×nn \times n attention matrix PP to compute gradients. FlashAttention recomputes PP block-by-block from QQ and KK during the backward pass, trading a small amount of extra compute for massive memory savings. This recomputation is a net win because the saved memory allows larger batch sizes.

References

  1. Dao, T., Fu, D.Y., Ermon, S., Rudra, A., and Re, C. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. Sections 3-4 derive the tiling algorithm and IO complexity bounds.
  2. Vaswani, A. et al. "Attention Is All You Need." NeurIPS 2017. Section 3.2 defines scaled dot-product attention.
  3. Milakov, M. and Gimelshein, N. "Online Normalizer Calculation for Softmax." 2018. The online softmax algorithm that FlashAttention depends on.
  4. Dao, T. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. Sections 3-4 improve parallelism and add causal masking.
  5. Rabe, M.N. and Staats, C. "Self-Attention Does Not Need O(n2)O(n^2) Memory." 2022. Independent derivation of memory-efficient exact attention.
  6. Dao, T., Haziza, D., Massa, F., and Sizov, G. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." 2024. Hopper GPU optimizations and FP8 support.
  7. Jia, Z. et al. "Dissecting the NVIDIA Hopper GPU Architecture via Microbenchmarking." 2023. HBM vs SRAM bandwidth measurements that explain FlashAttention speedups.