LLM 学习笔记:从 NLP 基础到 Transformer 架构与代码实现

LLM 学习笔记:从 NLP 基础到 Transformer 架构与代码实现

自然语言处理(NLP)

NLP 是研究人类语言与计算机之间交互的科学,目标是让计算机能够理解处理生成自然语言。

文本表示

文本表示是将自然语言转化为计算机能够理解和运算的数字形式(即向量)的技术。

基于向量构建的词向量空间模型(VSM),将文本内容转化为高维空间中的向量,实现从语言符号域到数字域的映射,使得计算机可以进行数学计算和分析。

文本表示依赖于语言模型,其核心作用是:基于给定上下文,建模语言的概率分布

语言概率分布统计模型

  1. N-Gram 模型:基于统计频率,当前词依赖于前 N-1 个词
  2. RNN / LSTM 模型:通过循环结构记忆前面出现的词
  3. Transformer:以注意力机制(Self-Attention)为核心,为每个词构建向量表示,通过注意力计算词之间的关系,可并行计算并捕获更优的上下文信息

Transformer 架构

Transformer 最初由 Vaswani 等人于 2017 年在论文 《Attention is All You Need》 中提出,是目前 NLP 与多模态大模型的核心架构。

它以全注意力机制为核心,彻底摆脱了 RNN 的时序依赖,具备更强的并行计算能力和上下文建模能力。

全注意力机制的核心设计思想:让序列中的每一个元素都可以动态地关注到整个序列中的其他元素,从而更好地理解上下文信息。

在数学上,上下文信息体现为每个元素之间关联关系的权重大小。

注意力机制(Attention)

注意力机制的作用是:给定文本序列,通过计算每个词对其他所有词的关注程度(关系权重),实现计算机对文本序列的理解。

输入的词向量表示

每个输入通过 Embedding + 位置编码(Positional Encoding) 获得向量表示:

$$X = [x_1, x_2, …, x_n], \quad x_i \in \mathbb{R}^d$$

然后通过三个可学习的线性映射矩阵得到 Q、K、V:

$$Q = XW_Q,\quad K = XW_K,\quad V = XW_V$$

  • Q 向量(查询向量):决定"我需要关注什么"
  • K 向量(被查询向量):决定"我是什么特征"
  • V 向量(内容向量):决定"我携带什么内容"

Q、K、V 全部来自输入 X,通过三个可学习矩阵 W_Q、W_K、W_V 投影而来。

注意力计算公式

$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}$$

  • QKᵀ:每个词对其他词的相关性打分(点积相似度计算)
  • 除以 √dₖ:防止点积值过大(内积爆炸)
  • Softmax:将相似度转化为和为 1 的注意力权重

注意力不是"显式地找主语"——模型并不会以符号化方式理解语法规则。它学到的是:在投影空间中,指代和被指代项的 Q/K 投影在训练数据下表现出相似/兼容的几何关系,从而产生高点积得分。这是**统计学(分布式表示)**而非符号规则的效果。

一个 Attention Head 的实际计算

“The animal didn’t cross the street because it was too tired.”

对于每个 token(以 “it” 为例):

  1. embedding + positional encoding → 输入向量 xᵢₜ
  2. xᵢₜ 用 W_Q 投影 → Qᵢₜ
  3. 所有 token 用 W_K 投影 → K_matrix
  4. 所有 token 用 W_V 投影 → V_matrix

Step 1 —— 点积得分(相似度):

$$s_j = \frac{Q_{it} \cdot K_j}{\sqrt{d_k}}$$

其中 j 遍历整句:animal, the, didn’t, … , it, street, tired

Step 2 —— Softmax(权重化):

$$a_j = \frac{e^{s_j}} {\sum e^{s_k}}$$

模拟计算结果:

token attention weight(示例)
animal 0.65
tired 0.18
it(self) 0.10
street 0.05
because 0.02

Step 3 —— 加权求和(聚合 V 向量):

$$\text{output}_{it} = \sum a_j V_j$$

该输出向量随后流经残差连接、LayerNorm、FFN,进入下一层……最终在语义上得出:it 指代的是 animal

Q、K、V 的总结

token Q 向量(查询) K 向量(被查询) V 向量(内容)
animal 由 X_animal 乘 W_Q 得到(无固定语义) 由 X_animal 乘 W_K 得到(只是一个向量) 由 X_animal 乘 W_V 得到(传递"动物"的语义特征)
it 由 X_it 乘 W_Q 得到(模型学到"它通常需要找先行词") 由 X_it 乘 W_K 得到 由 X_it 乘 W_V 得到

以上只有线性计算,没有任何标签或语义标识。

