A Close Reading of Self-Attention

2025-07-12

Framing self-attention in terms of convex combinations and similarity matrices, aiming to precisely ground common intuitive explanations.

This article assumes comfort with linear algebra. It’s designed for those who have seen the self-attention formula but haven’t yet found the time to really understand it (i.e., me).

Self-attention is a mechanism in deep learning that allows positions in sequences to aggregate information from other positions of their ‘choosing’, depending on the specific input. We’ll make this intuition precise.

For starters, self-attention is a sequence-to-sequence map. Each element of the output sequence is a weighted average of feature vectors, where the weights and feature vectors are determined from the input sequence.

Mathematically: Self-attention maps an SS-element sequence of dd-vectors, xRS×dx \in \mathbb{R}^{S \times d}, to another SS-element sequence of dvd_v-vectors, Attention(x)RS×dv\text{Attention}(x) \in \mathbb{R}^{S \times d_v}. Its elements are given by Attention(x)i=AiV=jAijVj,\text{Attention}(x)_{i} = A_{i}V = \sum_{j} A_{ij} V_j,for a weight matrix AA whose rows AiA_i sum to one: jAij=1\sum_j A_{ij} = 1. The iith output vector is the weighted average of feature vectors VjRdvV_{j} \in \mathbb{R}^{d_v}.

Feature vectors, VV

The feature vectors are a linear projection of the input. For an input sequence xRS×dx \in \mathbb{R}^{S \times d}, the feature matrix VRS×dvV \in \mathbb{R}^{S \times d_{v}} is V=xWV, V = xW^V, for some low-dimensional projection matrix WVRd×dvW^V \in \mathbb{R}^{d \times d_{v}}. Think of VV as a batched projection of xx, so each sequence element is passed through the same projection.

Handdrawn illustration of the equation V = xW^V

Attention matrix, AA

The ‘weight’ matrix ARS×SA \in \mathbb{R}^{S \times S} is a normalized similarity matrix, where each entry AijA_{ij} represents a measure of “alignment” between xjx_{j} and xix_{i}. Alignment is measured through the kernel matrix,xMxT, xMx^T, for some learnable matrix MRd×dM \in \mathbb{R}^{d \times d}. Each element ijij represents the dot product (in the geometry defined by MM) between input vectors xix_i and xjx_j. The matrix MM may be asymmetric, so(xMxT)ij(xMxT)ji, (xMx^T)_{ij} \neq (xMx^T)_{ji}, in general. Handdrawn image illustrating the matrix xMx^T

The attention matrix AA is constructed by passing this kernel matrix through an element-wise nonnegative function ψ\psi, then normalizing each row to have unit sum. Specifically, Aij:=ψ(xiMxjT)kψ(xiMxkT) A_{ij} := \frac{\psi(x_{i}Mx_{j}^T)}{\sum_{k} \psi(x_{i}Mx_{k}^T)} In practice, ψ\psi is often chosen to be ψ=exp\psi = \text{exp}.1 With this choice, AA can be concisely expressed as the result of a softmax operation applied rowwise: A=softmaxrowwise(xMxT). A = \text{softmax}_{\text{rowwise}}\left(xMx^T \right).

Convex combinations, AVAV

With VV and AA defined, we can pull together a precise understanding of self-attention.

Self-attention maps one sequence to another of the same length. Each element Attention(x)i\text{Attention}(x)_i in the output sequence is a convex combination of linear features VV of the input sequence xx. Attention(x)i=AiV=jAijVj \text{Attention}(x)_{i} = A_{i}V = \sum_{j} A_{ij}V_{j} Each combination weight AijA_{ij} is determined based on the kernel similarity between xix_i and xjx_j, normalized by its similarities to all other sequence elements.2

Parameterizing MM

In practice, we don’t learn the bilinear form matrix MM directly. Instead, we constrain its rank (and hopefully make it easier to learn) by parameterizing as a product of lower-dimensional factors.

