Beta. Content is under active construction and has not been peer-reviewed. Report errors on GitHub.Disclaimer

LLM Construction

Forgetting Transformer (FoX)

FoX adds a data-dependent forget gate to softmax attention. The gate down-weights unnormalized attention scores between past and present positions, giving the transformer a learned, recency-biased decay. FoX is FlashAttention-compatible, works without positional embeddings, and improves long-context language modeling and length extrapolation.

AdvancedTier 2Frontier~40 min
0

Why This Matters

Standard softmax attention has no built-in notion of recency. Position is injected externally through positional embeddings (sinusoidal, learned, RoPE, ALiBi), and the attention weights themselves are computed from content alone. Recurrent models do the opposite: LSTMs and state-space models carry a learned forget gate that decides, at each step, how much of the past to keep.

FoX (Forgetting Transformer) asks whether softmax attention can borrow that mechanism. The answer is yes, with a small and clean modification: compute a scalar forget gate per token and per head, take a running product of these gates between positions jj and ii, add the log of this product to the pre-softmax attention logits. That is it. No positional embeddings are required. The resulting model trains at the same speed as a standard transformer because the modification is compatible with the FlashAttention kernel. On long-context language modeling and length extrapolation, FoX beats a tuned RoPE baseline. On short-context downstream tasks it performs competitively.

The paper is Lin, Nikishin, He, and Courville, Forgetting Transformer: Softmax Attention with a Forget Gate, ICLR 2025 (arXiv:2503.02130). The contribution is narrow but instructive: a single data-dependent scalar per token, injected in the right place, gives transformers the recency inductive bias of recurrent models without giving up parallel training or the FlashAttention fast path.

Mental Model

Think of softmax attention as a content-based lookup with no decay: token ii can attend equally to any earlier token jj, and nothing in the attention weight itself punishes distance. Positional embeddings fix this indirectly by encoding iji - j as a content signal.

FoX adds a direct, multiplicative decay. Each token tt emits a forget value ft(0,1)f_t \in (0, 1) per attention head. The "keep mass" from token jj to token ii is the product

Fij==j+1if.F_{ij} = \prod_{\ell = j+1}^{i} f_\ell.

If every intermediate ff_\ell is near 1, the keep mass is near 1 and FoX recovers ordinary attention. If one intermediate ff_\ell is near 0, it closes the gate and the contribution from all tokens before it is suppressed for all future queries. The gate is data-dependent, so the model can decide, based on the current input, when to wipe context.

In log space this decay becomes additive. That is the key to efficient implementation: adding logFij\log F_{ij} to the attention logits is the same kind of operation ALiBi already performs, except the slope is data-dependent instead of a fixed head-specific constant.

Formal Setup

Definition

Forget Gate (FoX)

For each attention head hh and each token position tt, the forget gate is a scalar in (0,1)(0, 1) computed from the current hidden state xtx_t:

ft(h)=σ(wf(h)xt+bf(h)),ft(h)(0,1),f_t^{(h)} = \sigma\bigl( w_f^{(h) \top} x_t + b_f^{(h)} \bigr), \qquad f_t^{(h)} \in (0, 1),

where σ\sigma is the sigmoid, and wf(h)Rdw_f^{(h)} \in \mathbb{R}^d, bf(h)Rb_f^{(h)} \in \mathbb{R} are per-head learnable parameters. The gate is a scalar per head, not a vector, so its parameter cost is negligible.

Definition

Forgetting Attention

Let qi,kj,vjq_i, k_j, v_j denote query, key, and value vectors at positions i,ji, j for a fixed head. Define the cumulative log forget gate

Dij==j+1ilogf(ji),Dij=(j>i).D_{ij} = \sum_{\ell = j+1}^{i} \log f_\ell \quad (j \leq i), \qquad D_{ij} = -\infty \quad (j > i).

Forgetting Attention is the causal attention

oi=j=1iexp(qikj+Dij)vjj=1iexp(qikj+Dij).o_i = \frac{\sum_{j=1}^{i} \exp(q_i^{\top} k_j + D_{ij}) \, v_j}{\sum_{j=1}^{i} \exp(q_i^{\top} k_j + D_{ij})}.

Equivalently, letting Fij==j+1if=exp(Dij)F_{ij} = \prod_{\ell = j+1}^{i} f_\ell = \exp(D_{ij}),

oi=j=1iFijexp(qikj)vjj=1iFijexp(qikj).o_i = \frac{\sum_{j=1}^{i} F_{ij} \exp(q_i^{\top} k_j) \, v_j}{\sum_{j=1}^{i} F_{ij} \exp(q_i^{\top} k_j)}.

The factor FijF_{ij} down-weights the unnormalized attention score between positions ii and jj, with no change to the softmax denominator structure.

The modification lives entirely inside the attention logits. Adding a bias DijD_{ij} to qikjq_i^{\top} k_j is the same access pattern FlashAttention already handles for causal masks and ALiBi, so FoX reuses the FlashAttention kernel without a new fused implementation.

Main Theorems

Proposition