Transformer 并没有硬编码任何语言学规则,它只是学到了一个投影方式,使得:

  • Q 向量倾向于"问问题"
  • K 向量倾向于"带着可检索特征"
  • V 向量倾向于"携带内容"

其中 W_Q、W_K、W_V 是可训练参数,训练时随机初始化,通过反向传播不断更新。

# Scaled Dot-Product Attention (single head)
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q, K, V: (..., seq_len, d_k)  (supports batch and head dims via leading dims)
    mask: None or (..., seq_q, seq_k)  (True where we should mask)
    Returns: attn_output, attn_weights
    """
    d_k = Q.size(-1)
    # scores: (..., seq_q, seq_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        # mask True = blocked -> set to -inf
        scores = scores.masked_fill(mask, float('-inf'))
    attn = F.softmax(scores, dim=-1)  # attention weights across keys
    output = torch.matmul(attn, V)
    return output, attn

Transformer 结构详解

1. 输入嵌入 + 位置编码

1.1 词嵌入(Token Embedding)

输入句子:

the animal sleeps

Token → Vector(Embedding):$$\mathbb{Z} \Rightarrow \mathbb{R}^{d_{model}}$$

若 model dimension = 512:$$X \in \mathbb{R}^{(N, d_{model})}$$

最终输入句子维度为 (3, 512):

Token Embedding Shape
the (512,)
animal (512,)
sleeps (512,)

1.2 位置编码(Positional Encoding)

Transformer 没有 RNN,因此必须显式告诉模型词的位置信息。通过正余弦函数实现:

$$PE(pos, 2i) = \sin\bigg(\frac{pos}{10000^{2i/d_{model}}}\bigg)$$ $$PE(pos, 2i+1) = \cos\bigg(\frac{pos}{10000^{2i/d_{model}}}\bigg)$$

其维度同样为 (3, 512)。

最终输入:$$X_{input} = V_{Embedding} + V_{PositionalEncoding}$$

class PositionalEncoding(nn.Module):
    """
    Implements the classic sinusoidal positional encoding.
    Input:
      d_model: embedding dimension
      max_len: maximum sequence length to precompute
    Forward input: x shape (batch, seq_len, d_model)
    Returns: x + pos_encoding (same shape)
    """
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        # Create a (max_len, d_model) matrix of positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        # div_term: 10000^{2i/d_model}
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)      # even dims
        pe[:, 1::2] = torch.cos(position * div_term)      # odd dims
        pe = pe.unsqueeze(0)  # shape (1, max_len, d_model)
        self.register_buffer("pe", pe)  # not a parameter, but saved with the model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len].to(x.dtype)

2. 多头注意力机制(Multi-Head Self-Attention)

将 Q/K/V 分成 h 个 head,每个 head 都有自己独立的三个投影矩阵:

$$W^{(i)}_Q, W^{(i)}_K, W^{(i)}_V$$

并行计算 Attention 后拼接:

$$\text{MultiHead}(X) = \text{Concat}(head_1, …, head_h) W_O$$

其中:$$W_O \in \mathbb{R}^{h \cdot d_v \times d_{model}}$$

class MultiHeadSelfAttention(nn.Module):
    """
    Multi-head self-attention.
    - d_model: model dimension
    - num_heads: number of heads (d_model must be divisible by num_heads)
    Returns context vectors of shape (batch, seq_len, d_model)
    Also returns attention weights per head for inspection (batch, num_heads, seq, seq)
    """
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear layers to produce Q, K, V from input X
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        # Output linear layer
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len, d_model)
        mask: optional mask, shape broadcastable to (batch, num_heads, seq_len, seq_len)
        """
        B, S, D = x.size()
        # Project input to Q/K/V of shape (B, S, d_model)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Split heads: (B, S, num_heads, d_k) -> (B, num_heads, S, d_k)
        def split_heads(tensor):
            return tensor.view(B, S, self.num_heads, self.d_k).transpose(1, 2)

        Qh = split_heads(Q)
        Kh = split_heads(K)
        Vh = split_heads(V)

        # scaled dot-product per head
        attn_out, attn_weights = scaled_dot_product_attention(Qh, Kh, Vh, mask)
        # concat heads: (B, S, num_heads, d_k) after transpose back
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, D)
        # final linear projection
        out = self.W_o(attn_out)
        return out, attn_weights  # attn_weights: (B, num_heads, S, S)

3. 前馈神经网络(Feed Forward Network, FFN)

  • 自注意力机制负责跨词信息交换
  • FFN 负责每个词独立非线性变换

数学结构(两层 MLP):

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

