Kseniya Parkhamchuk

Attention. Again

There should be hundreds of articles written on this topic, but I need to come up with my own, with my understanding and insights that I've gained. I go back to this mechanism again and again, asking new questions, and here is a portion of answers to them.

My intention is not to explain the concept of attention mechanism from scratch but to focus on such topics as:

  1. Historical context
  2. Difference between training and inference mode
  3. Meaning of each operation in a mechanism (matmul, dot product, softmax are on the agenda)
  4. KV cache

Let's start...

Historical context

Before 2017, when the fundamental paper "Attention is all you need" was published, the RNNs (Recurrent Neural Networks) and, particularly, LSTM (Long Short-Term Memory) variant of it stood as dominant architectures. To learn more about how RNNs work, you can check this video.

Over time it became obvious that the current architecture has significant drawbacks that prevent scaling:

  • Exploding and vanishing gradients, which make it difficult to learn long-term dependencies in long sequences.
  • Challenges in capturing the dependencies between distant elements with an increasing sequence length.
  • RNNs process sequences sequentially rather than in parallel, resulting in inefficient training and poor scalability.

The leading interest in the first part of the 2010s decade was machine translation. In attempts to improve existing mechanisms, the research groups from Stanford University and the University of Montreal independently published the papers (paper 1 and paper 2) that introduced early forms of attention mechanism within encoder-decoder architecture. The obtained results showed the significant improvement in performance regardless of sequence length.

Two years later, in 2017, the "Attention is all you need" paper from Google introduced a new transformer architecture that put attention in the central place. The researchers were focusing on the high scale using GPUs, which turned out to be a great decision and solved the existing problem with parallelisation.

The new architecture also prevented gradients from vanishing and exploding by including normalization layers, as well as using attention to directly model dependencies between all tokens — even distant ones — in a single pass.


Quick algorithm reminder

I do like this visualization of nano-GPT architecture, it provides a comprehensive overview of every layer and makes it much easier to make sense of the whole architecture at once. I am going back to it each time I want to remind myself of the basics.

Attention now has a bunch of variations. Let's talk a bit about functional roles.

Encoder self-attention

Bidirectional attention – every token in a sequence can attend to all others (future and past). No masking.

Use cases: sentiment analysis, named entity recognition (NER), embedding generation for retrieval

causal masked self-attention

Encoder-decoder self-attention (cross-attention)

The output attends to the encoder outputs.

Use cases: machine translation, summarization, speech/text generation from audio, multimodal generation (audio, image, video requires modality-specific encoding + good for separation of concerns)

causal masked self-attention

Decoder self-attention (causal masked attention)

Autoregressive – each token can attend to itself and past tokens.

Use cases: text generation, autocomplete, reinforcement learning

causal masked self-attention

Note: It was interesting for me to learn that encoder-decoder architecture results in a better contextual understanding due to encoding the entire input all at once inside the encoder. Several content representations are build during the path: inside encoder, inside decoder, then the final input is generated using both ouputs (from decoder and encoder attention).

Here I am going to talk about the causal masked self-attention, as it is the basic one.

This is the idea:

"Build vectorized representations of each token in a sequence that can represent language nuances by attending to itself and all previous tokens, without access to future context."

Inputs:

Normalized token embeddings

The flow:

The attention block includes n heads. Each head applies the same algorithm to the input. In the end, results are concatenated.


Calculating Q, K, V by projecting the inputs to the matrix learnt during training

Q - what am I looking for K - what does each piece of information claim to be about V - what is the actual content


Q x Kt dot product. Computing the attention scores

Simple example:

Let's process the word "delicious" in the sentence "Lunch was very delicious today"

  • "delicious" is the word the model should make sense of. So it becomes a query.
  • all other words are keys
  • dot product between "delicious" and all other words represents their relevance to each other. Geometrically, the dot product shows the similarity between vectors. The closer they are to each other, the higher the score.

image


Masking attention scores to prevent the model from seeing future tokens.

Especially useful during training. The same mask is actually applied during inference. I was confused about that, as it is actually redundant. We want to predict the next token that we do not know based on every previous token in the sentence, there are no future tokens to mask. However, it is still useful in the "prefill" phase and for batch inference.


Calculate attention weights by applying softmax to attention scores.

The purpose:

  • make scores positive
  • normalize values
  • preserve the proportions
  • sum each row to one
  • turn scores into weights to understand how much attention each token should get

Calculate attention output by multiplying attention weights with value (V) matrix


Concatenate attention outputs from all heads. Remember that the head output dimensionality is proportional to the embedding dimension:

	emb_dim = 512
	n_heads = 8
	head_dim = emb_dim / n_heads = 64

After concatenation the final dimensionality goes back to the initial (emb_dim)

image


Linear projection