We pre-specify a maximum rank dkd_k, and parameterize MM in terms of learnable d×dkd \times d_{k} matrices WQ,WKW^Q, W^K, as follows: M=1dk(WQ)(WK)T. M = \frac{1}{\sqrt{d_{k}}} (W^Q)(W^K)^T. The product is scaled by 1dk\frac{1}{\sqrt{ d_{k} }} to limit the magnitude of the entries in practice. The goal is to avoid pushing the softmax inputs into regions where the nonlinearity has very small gradients.

The usual notation

Translating from our expression Attention(x)=softmaxrowwise(xMxT)V \text{Attention}(x) = \text{softmax}_{\text{rowwise}}(xMx^T)V to the form generally seen in papers, Attention(Q,K,V)=softmaxrowwise(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}_{\text{rowwise}}\left( \frac{QK^T}{\sqrt{ d_{k} }} \right)V is quite simple.

To do so, we define the usual Q,KQ, K matrices by “pulling” the factors WQW^Q and WKW^K onto each xx term in the kernel matrix computation. Specifically, define Q,KRS×dkQ, K \in \mathbb{R}^{S \times d_{k}} as Q=xWQandK=xWK, Q = xW^Q \quad \text{and} \quad K = xW^K, Like with the projection matrix WVW^V above, this operation is best understood as a “batched projection.” The two notations are then trivially equivalent, withxMxT=x(1dk(WQ)(WK)T)xT=1dk(xWQ)(xWK)T=1dkQKT. xMx^T = x \left( \frac{1}{\sqrt{ d_{k} }} (W^Q) (W^K)^T \right)x^T = \frac{1}{\sqrt{ d_{k} }} (xW^Q)(xW^K)^T = \frac{1}{\sqrt{ d_{k} }}QK^T. All together:Attention(x)=AV=softmax(xMxT)V=softmax(QKTdk)V. \text{Attention(x)} = AV = \text{softmax}(xMx^T)V = \text{softmax}\left( \frac{QK^T}{\sqrt{d_{k} }} \right) V.

So what?

Abstractly, this article frames self-attention as the composition Attention(x)i=[Normalize(Kernel(xi,x))]Features(x). \text{Attention}(x)_{i} = \big[ \text{Normalize}(\text{Kernel}(x_{i}, x)) \big] \text{Features}(x). In words, self-attention computes an output sequence where each element ii is a convex combination of feature vectors jj. The contribution of vector jj to the output (the magnitude of its combination weight, AijA_{ij}) is determined by a normalized kernel similarity between input elements ii and jj.

This framing makes apparent some otherwise hand-wavy observations.

Observation 1: Self-attention reduces the ‘distance’ between tokens ii and jj

Self-attention is a powerful mechanism because it calculates each output element ii by directly pulling information from other tokens jj. Compare this to the dependency between tokens in a recurrent neural network. Define the basic RNN as RNN(x)t:=g(ht)andht:=f(xt,ht1). \text{RNN}(x)_{t} := g(h_{t}) \quad \text{and} \quad h_{t} := f(x_{t}, h_{t - 1}). Each output computed, RNN(x)i\text{RNN}(x)_{i}, is the result of ii recursive applications of gg and ff along the input sequence. The information about early tokens x1,...,xix_1, ..., x_i is bottlenecked through the single “hidden state” representation hih_i.

In theory, hih_i should be able to represent all the relevant information from the previous sequence tokens. But in practice, RNNs don’t learn to compress this information well.3

Graphic illustrating the internal distance between tokens in RNNs Self-attention sidesteps this compression problem. It removes the hidden state bottleneck and directly mixes features from other tokens.

Graphic illustrating the shorter distance between tokens in self-attention In doing so, self-attention also parallelizes the sequence-to-sequence computation.