class PositionwiseFeedForward(nn.Module):
    """
    Implements FFN: two linear layers with an activation in between.
    Applied independently at each position.
    """
    def __init__(self, d_model: int, d_hidden: int, activation=F.relu):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_hidden)
        self.fc2 = nn.Linear(d_hidden, d_model)
        self.activation = activation

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        return self.fc2(self.activation(self.fc1(x)))

4. 残差连接 + 层归一化

  • 防止梯度消失
  • 允许模型学习微量更新
  • 确保深层网络的训练稳定性
class TransformerEncoderBlock(nn.Module):
    """
    One Transformer encoder block:
    x -> x + MultiHead(LN(x)) -> LN -> x + FFN(LN(x))
    Note: We use 'pre-norm' style: layer norm before sublayer, which is often more stable.
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-norm for attention
        x_norm = self.norm1(x)
        attn_out, attn_weights = self.self_attn(x_norm, mask=mask)
        x = x + self.dropout(attn_out)  # residual connection

        # Pre-norm for FFN
        x_norm = self.norm2(x)
        ffn_out = self.ffn(x_norm)
        x = x + self.dropout(ffn_out)  # residual connection
        return x, attn_weights

5. 堆叠多层形成深层模型(通常 6/12/24 层)

  • 每一层完成一次"信息整合 + 非线性变换":底层学习低级模式(词法、短语),中层组合成更复杂的语义结构,高层捕获句子级或段落级的语义/关系
  • 单层 Attention 可以让每个 token 直接看见其他 token(全局视野),但堆叠多层能让信息在多步中被不断重写与迭代:第一层给出初步注意力分配,第二层基于第一层的输出调整、强化或抑制信息——本质是对关系进行多步推理(iterative refinement)
  • 堆叠和 FFN 中的非线性让网络可以逼近更复杂的函数,多层组合能表达高阶交互
  • 不同层捕获不同"尺度"的特征,分层有利于模型泛化和可解释性

最小可运行的 Transformer Encoder Stack(PyTorch)

# transformer_from_scratch.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------
# Positional Encoding
# ------------------------
class PositionalEncoding(nn.Module):
    """
    Implements the classic sinusoidal positional encoding.
    Input:
      d_model: embedding dimension
      max_len: maximum sequence length to precompute
    Forward input: x shape (batch, seq_len, d_model)
    Returns: x + pos_encoding (same shape)
    """
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        # Create a (max_len, d_model) matrix of positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        # div_term: 10000^{2i/d_model}
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)      # even dims
        pe[:, 1::2] = torch.cos(position * div_term)      # odd dims
        pe = pe.unsqueeze(0)  # shape (1, max_len, d_model)
        self.register_buffer("pe", pe)  # not a parameter, but saved with the model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len].to(x.dtype)

# ------------------------
# Scaled Dot-Product Attention (single head)
# ------------------------
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q, K, V: (..., seq_len, d_k)  (supports batch and head dims via leading dims)
    mask: None or (..., seq_q, seq_k)  (True where we should mask, typically)
    Returns: attn_output, attn_weights
    """
    d_k = Q.size(-1)
    # scores: (..., seq_q, seq_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        # mask True = blocked -> set to -inf
        scores = scores.masked_fill(mask, float('-inf'))
    attn = F.softmax(scores, dim=-1)  # attention weights across keys
    output = torch.matmul(attn, V)
    return output, attn

# ------------------------
# Multi-Head Attention
# ------------------------
class MultiHeadSelfAttention(nn.Module):
    """
    Multi-head self-attention.
    - d_model: model dimension
    - num_heads: number of heads (d_model must be divisible by num_heads)
    Returns context vectors of shape (batch, seq_len, d_model)
    Also returns attention weights per head for inspection (batch, num_heads, seq, seq)
    """
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear layers to produce Q, K, V from input X
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        # Output linear layer
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len, d_model)
        mask: optional mask, shape broadcastable to (batch, num_heads, seq_len, seq_len)
        """
        B, S, D = x.size()
        # Project input to Q/K/V of shape (B, S, d_model)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Split heads: (B, S, num_heads, d_k) -> (B, num_heads, S, d_k)
        def split_heads(tensor):
            return tensor.view(B, S, self.num_heads, self.d_k).transpose(1, 2)

        Qh = split_heads(Q)
        Kh = split_heads(K)
        Vh = split_heads(V)

        # scaled dot-product per head
        attn_out, attn_weights = scaled_dot_product_attention(Qh, Kh, Vh, mask)
        # concat heads: (B, S, num_heads, d_k) after transpose back
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, D)
        # final linear projection
        out = self.W_o(attn_out)
        return out, attn_weights  # attn_weights: (B, num_heads, S, S)

# ------------------------
# Position-wise Feed-Forward Network (FFN)
# ------------------------
class PositionwiseFeedForward(nn.Module):
    """
    Implements FFN: two linear layers with an activation in between.
    Applied independently at each position.
    """
    def __init__(self, d_model: int, d_hidden: int, activation=F.relu):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_hidden)
        self.fc2 = nn.Linear(d_hidden, d_model)
        self.activation = activation

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        return self.fc2(self.activation(self.fc1(x)))

# ------------------------
# Transformer Encoder Block
# ------------------------
class TransformerEncoderBlock(nn.Module):
    """
    One Transformer encoder block:
    x -> x + MultiHead(LN(x)) -> LN -> x + FFN(LN(x))
    Note: We use 'pre-norm' style: layer norm before sublayer, which is often more stable.
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-norm for attention
        x_norm = self.norm1(x)
        attn_out, attn_weights = self.self_attn(x_norm, mask=mask)
        x = x + self.dropout(attn_out)  # residual connection

        # Pre-norm for FFN
        x_norm = self.norm2(x)
        ffn_out = self.ffn(x_norm)
        x = x + self.dropout(ffn_out)  # residual connection
        return x, attn_weights