The purpose:

  • heads information integration. Even though the projected vectors already have embedding dimensionality, they are still strictly separated in a dimensionality space (1-64, 65-128, etc.). The projection learns how features from different heads should interact.
  • regulate heads weighing. Not all heads are equal
  • important for all tasks or contexts. The projection learns a contribution of each head and amplifies or weakens its signal
  • training stabilization

Meaning of each operation

After the quick overview of the basic concept, let's think deeply about each operation and algorithm it contains.

Q, K, V computation (inside one attention head)

Q = X * W_q
K = X * W_k
V = X * W_v

Where:

X – input of shape [seq_length, emb_dim]
W_q, W_k, W_v – learnt weight matrices

One of my first questions always was, "What do those matrices really learn during training that allows them to transform embeddings into the essential elements of the attention mechanism?"

P.S. The role of each matrix was assigned by people for better understanding.

Before training:

W_q, W_k, W_v are initialized as matrices of random numbers. While training, each forward pass of the model results in the feedback that identifies how much impact each weight in the matrix brings to the result and adjusts it before the next pass (backpropagation).

After training:

W_q learns how to transform inputs into effective queries compressed to head dimensionality. W_k learns how to create keys that best represent the content matching compressed to head dimensionality.
W_v learns how to extract the most useful information and compress to head dimensionality.

Below is an example for more intuitive understanding:

Let's represent the word "dog" as token embedding, considering the word "dog" is one token with 512 dimensions.

image

Query transformation:

Combines the dimensions that turn "dog" into a search query:

  1. What actions am I performing? (dim 4)
  2. What attributes describe me? (dim 1, 3)
  3. What belongs to me? (dim 128)

Real sentence example: In "The dog chased the cat", the dog query would strongly attend to "chased" because it's looking for actions it performs.


Key transformation:

Combines the dimensions that make "dog" discoverable for the relevant queries:

  1. "I can be a subject of actions" (dim 2, 4)
  2. "I can be modified by adjectives" (dim 1, 3, 391)

Real sentence example: In "The big dog ran", adjectives like "big" have queries that match with the dog's key because both have strong "modifiable entity" signals.


Value transformation:

Combine the dimensions with actual informational content of the "dog":

  1. Semantic meaning of the specific animal (dim 1, 128)
  2. Grammatical and functional role (dim 4, 247, 391)

Now zoom out a bit. This is what is happening for the one single word, but the main advantage of the attention mechanism is parallelism. That means that attention processes all keys and values for a sequence of tokens at once inside an attention head.

To zoom out even more, let's remember that there are usually a couple of heads computing Q, K, and V at the same time. Each head contains different W_q, W_k, and W_v matrices. After concatenating the results of each head, we receive an attention output. It combines all features of each head, creating a rich representation of each token.

image

That's how Q, K, and V matrices are computed.

Attention scores (Q * Kt)

Note: Why Kt (transposed)? Because otherwise the matrix multiplication is not possible.
[m, n] x [n, m] = [m, m], not [m,n] x [m, n]

By multiplying Q * Kt, we are going element-wise and performing a dot product operation. The dot product operation result shows how similar 2 vectors are. Similarity between vectors is important for understanding how much attention should be paid to each token with regard to another. The more similar the tokens are (point in similar direction), the higher the attention score.

Example:

Input: "The dog chased the ball"

Note: for simlicity we will use 3-dimensional vectors

a = [ a1, a2, a3 ]

b = [ b1, a2, a3 ]

a * b = a1b1 + a2b2 + a3b3

Calculation:

"chased" (query) = [ 0.7, -0.5, 0.6 ]  
"dog" (key) = [ 0.8, 0.6, 0.7 ]
Dot product = ( 0.7 * 0.8 ) + (-0.5 * 0.6 ) + ( 0.6 * 0.7 ) = 0.68

The high dot product indicates strong alignment between the query "chased" and the key "dog", resulting in high attention.

Masking

Even though it seems to be obvious that we need to apply a mask to prevent the model from seeing future tokens, I want to stop here and provide as many details as possible:

  1. How is a mask applied during inference when we do not have any future tokens?
  2. What types of masks exist?
  3. Do we need to use it always, or are there exceptions?

Let's start with a more intuitive concept of a causal mask in a sentence "The quick brown fox". After transforming the 4-word sequence into embedding vectors and calculating attention scores, we will get the following matrix:

image

Next step is to create a mask. But before that, we need to do a little bit of planning here:

What is happening? We are trying to train the model to predict next words in the sentence.
What is the purpose of mask? While training the model we have a bunch of sequences in a training data. We want to use them to check if the model can predict the next word. In order to do that we should hide the future tokens in the sequence.
How to hide them and not ruin the flow? The next step of attention mechanism is a softmax. Sotmax is a function that outputs row values that sum up to 1 and are always in between 0 and 1. The formula looks like this:

image
In order to hide the scores they should be equal to 0 (have no attention).
What input will result in 0 output when put into this formula? That should be a really small number... How about -inf?
Fine, seems like we come up with the plan!

Thus, our mask should look like this:

image

Now adding the mask to the attention scores:

image

Zeros did not change the unmasked values.
-Inf is in practice represented by a very small value like -1e9, or -1e4.

So the result is a very small number anyway.

Now applying softmax:

image

We got the result! Here is also a bit more interactive visualization

Now questions.

What is the model going to predict if the 4th token is the last token in the sentence?

The model should predict the end-of-sentence token. How? The training data should include it in samples to give the model a clear signal to stop generating (if the task is generating tokens).

Ok, I have just described what usually happens while training. But during inference we want to have access to all the tokens at once, we need them to make the best prediction. Then why apply the mask?

There is no good answer to this question. The transformer architecture is optimized for training due to the ability to make parallel computations and use matrix multiplication. At inference time, the code remains the same, and all the calculations are done for each token, even though we need to compute only the last one. This is one of the reasons why such optimizations as KV cache, Grouped Query Attention, Multi-query Attention, Sparse attention and some others exist.

Do we need to always apply the mask?

Well, if we are dealing with a decoder transformer – then yes. Otherwise, models that are using bidirectional attention don't use causal masking at all. For example, the encoder-only architecture that I mentioned at the beginning. The tasks might include sentiment analysis or classification, which requires access to all the tokens at once.

Another mask

I also wanted to mention another type of mask called a padding mask. A padding mask is used in transformer models to handle various sequence lengths in a batch. If there are sequences of length 3, 4, 5 in a batch, they are going to be padded and look like this:

image

The idea is to mark real tokens with 1 and < PAD > tokens with 0 like this:

image

After applying the mask to the existing sequence, we get:

image

The padding mask hides meaningless tokens and can be used in combination with a causal one.


Softmax

Softmax is applied to masked attention scores through this formula:

image

Where:

xi - input value on position i.
exi - exponential of input value.
denominator – sum of all input values exponentials.

Softmax converts input sequences to the weights so that the larger score get the larger weight proportionally.

Why is it called "Softmax"?

The name "softmax" comes from it being a "soft" version of a "max" function:

  • "max" function simply selects the largest value and gives it a probability of 1, while others get 0 ([ 1, 2, 3, 4] -> [0, 0, 0, 1])
  • "softmax" instead assigns values proportionally to the original ones so that every item has a value and the sum still equals to 1 ([ 1, 2, 3, 4 ] -> [ 0.032, 0.087, 0.236, 0.645 ])

Why does it work?

  1. ex is always positive
  2. ex can easily be differentiated
  3. Division by the sum of exponentials always gives you values < 1

Context vectors (attention head output)

This should be an easy one. After the previous step, we got attention weights. Now, multiplying them by V matrix, we come up with a new matrix that contains the weighted amount of real information across all token dimensions. It also emphasizes the most relevant features for building better understanding of every particular token in a given context.


Attention heads results concatenation

Remember that now the final dimension of each head output equals emb_dim / head_count. In order to come back to the original embedding dimension, we just simply concatenate all the outputs together. Different heads captured different patterns during the forward pass, so now the final output does not miss anything.


Final projection

Isn't it strange just to concatenate all the heads outputs? Suspiciously simple.

Anyway, we need a final projection for a reason. Final projection is a linear transformation whose main functions are:

  • mixing the info received from the heads
  • enrich attention outputs

The good analogy might be imagining several experts working on different pieces of a complex task. Each produces a highly specialized report on their topic. The projection layer plays the role of a representor, who combines all in one, producing a concise but dense report, like a summary with highlighted connections between findings.


KV cache

I have already mentioned that the original attention mechanism has some inefficiencies that raise the computation cost. KV cache is one of the optimization techniques used during inference to avoid redundant computations of keys and values that always stay the same and there is no need for recomputing.

The idea: Cache keys and values that have already been computed.
Speed improvement: 2-10x depending on model size.
Trade-offs: cache takes additional memory, especially with a large context. Check the visualization here.
Computational complexity: O(n2) -> O(n)

Complexity is a number of operations.

Without KV cache:

Let:

L – sequence length
d - embedding dimensionality

Computing attention scores, each query attends to every key, resulting in quadratic complexity:

Q: shape [ L, d ]
K: shape [ L, d ]
Q*K_t: shape [ L, L ]
Complexity: O( L² * d )

This is followed by applying the attention weights to the value matrix:

V: shape [L, d]
Multiply [L, L] * [L, d] → O(L² * d)
The total complexity per block: O( n<sub>layers</sub> * L<sup>2</sup> * d )

With KV cache:

One new query attends to cached K/V.

Complexity O( L * d ) - linear 

Shape: [ batch_size, head_count, seq_length, head_dim ].

Here is also an implementation and an interactive visualization of a concept.

Thanks for reading!