Forgetting Attention is Attention With Log-Domain Decay

Statement

Let Fij==j+1ifF_{ij} = \prod_{\ell = j+1}^{i} f_\ell with f(0,1)f_\ell \in (0,1). Then Forgetting Attention can be written as standard softmax attention on modified logits

s~ij=qikj+logFij,FoXAttn(qi,K,V)=softmaxj(s~ij)V.\tilde{s}_{ij} = q_i^{\top} k_j + \log F_{ij}, \quad \mathrm{FoXAttn}(q_i, K, V) = \mathrm{softmax}_j(\tilde{s}_{ij}) V.

In particular, when f1f_\ell \equiv 1 the bias vanishes and FoX reduces to standard causal softmax attention.

Intuition

The forget gate never enters the softmax denominator separately. It is folded into the logits as a position-dependent bias. This is the structural reason FoX keeps FlashAttention compatibility: the kernel only needs one extra bias term per (i,j)(i, j) pair, which can be computed on the fly from a prefix sum of logf\log f_\ell.

Proof Sketch

Start from the Forgetting Attention definition and factor Fij=exp(logFij)F_{ij} = \exp(\log F_{ij}). Then Fijexp(qikj)=exp(qikj+logFij)F_{ij} \exp(q_i^{\top} k_j) = \exp(q_i^{\top} k_j + \log F_{ij}). Both numerator and denominator share the same exponential, so the ratio is exactly the softmax of the shifted logits s~ij\tilde{s}_{ij}. When f=1f_\ell = 1 for all \ell, logFij=0\log F_{ij} = 0, recovering standard attention.

Why It Matters

This reformulation is what makes FoX trainable at transformer speed. Prefix sums of scalars are cheap, and the shifted logits drop into any attention kernel that supports causal masking plus a bias. Compare to linear attention variants where the decay must be baked into a custom recurrence: FoX keeps the softmax and the associativity of the kernel untouched.

Failure Mode

If all gates saturate near 1, FoX becomes a plain transformer with wasted parameters. If all gates saturate near 0, only the immediate previous token contributes, collapsing the model to a very short effective context. Initialization that pushes bf(h)b_f^{(h)} to a positive value (analogous to LSTM forget-gate bias tricks) is important to keep gates open early in training.

Proposition

Recency Bias of Forgetting Attention

Statement

Fix a head and a query position ii. Suppose the forget gates satisfy ffˉ<1f_\ell \leq \bar{f} < 1 for all \ell. Then the multiplicative weight on the key at position j<ij < i is at most

Fij==j+1iffˉij.F_{ij} = \prod_{\ell = j+1}^{i} f_\ell \leq \bar{f}^{\,i - j}.

In particular, the effective attention weight Fijexp(qikj)/ZiF_{ij} \exp(q_i^{\top} k_j) / Z_i decays at least geometrically in the gap iji - j, uniformly in the content similarity qikjq_i^{\top} k_j, where ZiZ_i is the normalizer.

Intuition

Once the gate stays strictly below 1, FoX imposes a data-independent upper bound on how much a far-away token can contribute, regardless of how well its key matches the query. Content similarity can still push attention toward older tokens, but only up to the ceiling set by the cumulative product. This is a provable recency bias that standard attention lacks.

Proof Sketch

Each factor in =j+1if\prod_{\ell = j+1}^{i} f_\ell is at most fˉ\bar{f}, and there are iji - j factors, giving FijfˉijF_{ij} \leq \bar{f}^{\,i - j}. The numerator contribution of key jj is Fijexp(qikj)vjF_{ij} \exp(q_i^{\top} k_j) v_j and is divided by ZiFiiexp(qiki)=exp(qiki)Z_i \geq F_{ii} \exp(q_i^{\top} k_i) = \exp(q_i^{\top} k_i), so the normalized weight on vjv_j is at most fˉijexp(qikjqiki)\bar{f}^{\,i-j} \exp(q_i^{\top} k_j - q_i^{\top} k_i).

Why It Matters

Length extrapolation and stable long-context behavior require the model to refuse to attend to arbitrarily old tokens with arbitrary confidence. FoX enforces this structurally. This is one reason FoX generalizes past its training length without ALiBi-style slope tuning or RoPE tricks.

Failure Mode

If the gate is not well-regulated, a single near-zero ff_\ell kills all information from before position \ell. For tasks that require a specific long-range lookup (e.g., needle-in-a-haystack retrieval), this can be harmful. The paper reports FoX holds up on retrieval tasks, but the theoretical risk is real: the gate must learn to stay open for task-relevant anchors.

The Pro Block

The paper also introduces a "Pro" block design that layers several small components from recurrent and efficient-attention literature around the Forgetting Attention core. The components are:

  • Output gate: a sigmoid gate applied to the attention output before the residual, similar to gated linear attention variants.
  • Output normalization: an RMSNorm on the attention output prior to the output projection.
  • QK-norm: RMSNorm applied to queries and keys before the dot product, which stabilizes logits at long context.
  • KV-shift: a simplified, learned, data-dependent token-shift on keys and values, borrowed from the short-convolution tradition in RWKV and Mamba-like designs.