# ------------------------
# Transformer Encoder (stack N layers)
# ------------------------
class TransformerEncoder(nn.Module):
    """
    Stacks multiple TransformerEncoderBlock layers.
    Also contains embedding + positional encoding.
    """
    def __init__(self, vocab_size: int, d_model: int, num_heads: int,
                 d_ff: int, num_layers: int, max_len: int = 512, dropout=0.1):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len=max_len)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, input_ids, mask=None):
        """
        input_ids: (batch, seq_len) token ids
        mask: optional boolean mask where True indicates positions to mask (e.g., padding)
              We will expand it to (batch, num_heads, seq_len, seq_len) when needed
        """
        x = self.embed_tokens(input_ids)  # (B, S, d_model)
        x = self.pos_enc(x)

        # Build attention mask expected shape: (B, 1, 1, S) or broadcastable.
        attn_maps = []  # collect attention maps for debugging/inspection
        for layer in self.layers:
            # Prepare mask for heads if provided
            if mask is not None:
                # mask: (B, S) where True denotes PAD; we need (B, num_heads, S, S)
                mask_k = mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,S)
                layer_mask = mask_k
            else:
                layer_mask = None

            x, attn = layer(x, mask=layer_mask)
            attn_maps.append(attn.detach() if attn is not None else None)

        x = self.norm(x)
        return x, attn_maps  # attn_maps: list length=num_layers, each (B, num_heads, S, S)

# ------------------------
# Demo: run a tiny example and print attention maps per head
# ------------------------
if __name__ == "__main__":
    # tiny vocab and short sequence to demonstrate
    vocab_size = 50
    d_model = 64
    num_heads = 4
    d_ff = 256
    num_layers = 3
    batch = 2
    seq_len = 6

    model = TransformerEncoder(vocab_size, d_model, num_heads, d_ff, num_layers, max_len=32)
    # random toy token ids
    input_ids = torch.randint(0, vocab_size, (batch, seq_len))
    # padding mask example: suppose last two tokens in batch index 1 are padding
    pad_mask = torch.zeros((batch, seq_len), dtype=torch.bool)
    pad_mask[1, -2:] = True

    outputs, attn_maps = model(input_ids, mask=pad_mask)

    print("Outputs shape:", outputs.shape)  # (batch, seq_len, d_model)
    for layer_idx, attn in enumerate(attn_maps):
        print(f"Layer {layer_idx} attention shape:", attn.shape)
        avg_attn = attn[0].mean(dim=0)  # (S, S)
        print(f"Layer {layer_idx} avg attention (batch 0) shape:", avg_attn.shape)
        print(avg_attn)

训练 + 注意力可视化

  1. 构造极小的训练语料(可人工构造)
  2. 训练一个 Tiny Transformer(2 层、4 heads)少量 step
  3. 记录每个 head 的注意力权重
  4. 实时绘制每个 head 的 Attention Heatmap
  5. 观察"每个 head 学到了什么语义模式"

目标任务:让 Transformer 预测下一个词(语言模型训练)

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np


# =========================
# 1. 构造极小语料
# =========================
sentences = [
    "the cat likes fish",
    "the dog hates fish",
    "the cat eats fish",
    "the dog likes meat",
    "the girl likes cat",
    "the boy hates dog",
]

# 构造词表
words = sorted(list(set(" ".join(sentences).split())))
stoi = {w:i for i,w in enumerate(words)}
itos = {i:w for w,i in stoi.items()}
vocab_size = len(words)
print("vocab:", words)

