July 25, 2025

Understanding Self-Attention Using PyTorch

A step-by-step guide to the self-attention mechanism from scratch.

Introduction

We are in 2025, and LLMs have come a long way. Today, we have reasoning models that take time to think and then answer the complex questions. These reasoning models are LLMs that share the same core component found in language models used for text classification, such as BERT and RoBERTa. 

The idea in these early models stands as the foundation of today’s reasoning models. One of which is the self-attention mechanism. 

It powers GPT, BERT, Vision Transformers, and diffusion models. This mechanism allows models to weigh different parts of a sequence dynamically. Each element can attend to any other element regardless of distance.

Self-attention eliminated the sequential processing bottlenecks of RNNs. It enabled parallel computation across entire sequences. The transformer architecture made this revolutionary change possible. Models now understand context better than previous approaches.

In this article, I will discuss how we transitioned from RNN to CNN and then to transformers. I will implement a simple attention mechanism in PyTorch, providing a better understanding of how the attention mechanism works. 

Note: The code explained in this blog is available in this Colab Notebook.

Quick Glossary: query, key, value, head, context window

  • Query (Q) represents what we're looking for in the sequence. It acts like a search term in a database lookup. Each position generates its own query vector.
  • Key (K) represents what we're comparing against. Keys help determine relevance between different positions. They work alongside queries to compute attention scores.
  • Value (V) contains the actual content to retrieve. Values get weighted by attention scores. The final output combines values based on computed importance.
  • Head refers to independent attention computations. Multiple heads capture different types of relationships. Each head focuses on different aspects of the sequence.
  • Context window defines the sequence length the model processes. It limits how many tokens can attend to each other. Larger windows enable longer-range dependencies but increase computational cost.

The Leap from RNN/CNN to the Transformer Encoder-Decoder Architecture

RNNs processed sequences one step at a time. This created computational bottlenecks for long sequences. Vanishing gradients made learning long-range dependencies difficult. Information from early tokens often got lost by the final output.

CNNs on the other hand, used fixed receptive fields to capture local patterns. They needed many layers to see long-range dependencies. Each layer could only look at a small window of the sequence. Global context required deep stacking of convolutional layers.

The transformer architecture eliminated recurrence entirely. It allowed direct connections between any positions in a sequence. Parallel processing became possible across all sequence elements. Distance between tokens no longer mattered for computation.

Self-Attention vs Traditional Attention Mechanisms

Traditional attention mechanisms like Bahdanau attention computed relationships between encoder and decoder states. They were add-ons to existing RNN architectures. The attention helped decoders focus on relevant encoder outputs during generation.

RNN-based attention by Bahdanau et al., 2015. | Source: Attention? Attention!

Self-attention computes relationships within the same sequence. It became the core mechanism rather than an auxiliary component. Each position can attend to all positions in the input simultaneously.

Illustration of Scaled-dot-attention that does use RNN. | Source: Attention is all you need.

Traditional attention was limited by the underlying RNN processing. Self-attention enabled complete parallelization. Models could now capture complex patterns without sequential constraints. This fundamental shift made modern large language models possible.

How Self-Attention Works

Self-attention allows each word to examine every other word in the sequence. Each position decides how much to focus on other positions based on relevance. This creates rich contextual representations where word meanings are influenced by relationships with all other words.

The mechanism computes three matrices from input embeddings: queries (Q), keys (K), and values (V). Each word generates its own query vector to search for relevant information. Keys represent what each position offers to the search. Values contain the actual content to retrieve.

Illustration of how the attention mechanism can establish a contextual relationship between words. | Source: Attention? Attention!

First, QK^T computes similarity scores between queries and keys using dot products.

What is √dk?

The √dk scaling prevents large values that push softmax into saturation regions with tiny gradients. Softmax converts raw scores into probability distributions summing to 1. Finally, these attention weights create a weighted average of values.

This mechanism has O(n²) complexity for sequence length n. Every position attends to every other position creating an n×n attention matrix. This quadratic scaling becomes problematic for very long sequences like entire documents. The complexity is the price paid for allowing direct connections between all positions.

The Role of Multi-Head Self-Attention

Multiple attention heads enable the model to focus on different relationship types simultaneously. One head might capture syntactic patterns while another identifies semantic similarities. This parallel processing provides richer representations than single-head attention.

Illustration of multihead attention. | Source: Attention is all you need.

Each head uses independent projection matrices WQ, WK, WV to transform inputs into different query, key, and value spaces. The process involves parallel computation across h heads, followed by concatenation of outputs. A final linear transformation combines all head outputs:

Think of multiple experts examining the same data from different perspectives. BERT-base uses 12 heads while the original Transformer employed 8 heads. 

Head dimension typically equals d_model/num_heads to maintain reasonable computational cost. This distribution enables efficient parallel processing while capturing diverse attention patterns across multiple representation subspaces.

Variants, Optimizations & When to Use Them

Self-attention has evolved beyond its original form to address specific computational constraints and use cases. Modern implementations introduce specialized variants that maintain the core mechanism while optimizing for different requirements.

Causal attention for autoregressive LLMs masks future positions by setting attention scores to -∞ before softmax. This prevents information leakage during training. Each token can only attend to previous tokens in the sequence. The lower triangular mask ensures autoregressive generation for language modeling tasks like GPT.

