Skip to main content
Definition check·Difficulty 2/5·Target edge: softmax-and-numerical-stabilityattention-mechanism-theory

Softmax axis in attention logits

Question

In multi-head attention, the logits tensor has shape [batch, heads, query_tokens, key_tokens]. Over which axis must softmax normalize to produce, for each query position, a probability distribution over key positions?

Why this matters

The scaled dot-product attention formula is softmax(QK^T / sqrt(d_k)) V, with softmax applied independently to each row of QK^T / sqrt(d_k). Each row corresponds to one query, and the row's entries are the dot products of that query with every key. Normalizing over the wrong axis silently destroys the per-query distribution: the resulting weights no longer sum to one over keys, and the attention output stops behaving like a soft dictionary lookup.

Common mistake

Specifying axis=2 (query_tokens) when calling softmax on the [B,H,Q,K] logits tensor. Frameworks default to axis=-1 which is correct here; passing axis=-2 silently produces the column-normalized object instead.

Source anchor

content/topics/attention-mechanism-theory.mdx#scaled-dot-product-attention