Each piece is cheap in parameters and compatible with standard attention kernels. The paper reports that Pro blocks improve both FoX and the ordinary transformer, but the improvement for FoX is larger. The takeaway: the forget gate is the architectural change of interest, and the Pro block is a collection of well-trodden stabilizers that happen to pair well with it.

Historical Context

Multiplicative gating in sequence models goes back to LSTM (Hochreiter and Schmidhuber, 1997), which introduced input, forget, and output gates to stabilize gradient flow through long sequences. Highway Networks (Srivastava et al., 2015) carried the idea into feedforward depth. GRU (Cho et al., 2014) compressed LSTM gating to a single update gate.

In attention, ALiBi (Press et al., 2022) added a fixed, non-learned linear bias in the attention logits that decays with distance. FoX can be read as a data-dependent generalization of ALiBi: instead of a constant per-head slope, each head gets a sequence of gates whose cumulative log acts as the bias. Linear-attention variants like RetNet, GLA, and Mamba reached similar conclusions from a different direction, building recurrent state-space models with data-dependent decays. FoX keeps softmax attention and inherits its expressivity while borrowing the decay primitive.

Common Confusions

Watch Out

FoX does not gate the FFN

An earlier draft of this page (and some secondhand descriptions online) claim FoX adds a forget gate to the feed-forward block. That is wrong. The forget gate modifies the unnormalized attention scores inside the softmax, not the FFN output. The FFN block in FoX is the standard MLP or SwiGLU that the baseline uses.

Watch Out

The gate is a scalar per head, not a vector per dimension

Unlike LSTM forget gates, which are vectors that gate each hidden dimension independently, the FoX forget gate is a single scalar per attention head at each position. The scalar multiplies the whole attention contribution between positions ii and jj. This is what makes the cumulative product well-defined and FlashAttention-friendly.

Watch Out

FoX replaces positional embeddings, not positional information

FoX is trained and evaluated without RoPE, ALiBi, or sinusoidal positional embeddings. It is not the case that FoX is position-agnostic: the cumulative product FijF_{ij} depends on the gap iji - j and on the intermediate content, which gives the model an implicit, data-shaped sense of distance. The paper's claim is that this implicit signal is strong enough to make explicit positional embeddings unnecessary.

Watch Out

FoX is not a linear-attention method

FoX keeps softmax over the attention logits. Its decay enters as an additive bias in log space, not as a state update in a recurrence. That means FoX retains the full O(n2)O(n^2) training cost of attention, but also retains its expressivity. Compare to Mamba or RWKV, which use recurrent state to achieve O(n)O(n) at the cost of changing the computation.

Exercises

ExerciseCore

Problem

Show that when every forget gate f=1f_\ell = 1, Forgetting Attention is exactly equal to standard causal softmax attention. Then show that when every f=f(0,1)f_\ell = f \in (0, 1) is a head-specific constant, Forgetting Attention reduces to attention with an ALiBi-like linear bias. Give the exact slope.

ExerciseAdvanced

Problem

Let f(0,1)f_\ell \in (0, 1) be the per-head forget gates, and let Fij==j+1ifF_{ij} = \prod_{\ell = j+1}^{i} f_\ell. Suppose your hardware only supports causal attention kernels with a single additive bias term per (i,j)(i, j) pair. Describe an O(n)O(n) preprocessing step that allows Forgetting Attention to be computed by such a kernel with no change to the kernel itself. Why does this preprocessing not apply to a FFN-level forget gate?

References

Canonical:

  • Lin, Nikishin, He, and Courville, Forgetting Transformer: Softmax Attention with a Forget Gate (2025), arXiv:2503.02130. Published at ICLR 2025. Primary source for FoX, Forgetting Attention, and the Pro block.
  • Hochreiter and Schmidhuber, Long Short-Term Memory (1997), Section 2. Origin of the forget gate in recurrent networks.

Context for attention-level decay:

  • Press, Smith, and Lewis, Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (ICLR 2022). ALiBi: fixed, non-data-dependent linear bias on attention logits. FoX generalizes this.
  • Su, Lu, Pan, et al., RoFormer: Enhanced Transformer with Rotary Position Embedding (2021), Sections 3 and 4. RoPE is the standard baseline FoX drops.

Context for recurrent decay:

  • Gu and Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023), Sections 3.2 and 3.5. Data-dependent selective decay in a state-space model, the closest non-attention counterpart to the FoX gate.
  • Peng et al., RWKV: Reinventing RNNs for the Transformer Era (EMNLP Findings 2023), Sections 4 and 5. Data-dependent decay in a recurrent formulation.

Infrastructure:

  • Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023). The kernel FoX reuses via its additive-bias structure.

Further reading:

  • Jurafsky and Martin, Speech and Language Processing (3rd ed. draft), Chapters 9 and 10. Background on attention and long-context modeling.

Next Topics

Last reviewed: April 2026

Prerequisites

Foundations this topic depends on.

Next Topics