Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

LLM Construction

Fused Kernels

Combine multiple GPU operations into a single kernel launch to eliminate intermediate HBM reads and writes. Why kernel fusion is the primary optimization technique for memory-bound ML operations.

AdvancedTier 2Current~40 min

Prerequisites

0

Why This Matters

Consider a sequence of three element-wise operations: residual addition, layer normalization, and dropout. Without fusion, each operation is a separate GPU kernel. Each kernel reads its input from HBM, computes, and writes its output back to HBM. The next kernel reads that output from HBM to start its work. Three kernels produce three round-trips to HBM.

With fusion, all three operations are combined into a single kernel. The input is loaded from HBM once, all three operations are applied in registers or SRAM, and only the final result is written back. Two intermediate HBM round-trips are eliminated entirely.

Since these operations are memory-bound (arithmetic intensity near 1), eliminating HBM traffic translates directly into wall-clock speedup. Kernel fusion is how Flash Attention, xFormers, and most production ML inference engines achieve their performance.

Mental Model

Think of each GPU kernel as a factory worker who can only communicate by placing items on a shared conveyor belt (HBM). Without fusion, worker A puts intermediate results on the belt, worker B picks them up, processes them, and puts new results on the belt, worker C picks those up. Most time is spent loading and unloading the belt.

With fusion, a single worker does all three jobs internally, touching the belt only for the original input and final output.

The Problem: Unfused Operations

Definition

GPU Kernel

A function compiled for and launched on the GPU. Each kernel has a launch overhead (CPU dispatch, scheduling) and performs its own HBM reads and writes. The GPU kernel is the unit of work in CUDA and similar frameworks.

A typical transformer forward pass without fusion:

  1. QKV projection: read input from HBM, matmul, write QKV to HBM
  2. Attention scores: read Q, K from HBM, matmul, write scores to HBM
  3. Softmax: read scores from HBM, compute softmax, write to HBM
  4. Attention output: read softmax output and V from HBM, matmul, write to HBM
  5. Residual add: read attention output and residual from HBM, add, write to HBM
  6. Layer norm: read from HBM, normalize, write to HBM
  7. FFN: read from HBM, two matmuls with activation, write to HBM

Steps 2-4 involve three separate HBM round-trips for attention alone. Steps 5-6 involve two more. Each round-trip is wasted time because the intermediate values are consumed immediately and then discarded.

Main Theorems

Proposition

IO Reduction from Kernel Fusion

Statement

Without fusion, the total HBM traffic is 2kN2kN elements (each of kk operations reads NN and writes NN). With full fusion into a single kernel, the total HBM traffic is 2N2N elements (one read of the original input, one write of the final output). The IO reduction factor is kk.

Intuition

Fusion eliminates k1k-1 intermediate writes and k1k-1 intermediate reads. The only HBM access that remains is loading the initial input and storing the final output. Everything in between stays in registers or SRAM.

Proof Sketch

Each unfused kernel performs one read and one write of NN elements: total =k×2N=2kN= k \times 2N = 2kN. The fused kernel loads NN elements, applies all kk operations in fast memory, and writes NN elements: total =2N= 2N. Ratio: 2kN/2N=k2kN / 2N = k.

Why It Matters

For a chain of 5 element-wise operations (common in transformer blocks: residual add, layer norm mean, layer norm variance, normalize, scale and shift), fusion reduces HBM traffic by 5×5\times. Since these operations are memory-bound, this translates to roughly 5×5\times wall-clock speedup for this portion of the computation.

Failure Mode

Fusion does not help when: (1) the intermediate results are needed later by other kernels (not just the next operation), (2) the fused kernel requires more registers or SRAM than available per thread block, reducing occupancy, or (3) the operations are compute-bound rather than memory-bound (fusing compute-bound operations does not eliminate a bottleneck).

Examples of Fused Kernels in Practice

Fused attention (Flash Attention): the softmax, scaling, and matmul with V are fused into a single tiled kernel. This is the biggest single fusion in modern transformers, reducing attention from O(N2)O(N^2) to O(N)O(N) HBM accesses.