def encode(sentence):
    return torch.tensor([stoi[w] for w in sentence.split()])

encoded = [encode(s) for s in sentences]


# =========================
# 2. 位置编码(正弦)
# =========================
def positional_encoding(seq_len, dim):
    pe = torch.zeros(seq_len, dim)
    pos = torch.arange(0, seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2) * -(np.log(10000.0) / dim))
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    return pe


# =========================
# 3. Multi-Head Self-Attention(可视化支持)
# =========================
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=32, num_heads=4):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_O = nn.Linear(embed_dim, embed_dim, bias=False)

        # 用于收集可视化输出
        self.last_attention = None  # (heads, seq, seq)

    def forward(self, x):
        B, T, C = x.shape

        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)

        # 记录注意力
        self.last_attention = attn.detach().cpu()

        out = attn @ V
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.W_O(out)


# =========================
# 4. 前馈网络(FFN)
# =========================
class FeedForward(nn.Module):
    def __init__(self, dim=32, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, dim)
        )
    def forward(self, x):
        return self.net(x)


# =========================
# 5. Transformer Block(含残差+LayerNorm)
# =========================
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=32, heads=4, hidden=64):
        super().__init__()
        self.mha = MultiHeadAttention(embed_dim, heads)
        self.ln1 = nn.LayerNorm(embed_dim)

        self.ffn = FeedForward(embed_dim, hidden)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


# =========================
# 6. Tiny Transformer 模型(2 层)
# =========================
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, seq_len=4, embed_dim=32, n_layers=2):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = positional_encoding(seq_len, embed_dim)
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, 4, 64) for _ in range(n_layers)])
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        B, T = x.shape
        h = self.token_emb(x) + self.pos_emb[:T]

        for layer in self.layers:
            h = layer(h)

        logits = self.fc(h)
        return logits


# =========================
# 7. 训练模型
# =========================
model = TinyTransformer(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_step():
    total_loss = 0
    for seq in encoded:
        x = seq[:-1].unsqueeze(0)
        y = seq[1:].unsqueeze(0)

        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(encoded)

# 训练30轮
for epoch in range(30):
    print("epoch", epoch, "loss", train_step())


# =========================
# 8. 注意力可视化函数
# =========================
def plot_attention(attn, sentence_tokens):
    num_heads = attn.shape[0]
    seq_len = len(sentence_tokens)

    fig, axes = plt.subplots(1, num_heads, figsize=(3*num_heads, 3))

    for h in range(num_heads):
        ax = axes[h]
        ax.imshow(attn[h], cmap="hot")
        ax.set_xticks(range(seq_len))
        ax.set_yticks(range(seq_len))
        ax.set_xticklabels(sentence_tokens)
        ax.set_yticklabels(sentence_tokens)
        ax.set_title(f"Head {h}")

    plt.show()


# =========================
# 9. 测试 + 输出可视化
# =========================
test = encode("the cat likes fish")[:-1].unsqueeze(0)
_ = model(test)  # 前向一次,attention 已记录

attn = model.layers[0].mha.last_attention[0]  # 第1层,第1个batch
tokens = "the cat likes".split()

plot_attention(attn, tokens)

qEncoder–Decoder Transformer

用于翻译

  • 实现 Encoder、Decoder(包含自注意力 + Encoder-Decoder Attention)、位置编码、FFN、残差 + LayerNorm
  • 加入 Decoder 的 Causal Mask(自回归)与 Padding Mask
  • 使用极小的平行语料做示例训练(Teacher Forcing)
  • 推理阶段可视化 Encoder-Decoder Attention(按 Decoder 层与 head 展示 heatmap),标注 source/target tokens
# transformer_enc_dec_translation.py
# Minimal Encoder-Decoder Transformer for toy translation + encoder-decoder attention visualization
# Requires: torch, matplotlib
# Run: python transformer_enc_dec_translation.py

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# =========================
# Utilities: Positional Encoding
# =========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        # x shape: (B, T, d_model)
        return x + self.pe[:, : x.size(1)].to(x.dtype)

# =========================
# Scaled dot-product attention (supports Q, K, V with different lengths)
# Returns (context, attn_weights)
# attn_weights shape: (B, num_heads, T_q, T_k)
# =========================
def scaled_dot_product_attention(Q, K, V, mask=None):
    # Q, K, V: (..., T, d_k) with leading dims (B, heads)
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # (..., T_q, T_k)
    if mask is not None:
        # mask shape expected broadcastable to scores (True = mask out)
        scores = scores.masked_fill(mask, float("-1e9"))
    attn = F.softmax(scores, dim=-1)
    out = torch.matmul(attn, V)
    return out, attn

# =========================
# MultiHeadAttention (can be used for self-attn and enc-dec attn)
# If kv is provided (kv != x) then it's enc-dec use: Q from x, K/V from kv
# =========================
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # projectors
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv=None, mask=None):
        """
        x: (B, T_q, d_model)  -> queries
        kv: None or tensor (B, T_k, d_model) -> if None then self-attend (keys & vals from x)
        mask: broadcastable mask (B, 1 or heads, T_q, T_k) True where to mask
        """
        if kv is None:
            kv = x
        B, T_q, _ = x.size()
        T_k = kv.size(1)

        Q = self.W_q(x)                       # (B, T_q, d_model)
        K = self.W_k(kv)                      # (B, T_k, d_model)
        V = self.W_v(kv)                      # (B, T_k, d_model)

        # reshape -> (B, heads, T, d_k)
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.d_k).transpose(1, 2)

        Qh = split_heads(Q)   # (B, heads, T_q, d_k)
        Kh = split_heads(K)   # (B, heads, T_k, d_k)
        Vh = split_heads(V)   # (B, heads, T_k, d_k)

        out, attn = scaled_dot_product_attention(Qh, Kh, Vh, mask=mask)
        # out: (B, heads, T_q, d_k)
        out = out.transpose(1, 2).contiguous().view(B, T_q, self.d_model)
        return self.W_o(out), attn  # attn: (B, heads, T_q, T_k)