Self vs cross attention vs encoder-decoder attention serve different purposes. Self-attention operates within the same sequence for contextualization. Cross-attention operates between two different sequences, like image features attending to text in multimodal models. Encoder-decoder attention allows decoder queries to attend to encoder keys/values. BERT uses only self-attention, while Transformers combine both self-attention and cross-attention for machine translation tasks.

FlashAttention & linear attention address efficiency concerns for long contexts. FlashAttention uses memory-efficient computation avoiding the full attention matrix materialization. Linear attention variants reduce complexity from O(n²) to O(n). Sliding window attention handles extremely long sequences. These enable processing books and entire documents that would be prohibitively expensive with standard attention.

Grouped-query & rotary embeddings optimize modern implementations. Grouped-query attention shares key/value heads across multiple query heads, reducing memory bandwidth. Rotary position embeddings encode position information directly into attention computation rather than adding to inputs. GQA reduces KV cache size for inference while RoPE provides better length extrapolation for sequences longer than training data.

PyTorch Tutorial: Building Self-Attention From Scratch

The best way to understand the attention mechanism is by coding. I will be using PyTorch to implement the attention mechanism and essentially build BERT, one of the most popular language model. Please note that this code is provided for educational purposes only. The attention mechanism can be much more complex, especially in today’s LLMs. 

But the idea is that I can give you a general idea of how it works.

To begin with, let’s import all the dependencies. 

Dependencies.

We will be using the math library to implement a couple of functions. The re library will preprocess the string data that we will be feeding into the transformer. np to handle all the CPU-related functions, random for generating pseudo-random numbers, and torch for implementing the attention mechanism, BERT, and all the neural network functions. 

Once we have imported the basic dependencies, we can now work on creating a function to make batches for the training data. 

make_batch creates training batches for BERT by predicting masked words and sentence relationships.

make_batch.

What it does:

  • It takes two random sentences and combines them with special tokens ([CLS], [SEP]).
  • Masks 15% of words: 80% become [MASK], 10% become random words, 10% stay unchanged.
  • Creates segment IDs to distinguish sentence A from sentence B • Pads sequences to fixed length with zeros • Labels whether sentences are consecutive (IsNext/NotNext).

Output: Each batch contains equal numbers of consecutive and non-consecutive sentence pairs for training.

Let’s create a function to ignore padding tokens during BERT's self-attention computation.

get_attn_pad_mask.

The get_attn_pad_mask,

  • Finds zero-padded positions in input sequences.
  • Creates a mask matrix marking padded positions as "ignore".
  • Expands mask to match attention matrix dimensions (batch_size × query_length × key_length).

Purpose: Prevents the model from attending to meaningless padding tokens during training.

Now, we can write an activation function. For BERT, we will use the GeLU activation function.

GeLU.

It is a smooth alternative to ReLU used in BERT. It multiplies the input by a probability based on how many standard deviations it is from zero in a normal distribution.

We can now write an embedding class.

Embedding.

This class creates BERT's input embeddings by combining three types of information:

  • Token embeddings: Convert words to vectors.
  • Position embeddings: Add sequence position information.
  • Segment embeddings: Distinguish sentence A from sentence B.

Process: Sums all three embeddings and applies layer normalization to create final input representations.

Now, the important class, the attention, or the scaled dot product attention. 

ScaledDotProductAttention.

This class implements BERT's core attention mechanism that determines which words to focus on. It,

  • Computes attention scores by multiplying queries and keys, then scales by √d_k. 
  • Masks padding tokens by setting their scores to -∞.
  • Applies softmax to create attention weights.
  • Multiplies weights with values to produce contextualized representations.

To enhance the functionality of the attention mechanism, we need to create a multihead attention class. 

MultiHeadAttention.

This class implements multi-head attention, allowing BERT to focus on different aspects simultaneously. With multihead attention mechanism BERT can,

  • Projects inputs into multiple query, key, value representations. 
  • Splits into n_heads parallel attention computations. 
  • Applies scaled dot-product attention to each head. 
  • Concatenates heads and projects back to original dimensions. 
  • Adds residual connection and layer normalization.

Now, we have to ensure that we process each token using the PoswiseFeedForwardNetclass.  

PoswiseFeedForwardNet.

The PoswiseFeedForwardNet class implements the position-wise feed-forward network that:

  • Expands input dimensions from d_model to d_ff using first linear layer.
  • Applies GELU activation function.
  • Contracts back to d_model dimensions using second linear layer.
  • Processes each position separately with same transformation.

Let’s now create an encoder layer. 

EncoderLayer.

The encoder layer combines self-attention and feed-forward processing together. Here are the important components:

  • Multi-head self-attention: Tokens attend to all other tokens in sequence. 
  • Position-wise feed-forward network: Processes each token independently.
  • Uses same input as queries, keys, and values for self-attention.
  • Outputs contextualized representations and attention weights.

Now, we can create the final class that will combine all the class into one enclosure – BERT.

BERT.

This class implements the complete BERT model architecture for pre-training:

Architecture:

  • Embedding layer converts tokens to vectors.
  • Multiple encoder layers process sequences through self-attention.
  • Two output heads: masked language modeling and next sentence prediction.

Forward pass:

  • Processes embeddings through encoder stack.
  • Uses [CLS] token for sentence classification.
  • Extracts masked positions for word prediction.
  • Shares embedding weights with decoder.

This is how you can implement a simple attention mechanism with BERT. 

I have even trained this model on this Colab Notebook. Please do run the entire notebook to get a full idea of how each component works.

FAQ – People Also Ask