Fused layer norm + residual add: instead of writing the residual sum to HBM and reading it back for layer norm, combine both into one kernel. This eliminates one full tensor read-write cycle.

Fused activation + multiply (SwiGLU): the SwiGLU activation SwiGLU(x,W1,W2)=(swish(xW1))(xW2)\text{SwiGLU}(x, W_1, W_2) = (\text{swish}(xW_1)) \odot (xW_2) involves an element-wise activation and element-wise multiplication. Fusing these avoids writing the intermediate swish output to HBM.

Fused softmax + cross-entropy loss: instead of computing softmax, writing probabilities to HBM, then reading them back to compute the loss, do both in one kernel. This also avoids materializing the full probability vector, which saves memory for large vocabularies.

Writing Fused Kernels with Triton

CUDA requires writing kernels in C/C++ with explicit thread and memory management. Triton (OpenAI) provides a Python-based DSL for writing GPU kernels that compile to efficient GPU code. Key features:

  • Automatic memory coalescing and shared memory management
  • Block-level programming model (operate on tiles, not individual threads)
  • Just-in-time compilation to PTX/SASS
  • Dramatically lower development effort than raw CUDA

A fused layer norm kernel in Triton is roughly 40 lines of Python. The equivalent CUDA kernel is 200+ lines of C++. Performance is comparable because Triton's compiler handles the low-level optimizations.

When Fusion Helps and When It Does Not

Fusion helps when:

  • Operations are memory-bound (low arithmetic intensity)
  • Intermediate results are used once and then discarded
  • The chain of operations fits in the register/SRAM budget per thread block

Fusion does not help when:

  • The bottleneck is compute, not memory (large matmuls already achieve high utilization)
  • Intermediate results must be reused by multiple downstream operations
  • The fused kernel becomes so large that register pressure reduces occupancy below the level where the GPU can hide memory latency

Common Confusions

Watch Out

Kernel fusion does not reduce FLOPs

Fusion performs the exact same floating-point operations as the unfused version. The speedup comes entirely from eliminating redundant HBM traffic. If anything, fusion may add a small overhead from more complex control flow within the single kernel.

Watch Out

Not all operations should be fused

Fusing a large matmul with a small element-wise operation may not help if the matmul already achieves high compute utilization. The element-wise operation's memory traffic is negligible compared to the matmul's compute time. Fusion adds implementation complexity without meaningful speedup.

Watch Out

Triton is not a replacement for all CUDA

Triton excels at element-wise, reduction, and attention-like kernels. For operations requiring warp-level primitives, shared memory bank conflict management, or tensor core scheduling, hand-written CUDA may still be necessary. Triton's abstraction level prevents some low-level optimizations.

Summary

  • Kernel fusion eliminates intermediate HBM reads and writes between consecutive operations
  • For a chain of kk memory-bound operations, fusion reduces IO by a factor of kk
  • Flash Attention is a fused kernel: softmax, scaling, and matmul with V in one kernel
  • Triton makes writing fused kernels accessible from Python
  • Fusion helps memory-bound operations; it does not help compute-bound operations

Exercises

ExerciseCore

Problem

A transformer layer applies (in sequence): residual addition (NN elements read and written), layer norm (NN elements read and written), and dropout (NN elements read and written). Each operation is a separate kernel. How many total HBM reads and writes occur? How many with full fusion?

ExerciseAdvanced

Problem

Flash Attention fuses the computation softmax(QK/d)V\text{softmax}(QK^\top / \sqrt{d}) V into a single tiled kernel. Explain why this fusion is more complex than fusing element-wise operations. What specific challenge does the softmax normalization introduce, and how is it resolved?

References

Canonical:

  • NVIDIA CUDA Programming Guide, Chapter on kernel optimization
  • Dao et al., FlashAttention (2022), Sections 3-4 (Flash Attention as a fused kernel)

Current:

  • Tillet, Kung, Cox, Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations (2019)
  • NVIDIA, Transformer Engine documentation (fused kernels for FP8 transformer training)

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.