There should be hundreds of articles written on this topic, but I need to come up with my own, with my understanding and insights that I've gained. I go back to this mechanism again and again, asking new questions, and here is a portion of answers to them.
My intention is not to explain the concept of attention mechanism from scratch but to focus on such topics as:
Let's start...
Before 2017, when the fundamental paper "Attention is all you need" was published, the RNNs (Recurrent Neural Networks) and, particularly, LSTM (Long Short-Term Memory) variant of it stood as dominant architectures. To learn more about how RNNs work, you can check this video.
Over time it became obvious that the current architecture has significant drawbacks that prevent scaling:
The leading interest in the first part of the 2010s decade was machine translation. In attempts to improve existing mechanisms, the research groups from Stanford University and the University of Montreal independently published the papers (paper 1 and paper 2) that introduced early forms of attention mechanism within encoder-decoder architecture. The obtained results showed the significant improvement in performance regardless of sequence length.
Two years later, in 2017, the "Attention is all you need" paper from Google introduced a new transformer architecture that put attention in the central place. The researchers were focusing on the high scale using GPUs, which turned out to be a great decision and solved the existing problem with parallelisation.
The new architecture also prevented gradients from vanishing and exploding by including normalization layers, as well as using attention to directly model dependencies between all tokens — even distant ones — in a single pass.
I do like this visualization of nano-GPT architecture, it provides a comprehensive overview of every layer and makes it much easier to make sense of the whole architecture at once. I am going back to it each time I want to remind myself of the basics.
Attention now has a bunch of variations. Let's talk a bit about functional roles.
Bidirectional attention – every token in a sequence can attend to all others (future and past). No masking.
Use cases: sentiment analysis, named entity recognition (NER), embedding generation for retrieval
The output attends to the encoder outputs.
Use cases: machine translation, summarization, speech/text generation from audio, multimodal generation (audio, image, video requires modality-specific encoding + good for separation of concerns)
Autoregressive – each token can attend to itself and past tokens.
Use cases: text generation, autocomplete, reinforcement learning
Note: It was interesting for me to learn that encoder-decoder architecture results in a better contextual understanding due to encoding the entire input all at once inside the encoder. Several content representations are build during the path: inside encoder, inside decoder, then the final input is generated using both ouputs (from decoder and encoder attention).
Here I am going to talk about the causal masked self-attention, as it is the basic one.
"Build vectorized representations of each token in a sequence that can represent language nuances by attending to itself and all previous tokens, without access to future context."
Normalized token embeddings
The attention block includes n heads. Each head applies the same algorithm to the input. In the end, results are concatenated.
Calculating Q, K, V by projecting the inputs to the matrix learnt during training
Q - what am I looking for K - what does each piece of information claim to be about V - what is the actual content
Q x Kt dot product. Computing the attention scores
Simple example:
Let's process the word "delicious" in the sentence "Lunch was very delicious today"
Masking attention scores to prevent the model from seeing future tokens.
Especially useful during training. The same mask is actually applied during inference. I was confused about that, as it is actually redundant. We want to predict the next token that we do not know based on every previous token in the sentence, there are no future tokens to mask. However, it is still useful in the "prefill" phase and for batch inference.
Calculate attention weights by applying softmax to attention scores.
The purpose:
Calculate attention output by multiplying attention weights with value (V) matrix
Concatenate attention outputs from all heads. Remember that the head output dimensionality is proportional to the embedding dimension:
emb_dim = 512
n_heads = 8
head_dim = emb_dim / n_heads = 64
After concatenation the final dimensionality goes back to the initial (emb_dim)
Linear projection
The purpose:
After the quick overview of the basic concept, let's think deeply about each operation and algorithm it contains.
Q = X * W_q
K = X * W_k
V = X * W_v
Where:
X – input of shape [seq_length, emb_dim]
W_q, W_k, W_v – learnt weight matrices
One of my first questions always was, "What do those matrices really learn during training that allows them to transform embeddings into the essential elements of the attention mechanism?"
P.S. The role of each matrix was assigned by people for better understanding.
Before training:
W_q, W_k, W_v are initialized as matrices of random numbers. While training, each forward pass of the model results in the feedback that identifies how much impact each weight in the matrix brings to the result and adjusts it before the next pass (backpropagation).
After training:
W_q learns how to transform inputs into effective queries compressed to head dimensionality.
W_k learns how to create keys that best represent the content matching compressed to head dimensionality.
W_v learns how to extract the most useful information and compress to head dimensionality.
Below is an example for more intuitive understanding:
Let's represent the word "dog" as token embedding, considering the word "dog" is one token with 512 dimensions.
Query transformation:
Combines the dimensions that turn "dog" into a search query:
Real sentence example: In "The dog chased the cat", the dog query would strongly attend to "chased" because it's looking for actions it performs.
Key transformation:
Combines the dimensions that make "dog" discoverable for the relevant queries:
Real sentence example: In "The big dog ran", adjectives like "big" have queries that match with the dog's key because both have strong "modifiable entity" signals.
Value transformation:
Combine the dimensions with actual informational content of the "dog":
Now zoom out a bit. This is what is happening for the one single word, but the main advantage of the attention mechanism is parallelism. That means that attention processes all keys and values for a sequence of tokens at once inside an attention head.
To zoom out even more, let's remember that there are usually a couple of heads computing Q, K, and V at the same time. Each head contains different W_q, W_k, and W_v matrices. After concatenating the results of each head, we receive an attention output. It combines all features of each head, creating a rich representation of each token.
That's how Q, K, and V matrices are computed.
Note: Why Kt (transposed)? Because otherwise the matrix multiplication is not possible.
[m, n] x [n, m] = [m, m], not [m,n] x [m, n]
By multiplying Q * Kt, we are going element-wise and performing a dot product operation. The dot product operation result shows how similar 2 vectors are. Similarity between vectors is important for understanding how much attention should be paid to each token with regard to another. The more similar the tokens are (point in similar direction), the higher the attention score.
Example:
Input: "The dog chased the ball"
Note: for simlicity we will use 3-dimensional vectors
a = [ a1, a2, a3 ]
b = [ b1, a2, a3 ]
a * b = a1b1 + a2b2 + a3b3
Calculation:
"chased" (query) = [ 0.7, -0.5, 0.6 ]
"dog" (key) = [ 0.8, 0.6, 0.7 ]
Dot product = ( 0.7 * 0.8 ) + (-0.5 * 0.6 ) + ( 0.6 * 0.7 ) = 0.68
The high dot product indicates strong alignment between the query "chased" and the key "dog", resulting in high attention.
Even though it seems to be obvious that we need to apply a mask to prevent the model from seeing future tokens, I want to stop here and provide as many details as possible:
Let's start with a more intuitive concept of a causal mask in a sentence "The quick brown fox". After transforming the 4-word sequence into embedding vectors and calculating attention scores, we will get the following matrix:
Next step is to create a mask. But before that, we need to do a little bit of planning here:
What is happening? We are trying to train the model to predict next words in the sentence.
What is the purpose of mask? While training the model we have a bunch of sequences in a training data. We want to use them to check if the model can predict the next word. In order to do that we should hide the future tokens in the sequence.
How to hide them and not ruin the flow? The next step of attention mechanism is a softmax. Sotmax is a function that outputs row values that sum up to 1 and are always in between 0 and 1. The formula looks like this:
In order to hide the scores they should be equal to 0 (have no attention).
What input will result in 0 output when put into this formula? That should be a really small number... How about -inf?
Fine, seems like we come up with the plan!
Thus, our mask should look like this:
Now adding the mask to the attention scores:
Zeros did not change the unmasked values.
-Inf is in practice represented by a very small value like -1e9, or -1e4.
So the result is a very small number anyway.
Now applying softmax:
We got the result! Here is also a bit more interactive visualization
Now questions.
What is the model going to predict if the 4th token is the last token in the sentence?
The model should predict the end-of-sentence token. How? The training data should include it in samples to give the model a clear signal to stop generating (if the task is generating tokens).
Ok, I have just described what usually happens while training. But during inference we want to have access to all the tokens at once, we need them to make the best prediction. Then why apply the mask?
There is no good answer to this question. The transformer architecture is optimized for training due to the ability to make parallel computations and use matrix multiplication. At inference time, the code remains the same, and all the calculations are done for each token, even though we need to compute only the last one. This is one of the reasons why such optimizations as KV cache, Grouped Query Attention, Multi-query Attention, Sparse attention and some others exist.
Do we need to always apply the mask?
Well, if we are dealing with a decoder transformer – then yes. Otherwise, models that are using bidirectional attention don't use causal masking at all. For example, the encoder-only architecture that I mentioned at the beginning. The tasks might include sentiment analysis or classification, which requires access to all the tokens at once.
I also wanted to mention another type of mask called a padding mask. A padding mask is used in transformer models to handle various sequence lengths in a batch. If there are sequences of length 3, 4, 5 in a batch, they are going to be padded and look like this:
The idea is to mark real tokens with 1 and < PAD > tokens with 0 like this:
After applying the mask to the existing sequence, we get:
The padding mask hides meaningless tokens and can be used in combination with a causal one.
Softmax is applied to masked attention scores through this formula:
Where:
xi - input value on position i.
exi - exponential of input value.
denominator – sum of all input values exponentials.
Softmax converts input sequences to the weights so that the larger score get the larger weight proportionally.
Why is it called "Softmax"?
The name "softmax" comes from it being a "soft" version of a "max" function:
Why does it work?
This should be an easy one. After the previous step, we got attention weights. Now, multiplying them by V matrix, we come up with a new matrix that contains the weighted amount of real information across all token dimensions. It also emphasizes the most relevant features for building better understanding of every particular token in a given context.
Remember that now the final dimension of each head output equals emb_dim / head_count. In order to come back to the original embedding dimension, we just simply concatenate all the outputs together. Different heads captured different patterns during the forward pass, so now the final output does not miss anything.
Isn't it strange just to concatenate all the heads outputs? Suspiciously simple.
Anyway, we need a final projection for a reason. Final projection is a linear transformation whose main functions are:
The good analogy might be imagining several experts working on different pieces of a complex task. Each produces a highly specialized report on their topic. The projection layer plays the role of a representor, who combines all in one, producing a concise but dense report, like a summary with highlighted connections between findings.
I have already mentioned that the original attention mechanism has some inefficiencies that raise the computation cost. KV cache is one of the optimization techniques used during inference to avoid redundant computations of keys and values that always stay the same and there is no need for recomputing.
The idea: Cache keys and values that have already been computed.
Speed improvement: 2-10x depending on model size.
Trade-offs: cache takes additional memory, especially with a large context. Check the visualization here.
Computational complexity: O(n2) -> O(n)
Complexity is a number of operations.
Without KV cache:
Let:
L – sequence length
d - embedding dimensionality
Computing attention scores, each query attends to every key, resulting in quadratic complexity:
Q: shape [ L, d ]
K: shape [ L, d ]
Q*K_t: shape [ L, L ]
Complexity: O( L² * d )
This is followed by applying the attention weights to the value matrix:
V: shape [L, d]
Multiply [L, L] * [L, d] → O(L² * d)
The total complexity per block: O( n<sub>layers</sub> * L<sup>2</sup> * d )
With KV cache:
One new query attends to cached K/V.
Complexity O( L * d ) - linear
Shape: [ batch_size, head_count, seq_length, head_dim ].
Here is also an implementation and an interactive visualization of a concept.
Thanks for reading!