# =========================
# Feed-forward network
# =========================
class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

# =========================
# Encoder Layer (self-attn + ffn)
# =========================
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        # Pre-norm style
        q = self.norm1(x)
        attn_out, attn_w = self.self_attn(q, kv=None, mask=src_mask)
        x = x + self.drop(attn_out)
        q2 = self.norm2(x)
        f = self.ffn(q2)
        x = x + self.drop(f)
        return x, attn_w

# =========================
# Decoder Layer (self-attn (causal) + enc-dec attn + ffn)
# We will return both self-attn weights and encoder-decoder attn weights for visualization
# =========================
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, enc_mask=None):
        # x: (B, T_tgt, d)
        # enc_out: (B, T_src, d)
        q1 = self.norm1(x)
        sa_out, sa_w = self.self_attn(q1, kv=None, mask=tgt_mask)  # causal self-attn
        x = x + self.drop(sa_out)

        q2 = self.norm2(x)
        ed_out, ed_w = self.enc_dec_attn(q2, kv=enc_out, mask=enc_mask)  # enc-dec attn
        x = x + self.drop(ed_out)

        q3 = self.norm3(x)
        f = self.ffn(q3)
        x = x + self.drop(f)
        return x, sa_w, ed_w

# =========================
# Full Encoder & Decoder stacks
# =========================
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_len=50):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, src_ids, src_mask=None):
        # src_ids: (B, T_src)
        x = self.tok_emb(src_ids)  # (B, T_src, d)
        x = self.pos(x)
        attn_maps = []
        for layer in self.layers:
            x, attn = layer(x, src_mask)
            attn_maps.append(attn)
        x = self.norm(x)
        return x, attn_maps

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_len=50):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, tgt_ids, enc_out, tgt_mask=None, enc_mask=None):
        x = self.tok_emb(tgt_ids)
        x = self.pos(x)
        all_self_attn = []
        all_enc_dec_attn = []
        for layer in self.layers:
            x, sa_w, ed_w = layer(x, enc_out, tgt_mask=tgt_mask, enc_mask=enc_mask)
            all_self_attn.append(sa_w)
            all_enc_dec_attn.append(ed_w)
        x = self.norm(x)
        logits = self.fc_out(x)  # (B, T_tgt, vocab)
        return logits, all_self_attn, all_enc_dec_attn

# =========================
# Full Seq2Seq model wrapper
# =========================
class TinyTransformerSeq2Seq(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=64, num_layers=2, num_heads=4, d_ff=128, max_len=50):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, num_layers, num_heads, d_ff, max_len=max_len)
        self.decoder = Decoder(tgt_vocab, d_model, num_layers, num_heads, d_ff, max_len=max_len)

    def forward(self, src_ids, tgt_ids, src_mask=None, tgt_mask=None, enc_mask=None):
        enc_out, enc_attn = self.encoder(src_ids, src_mask)
        logits, self_attn, enc_dec_attn = self.decoder(tgt_ids, enc_out, tgt_mask=tgt_mask, enc_mask=enc_mask)
        return logits, enc_attn, self_attn, enc_dec_attn

# =========================
# Masks
# =========================
def make_src_padding_mask(src_ids, pad_idx=0):
    # True where padding (to be masked)
    return (src_ids == pad_idx).unsqueeze(1).unsqueeze(2)  # (B,1,1,T_src)

