先搞清楚一个问题:为什么需要 Transformer?
在 Transformer 出现之前,NLP 领域是 RNN 和 LSTM 的天下。它们有个致命缺陷------只能顺序处理。
想象一下你在读一句话:
"我 昨天 去 了 北京 ,因为 那里 的 烤鸭 特别 好吃"
RNN 读到"烤鸭"的时候,已经"走过"了好几个词,对"北京"的记忆早就衰减了。虽然 LSTM 用门控机制缓解了这个问题,但本质上还是"一个一个排队看"。
Transformer 的思路完全不同:让每个词同时看到句子里的所有词,想关注谁就关注谁,一步到位。
这就是 Self-Attention(自注意力机制) 的核心思想。
一、Self-Attention:让每个词都"眼观六路"
1.1 核心直觉
自注意力要解决的问题是:在一个序列中,每个元素应该"关注"哪些其他元素?
用人话来说就是:当模型处理"它"这个词的时候,它需要知道"它"指的是"猫"还是"垫子"。
"The cat sat on the mat because it was tired."
人类一眼就能看出 "it" 指的是 "cat"。自注意力机制就是让模型学会这种能力。
1.2 Q、K、V 三剑客
自注意力的核心是三个矩阵:Query(查询) 、Key(键) 、Value(值)。
这个概念其实很好理解,拿搜索引擎来类比:
- Query:你在搜索框输入的关键词 → "我想找什么?"
- Key:每篇文章的标题/标签 → "我是什么?"
- Value:文章的实际内容 → "这是我的内容"
匹配过程:用 Query 去和所有 Key 计算相似度,相似度高的 Value 就多关注一些。
less
输入序列: [词1, 词2, 词3, ...]
↓ (乘以三个权重矩阵)
Q矩阵 K矩阵 V矩阵
↓ ↓ ↓
Q1,Q2.. K1,K2.. V1,V2..
↓
Q·K^T → 注意力分数 → softmax → 加权求和Value → 输出
1.3 数学公式(别怕,很简单)
python
# Scaled Dot-Product Attention
# 公式: Attention(Q, K, V) = softmax(QK^T / √d_k) · V
import torch
import torch.nn.functional as F
import math
def self_attention(Q, K, V):
"""
Q: (seq_len, d_k) 查询矩阵
K: (seq_len, d_k) 键矩阵
V: (seq_len, d_v) 值矩阵
"""
d_k = Q.size(-1)
# 第一步:Q 和 K 做点积,得到注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1))
# 第二步:除以 √d_k(缩放因子),防止分数过大导致 softmax 梯度消失
scores = scores / math.sqrt(d_k)
# 第三步:softmax 归一化,得到注意力权重(和为1)
attention_weights = F.softmax(scores, dim=-1)
# 第四步:用注意力权重对 V 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
为什么除以 √d_k?
当 d_k 比较大的时候,Q·K^T 的结果会很大,导致 softmax 进入饱和区(梯度接近0)。除以 √d_k 就是把数值拉回到合理范围。
1.4 一个直观的例子
假设我们处理这句话:["我", "爱", "AI"]
makefile
注意力权重矩阵(softmax后):
我 爱 AI
我 [0.7, 0.2, 0.1] ← "我"主要关注自己
爱 [0.3, 0.4, 0.3] ← "爱"同时关注"我"和"AI"
AI [0.1, 0.3, 0.6] ← "AI"主要关注自己和"爱"
看到没?模型自动学到了"爱"连接了"我"和"AI"这个语义关系。
二、Multi-Head Attention:多角度观察,信息更全面
2.1 为什么需要多头?
单头注意力只有一个 Q、K、V,意味着每个词只能从一个角度去关注其他词。
但语言是复杂的!一个词可能同时跟多个词有关系,而且关系的类型不同:
"The bank of the river" → bank 关注 river(位置关系)
"I went to the bank" → bank 关注 went、money(语义关系)
多头注意力的思路:用多组 Q、K、V,让模型从不同的"视角"同时关注信息。
2.2 结构图解
css
输入 X
│
├──→ Linear → Q₁, K₁, V₁ → Scaled Attention → Head₁ ──┐
├──→ Linear → Q₂, K₂, V₂ → Scaled Attention → Head₂ ──┤
├──→ Linear → Q₃, K₃, V₃ → Scaled Attention → Head₃ ──┼→ Concat → Linear → 输出
├──→ ... │
└──→ Linear → Q₈, K₈, V₈ → Scaled Attention → Head₈ ──┘
原论文用了 8 个头(h=8),每个头的维度是 d_model/h。比如 d_model=512,每个头的维度就是 64。
2.3 代码实现
python
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度 = 64
# Q、K、V 的线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 输出的线性变换层
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 线性变换得到 Q、K、V
Q = self.W_q(x) # (batch, seq_len, d_model)
K = self.W_k(x)
V = self.W_v(x)
# 2. 拆分成多个头
# reshape 成 (batch, n_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# 3. 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 4. 拼接所有头
# (batch, n_heads, seq_len, d_k) → (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
# 5. 最终线性变换
output = self.W_o(attn_output)
return output
踩坑提醒 :实现多头注意力的时候,
transpose和view的维度顺序特别容易搞混。建议先跑一个小例子(seq_len=4, d_model=16, n_heads=4)打印中间张量的 shape,确认无误再上大模型。
2.4 多头到底学到了什么?
不同头会关注不同类型的信息:
| 头编号 | 可能关注的模式 | 例子 |
|---|---|---|
| Head 1 | 相邻词的关系 | "The cat sat" |
| Head 2 | 语法依赖 | "The cat that ate the fish" |
| Head 3 | 指代关系 | "The cat because it was tired" |
| Head 4 | 位置关系 | 句首词关注句尾词 |
这就是多头的魅力------不指定关注什么,让模型自己学。
三、Positional Encoding:给词加上"位置坐标"
3.1 一个关键问题
前面说到,Transformer 是并行处理所有词的。这带来了一个严重问题:
它不知道词的顺序!
对,你没听错。把 "狗咬人" 和 "人咬狗" 输入 Transformer,如果不加位置信息,它觉得这两句话是一样的。
RNN/LSTM 天然有序(一个一个处理),所以不存在这个问题。但 Transformer 的并行计算是以丢失位置信息为代价的。
解决方案就是:手动给每个词加一个"位置编码"。
3.2 位置编码怎么做?
原论文用的是 正弦-余弦位置编码(Sinusoidal Positional Encoding):
python
# 位置编码公式
# PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
# PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model=512, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
# 计算分母项: 10000^(2i/d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# 偶数位置用 sin,奇数位置用 cos
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
# (1, max_len, d_model) 方便广播
pe = pe.unsqueeze(0)
# 注册为 buffer(不参与梯度更新,但会保存到模型中)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
# 把位置编码加到输入上
x = x + self.pe[:, :x.size(1), :]
return x
3.3 为什么用 sin/cos?
你可能会问:为什么不用简单的 0, 1, 2, 3... 表示位置?
原因有三:
① 可以处理比训练序列更长的文本
sin/cos 是周期函数,天然支持外推。而学习到的位置编码(如 BERT 用的)对超长文本就不太友好。
② 每个位置有唯一的编码
不同位置的正弦波组合是唯一的,就像指纹一样。
③ 模型可以学到相对位置
这个是 sin/cos 编码最巧妙的地方。通过三角函数公式:
scss
sin(α + β) = sin(α)cos(β) + cos(α)sin(β)
模型可以通过线性变换,从绝对位置编码中推导出相对位置关系。
实际工作中 :BERT 用的是可学习位置编码(Learned Position Embedding),就是直接把位置编码当参数训练。GPT 也是。只有原始 Transformer 和一些长文本模型(如 Transformer-XL)还在用 sin/cos。两种方式各有优劣,sin/cos 泛化性更好,可学习的编码在训练长度内表现更优。
3.4 位置编码可视化
ini
位置编码热力图(每个位置 × 每个维度):
维度 → 0 1 2 3 4 5 6 7
位置 ↓
0 [ 0.0 1.0 0.0 1.0 0.0 1.0 0.0 1.0 ]
1 [ 0.84 0.54 0.01 1.0 0.0 1.0 0.0 1.0 ]
2 [ 0.91 -0.42 0.02 1.0 0.0 1.0 0.0 1.0 ]
3 [ 0.14 -0.99 0.03 1.0 0.0 1.0 0.0 1.0 ]
...
50 [...变化越来越慢的低频部分...]
可以看到:低维度变化快(高频),高维度变化慢(低频)。这种多尺度的设计让模型能同时捕捉局部和全局的位置关系。
四、三大组件如何协同工作?
把上面三个部分串起来,Transformer 的编码器(Encoder)一层的完整流程是:
scss
输入词向量 (Word Embedding)
│
▼
┌─────────────────────┐
│ Positional Encoding │ ← 加上位置信息
└────────┬────────────┘
▼
┌──────────────────────┐
│ Multi-Head Attention │ ← 自注意力:词与词之间交互
└────────┬─────────────┘
▼
┌───────────────────────────┐
│ Add & Norm (残差+层归一化) │
└────────┬──────────────────┘
▼
┌──────────────────────┐
│ Feed Forward Network │ ← 全连接层:逐位置变换
└────────┬─────────────┘
▼
┌───────────────────────────┐
│ Add & Norm (残差+层归一化) │
└────────┬──────────────────┘
▼
输出到下一层
用代码串起来:
python
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 1. Multi-Head Self-Attention + 残差连接 + LayerNorm
attn_out = self.self_attn(x)
x = self.norm1(x + self.dropout(attn_out))
# 2. Feed Forward + 残差连接 + LayerNorm
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
残差连接(Residual Connection) 的作用:把输入直接加到输出上,梯度可以"跳过"中间层直接回传,有效缓解深层网络的梯度消失问题。这个设计来自 ResNet,被 Transformer 借鉴过来了。
五、速查对照表
| 组件 | 解决什么问题 | 核心思想 | 一句话记忆 |
|---|---|---|---|
| Self-Attention | 词与词之间怎么交互 | Q·K^T 算相似度,加权取 V | "你跟谁最相关?" |
| Multi-Head Attention | 单一视角不够 | 多组 QKV 并行,多角度观察 | "多看几眼,看得更全" |
| Positional Encoding | 并行计算丢失位置 | sin/cos 注入位置信息 | "告诉模型你在哪" |