RNNs are inherently sequential, since their output at time t+1t+1 depends on the previous hidden state hth_t. But with self-attention, each output term Attention(x)i\text{Attention}(x)_{i} and Attention(x)j\text{Attention}(x)_{j} can be computed completely in parallel.

This parallelization, however, does come at a cost.

Observation 2: Self-attention has no “interaction terms”

Because self-attention parallelizes the similarity computation for output ii across all input elements jj at once, it’s quite restricted in the way it can combine information from input tokens.

Specifically, self-attention can only ‘combine’ information from tokens j,kj, k into Attention(x)i\text{Attention}(x)_{i} through the linear mixing AijVj+AikVk. A_{ij}V_{j} + A_{ik}V_{k}. Even the linear mixing weights AijA_{ij} and AikA_{ik} depend largely only on the individual similarities Kernel(xi,xj)\text{Kernel}(x_i, x_{j}) and Kernel(xi,xk)\text{Kernel}(x_i, x_{k}). This turns out to limit the kinds of computations that (a single layer of) self-attention can perform.4

Various generalizations of self-attention aiming to address this limitation by incorporating interaction terms have been proposed.

The basic concept of “higher-order attention” is simple, and fits nicely into the framework we’ve discussed. For concreteness, let’s consider third-order attention.5

Third-Order Attention

There are two key changes. Rather than mixing individual features VjV_j associated with individual tokens xjx_j, we mix pair features Vj,kV_{j, k} associated with token pairs xj,kx_{j, k}. Rather than defining the mixing weights for output ii using the similarity between xix_{i} and all other tokens xjx_{j}, we mix the pair features based on the similarity between xix_i and all pairs of tokens xj,xkx_j, x_k.

The final form of third-order attention is semantically equivalent to standard attention. It composes an output sequence composed by linearly mixing interaction features from the input: Attention3(x)i=j,k[Normalize(Kernel(x,x,x))]i,j,kFeaturesj,k(x) \text{Attention}_{3}(x)_{i} = \sum_{j, k}\big[ \text{Normalize}(\text{Kernel}(x, x, x)) \big]_{i, j, k} \text{Features}_{j,k}(x) Recent research has demonstrated practical performance improvements from incorporating higher-order attention layers (Roy et al, 2025).

But unfortunately, there are no free lunches. Higher-order attention is significantly more expensive. Whereas standard self-attention is O(S2)O(S^2), third-order attention is O(S3)O(S^3), since we now need to compute an S×S×SS \times S \times S attention tensor rather than an S×SS \times S attention matrix.

Conclusion

We’ve now thoroughly examined each component of self-attention. The self-attention mechanism is a sequence-to-sequence transformation that computes each output as a weighted average of features, where the weights are based on normalized similarities between input sequence elements. This framing precisely undergirds the intuition that self-attention allows positions to “look at” and “choose” information from other positions.


  1. Some interesting papers exploring other nonlinearities: 1, 2, 3. Maybe another blog post to come on this. 

  2. In practice, we restrict the iith output to only pull from the first ii tokens. This is called causal masking, and is another main reason why transformers are so good at autoregressive generation. 

  3. There are mathematical arguments for the difficulty of learning these “long-range dependencies,” as well. The rough idea is that the contribution of an input token ii to a loss at output token jj decays to zero for iji \ll j. See (Bengio et al, 1994) and (Hochreiter and Schmidhuber, 1997) for more. 

  4. For example: (Sanford et al, 2023) introduce a task, Match3\verb|Match3|, that is very difficult for standard attention but trivially solvable for a “higher-order” attention taking into account pairwise interactions. For a single self-attention layer, solving Match3\verb|Match3| requires an embedding dimension or number of heads that is polynomial in the input sequence length, SS. By contrast, a “third-order” self-attention layer can solve Match3\verb|Match3| for any sequence length with constant embedding dimension. 

  5. The exposition here is roughly based on (Clift et al, 2019) and its recent computationally-efficient reformulation in (Roy et al, 2025). These papers refer to third-order attention as “2-simplicial.”