def make_tgt_padding_mask(tgt_ids, pad_idx=0):
    return (tgt_ids == pad_idx).unsqueeze(1).unsqueeze(3)  # (B,1,T_tgt,1) - to combine with causal mask

def make_causal_mask(tgt_len):
    # causal mask: True where j > i (mask future)
    # shape (1, 1, T_tgt, T_tgt)
    mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.bool), diagonal=1)
    return mask.unsqueeze(0).unsqueeze(0)

# =========================
# Toy parallel corpus (English -> "Target")
# =========================
src_sentences = [
    "i eat fish",
    "i like fish",
    "you eat meat",
    "i eat meat",
    "she likes fish",
    "he hates meat"
]

tgt_sentences = [
    "je mange poisson",     # pretend target language tokens
    "je aime poisson",
    "tu mange viande",
    "je mange viande",
    "elle aime poisson",
    "il deteste viande"
]

# build src vocab
src_tokens = sorted(list({tok for s in src_sentences for tok in s.split()}))
src_stoi = {w:i+1 for i,w in enumerate(src_tokens)}  # reserve 0 for PAD
src_stoi["<bos>"] = len(src_stoi)+1
src_stoi["<eos>"] = len(src_stoi)+1
src_itos = {i:w for w,i in src_stoi.items()}

# build tgt vocab
tgt_tokens = sorted(list({tok for s in tgt_sentences for tok in s.split()}))
tgt_stoi = {w:i+1 for i,w in enumerate(tgt_tokens)}  # 0 pad
tgt_stoi["<bos>"] = len(tgt_stoi)+1
tgt_stoi["<eos>"] = len(tgt_stoi)+1
tgt_itos = {i:w for w,i in tgt_stoi.items()}

# encode helpers
def encode_src(s):
    toks = s.split()
    ids = [src_stoi["<bos>"]] + [src_stoi[t] for t in toks] + [src_stoi["<eos>"]]
    return torch.tensor(ids, dtype=torch.long)

def encode_tgt(s):
    toks = s.split()
    ids = [tgt_stoi["<bos>"]] + [tgt_stoi[t] for t in toks] + [tgt_stoi["<eos>"]]
    return torch.tensor(ids, dtype=torch.long)

src_data = [encode_src(s) for s in src_sentences]
tgt_data = [encode_tgt(s) for s in tgt_sentences]

# pad sequences to max lengths
max_src_len = max([x.size(0) for x in src_data])
max_tgt_len = max([x.size(0) for x in tgt_data])

def pad_batch(seq_list, max_len):
    padded = []
    for s in seq_list:
        if s.size(0) < max_len:
            pad = F.pad(s, (0, max_len - s.size(0)), value=0)
            padded.append(pad)
        else:
            padded.append(s)
    return torch.stack(padded)

src_batch = pad_batch(src_data, max_src_len)  # (N, T_src)
tgt_batch = pad_batch(tgt_data, max_tgt_len)  # (N, T_tgt)

