Softmax axis in attention logits
Question
In multi-head attention, the logits tensor has shape [batch, heads, query_tokens, key_tokens]. Over which axis must softmax normalize to produce, for each query position, a probability distribution over key positions?
Why this matters
The scaled dot-product attention formula is softmax(QK^T / sqrt(d_k)) V, with softmax applied independently to each row of QK^T / sqrt(d_k). Each row corresponds to one query, and the row's entries are the dot products of that query with every key. Normalizing over the wrong axis silently destroys the per-query distribution: the resulting weights no longer sum to one over keys, and the attention output stops behaving like a soft dictionary lookup.
Common mistake
Specifying axis=2 (query_tokens) when calling softmax on the [B,H,Q,K] logits tensor. Frameworks default to axis=-1 which is correct here; passing axis=-2 silently produces the column-normalized object instead.
Source anchor
content/topics/attention-mechanism-theory.mdx#scaled-dot-product-attention