# =========================
# Model instantiation & training (tiny, demo)
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyTransformerSeq2Seq(
    src_vocab=max(src_stoi.values())+1,
    tgt_vocab=max(tgt_stoi.values())+1,
    d_model=64,
    num_layers=2,
    num_heads=4,
    d_ff=128,
    max_len=max(max_src_len, max_tgt_len)+2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training with teacher forcing: predict next target token
epochs = 200
model.train()
for ep in range(epochs):
    total_loss = 0.0
    optimizer.zero_grad()
    # we'll do full-batch for simplicity
    src_ids = src_batch.to(device)
    tgt_ids = tgt_batch.to(device)
    # inputs for decoder are all tokens except last; targets are all tokens except first
    decoder_input = tgt_ids[:, :-1]   # (B, T_tgt-1)
    decoder_target = tgt_ids[:, 1:]   # (B, T_tgt-1)

    # masks
    src_pad_mask = make_src_padding_mask(src_ids, pad_idx=0).to(device)  # (B,1,1,T_src)
    tgt_pad_mask = make_tgt_padding_mask(decoder_input, pad_idx=0).to(device)  # (B,1,T_tgt-1,1)
    causal = make_causal_mask(decoder_input.size(1)).to(device)  # (1,1,T_tgt-1,T_tgt-1)
    tgt_mask = (tgt_pad_mask | causal)  # broadcastable to (B, heads, Tq, Tk)
    enc_mask = src_pad_mask  # mask keys in encoder-decoder attention

    logits, enc_attn, self_attn, enc_dec_attn = model(src_ids, decoder_input, src_mask=src_pad_mask, tgt_mask=tgt_mask, enc_mask=enc_mask)
    # logits: (B, T_tgt-1, V_tgt)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), decoder_target.contiguous().view(-1))
    loss.backward()
    optimizer.step()
    if (ep+1) % 50 == 0 or ep == 0:
        print(f"Epoch {ep+1}/{epochs} loss={loss.item():.4f}")

print("Training done.")

# =========================
# Simple greedy inference (for demo) and capture encoder-decoder attn
# =========================
model.eval()
with torch.no_grad():
    example_idx = 0  # pick first sentence to visualize
    src_ids = src_batch[example_idx:example_idx+1].to(device)  # (1, T_src)
    # encode
    src_pad_mask = make_src_padding_mask(src_ids, pad_idx=0).to(device)
    enc_out, enc_attn_maps = model.encoder(src_ids, src_mask=src_pad_mask)

    # Start decoder with <bos>
    cur = torch.tensor([[tgt_stoi["<bos>"]]], dtype=torch.long).to(device)
    generated = [tgt_stoi["<bos>"]]
    collected_enc_dec_attn = []  # will be list of lists: per decode-step, per decoder-layer: attn (1,heads,1,T_src)
    max_gen_len = max_tgt_len
    for step in range(max_gen_len):
        # build masks for current decoder input
        tgt_pad_mask = make_tgt_padding_mask(cur, pad_idx=0).to(device)  # (B,1,Tcur,1)
        causal = make_causal_mask(cur.size(1)).to(device)  # (1,1,Tcur,Tcur)
        tgt_mask = (tgt_pad_mask | causal)

        logits, self_attn_maps, enc_dec_attn_maps = model.decoder(cur, enc_out, tgt_mask=tgt_mask, enc_mask=src_pad_mask)
        # logits: (1, Tcur, V)
        next_tok_logits = logits[:, -1, :]  # (1, V)
        next_id = next_tok_logits.argmax(dim=-1).item()
        generated.append(next_id)
        # collect the encoder-decoder attention maps *for the last decoder position* of each layer
        # enc_dec_attn_maps is list[layer] each (B, heads, Tcur, T_src)
        step_attns = [m[:, :, -1, :].cpu().numpy() for m in enc_dec_attn_maps]  # per-layer list of (1,heads,T_src)
        collected_enc_dec_attn.append(step_attns)  # append per-step
        cur = torch.cat([cur, torch.tensor([[next_id]], device=device)], dim=1)
        if next_id == tgt_stoi["<eos>"]:
            break

# decode ids to tokens
gen_tokens = [tgt_itos.get(i, "<unk>") for i in generated]
src_tokens = [src_itos.get(i, "<unk>") for i in src_batch[example_idx].tolist() if i != 0]
print("SRC tokens:", src_tokens)
print("Generated tgt tokens:", gen_tokens)

# =========================
# Visualization of encoder-decoder attention
# =========================
def plot_enc_dec_attn_for_step(step_idx):
    """
    Plot for a given decode step (0-based). For that step, we have per-layer: (1, heads, T_src)
    We'll create a figure with rows = layers, cols = heads.
    """
    step_attns = collected_enc_dec_attn[step_idx]  # list len=num_layers, each shape (1,heads,T_src)
    num_layers = len(step_attns)
    num_heads = step_attns[0].shape[1]
    T_src = step_attns[0].shape[2]

    fig, axes = plt.subplots(num_layers, num_heads, figsize=(3*num_heads, 2.5*num_layers))
    if num_layers == 1 and num_heads == 1:
        axes = np.array([[axes]])

    for li in range(num_layers):
        for hi in range(num_heads):
            ax = axes[li, hi] if num_layers > 1 or num_heads > 1 else axes[0,0]
            att = step_attns[li][0, hi, :]  # (T_src,)
            ax.imshow(att[np.newaxis, :], aspect="auto", cmap="viridis")
            ax.set_xticks(range(T_src))
            ax.set_xticklabels(src_tokens, rotation=45)
            ax.set_yticks([])
            if hi == 0:
                ax.set_ylabel(f"Layer {li}")
            ax.set_title(f"Head {hi}")
    plt.suptitle(f"Encoder-Decoder Attention for decode step {step_idx} (generated token: {gen_tokens[step_idx+1]})")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# plot attention for each decode step
for step in range(len(collected_enc_dec_attn)):
    plot_enc_dec_attn_for_step(step)

专有名词

En Full En CN
NLP Natural Language Processing 自然语言处理
LSTM Long Short-Term Memory 长短期记忆网络
ELMo Embeddings from Language Models 预训练的上下文相关词嵌入模型

参考链接

avatar
hzzhu