搞懂Transformer,是理解GPT、BERT、Claude等所有大模型的第一步。这篇文章用最直白的方式讲透核心机制,配合架构图和代码,看完就能跟面试官掰扯。
一句话总结
Transformer = 自注意力机制 + 前馈网络 + 残差连接 + 层归一化,堆叠N层,不需要循环,纯靠注意力建模序列关系。
0. 为什么需要Transformer?
2017年之前,NLP的主流是RNN/LSTM。它们有个致命问题:必须逐步处理序列,无法并行。第t个词必须等前t-1个词处理完,训练极慢。
Transformer的解法:扔掉循环,用注意力机制直接建模任意两个位置的关系。无论两个词隔多远,一步到位。
论文标题《Attention Is All You Need》------"注意力就是你需要的一切",霸气。
1. 整体架构
Transformer由编码器栈 和解码器栈两部分组成:
- 编码器(Encoder):读取输入序列,提取特征
- 解码器(Decoder):基于编码器输出,逐步生成目标序列
原始论文中,编码器6层,解码器6层。
编码器单层结构
Input → Self-Attention → Add & Norm → Feed-Forward → Add & Norm → Output
↓ ↓
残差连接 残差连接
解码器单层结构
Input → Masked Self-Attn → Add & Norm → Cross-Attention → Add & Norm → FFN → Add & Norm → Output
↑
Encoder输出(K,V)
解码器比编码器多了一个Cross-Attention层,用编码器的输出作为K和V,解码器的自注意力输出作为Q------这就是"翻译"的关键。
2. Self-Attention:Transformer的灵魂
2.1 直觉理解
一句话:每个词去看序列中所有其他词,决定自己应该"关注"谁。
例如:"我 喜欢 吃 苹果 手机"
"苹果"这个词需要看上下文才能判断是水果还是品牌。Self-Attention让"苹果"同时关注"吃"和"手机",动态计算注意力权重。
2.2 Q、K、V是什么?
借检索系统的类比:
| 角色 | 类比 | 含义 |
|---|---|---|
| Q (Query) | 搜索关键词 | "我想找什么" |
| K (Key) | 文档标题/标签 | "我有什么特征" |
| V (Value) | 文档内容 | "我的实际内容" |
每个输入向量X,通过三个可学习的权重矩阵,生成Q、K、V:
Q = X * W_Q (n × d_k)
K = X * W_K (n × d_k)
V = X * W_V (n × d_v)
2.3 计算过程(四步)
Step 1:打分 --- Q和K做点积,衡量相似度
scores = Q * K^T # (n × d_k) × (d_k × n) = (n × n)
Step 2:缩放 --- 除以 sqrt(d_k),防止点积过大导致softmax梯度消失
scores = scores / sqrt(d_k)
为什么缩放?当d_k=64时,两个随机向量的点积期望值约为d_k=64,方差也是d_k。不缩放的话,softmax输入值偏大,输出接近one-hot,梯度趋近于0。
Step 3:归一化 --- Softmax转为概率分布
attn_weights = softmax(scores) # (n × n),每行之和为1
Step 4:加权求和 --- 用注意力权重对V加权
output = attn_weights * V # (n × n) × (n × d_v) = (n × d_v)
2.4 完整公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V
这就是Scaled Dot-Product Attention。
2.5 代码实现
python
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, n, d_k)
K: (batch, n, d_k)
V: (batch, n, d_v)
"""
d_k = Q.size(-1)
# Step 1+2: 点积 + 缩放
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Mask(解码器用,防止看到未来)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 3: Softmax
attn_weights = F.softmax(scores, dim=-1)
# Step 4: 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights
3. Multi-Head Attention:多角度看世界
单头注意力只能学一种关注模式。多头注意力让模型同时从不同子空间关注不同信息。
3.1 原理
- 将Q、K、V分别投影到h个不同的子空间
- 每个子空间独立计算Scaled Dot-Product Attention
- 将h个头的输出拼接
- 通过线性变换W_O映射回d_model维
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) ⋅ W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) \cdot W_O MultiHead(Q,K,V)=Concat(head1,...,headh)⋅WO
其中 head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)
3.2 维度变化(原始论文参数)
| 参数 | 值 | 说明 |
|---|---|---|
| d_model | 512 | 模型维度 |
| h (头数) | 8 | 多头数量 |
| d_k = d_v | 64 | 每个头的维度 = 512/8 |
关键点:多头注意力的总计算量和单头差不多,因为每个头的维度缩小了(512 → 64),但能学到更丰富的特征。
3.3 代码实现
python
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Q, K, V 的投影矩阵
self.W_Q = torch.nn.Linear(d_model, d_model)
self.W_K = torch.nn.Linear(d_model, d_model)
self.W_V = torch.nn.Linear(d_model, d_model)
# 输出投影
self.W_O = torch.nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 投影并拆分成多头: (batch, n, d_model) -> (batch, n_heads, n, d_k)
Q = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 计算注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 拼接多头: (batch, n_heads, n, d_k) -> (batch, n, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 输出投影
output = self.W_O(attn_output)
return output, attn_weights
4. 位置编码:让模型知道"谁在哪儿"
Self-Attention本身是位置无关的------"我爱你"和"你爱我"对它来说是一样的。这显然不行。
4.1 正弦位置编码
原始论文使用正弦/余弦函数生成位置编码:
P E ( p o s , 2 i ) = sin ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)
P E ( p o s , 2 i + 1 ) = cos ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)
pos:词在序列中的位置(0, 1, 2, ...)2i/2i+1:维度索引(偶数用sin,奇数用cos)
4.2 为什么用三角函数?
- 有界性:sin/cos值在[-1, 1],不会爆炸
- 相对位置可推导:PE(pos+k)可以用PE(pos)的线性变换表示,模型能学到相对位置关系
- 外推性:对训练时没见过的序列长度,也能生成位置编码
4.3 代码实现
python
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model=512, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1)]
位置编码直接加到输入Embedding上,不是拼接。因为两者维度相同(都是d_model),加法保留了原始信息,同时注入了位置信号。
5. Add & Norm:残差连接 + 层归一化
每个子层(Self-Attention、Feed-Forward)后面都有:
output = LayerNorm(x + Sublayer(x))
5.1 残差连接(Add)
直接把输入加到输出上,解决深层网络梯度消失问题。
直觉理解:如果某一层没用,残差连接让它学恒等映射(输出≈输入),至少不会变差。
5.2 层归一化(LayerNorm)
对每个样本的特征维度做归一化(均值0、方差1),稳定训练。
python
# BatchNorm: 对batch维度归一化(CV常用)
# LayerNorm: 对feature维度归一化(NLP常用)
class LayerNorm(torch.nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(d_model))
self.beta = torch.nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
为什么NLP用LayerNorm而不是BatchNorm?因为序列长度可变,BatchNorm在不同位置的统计量不稳定;LayerNorm在每个位置独立归一化,更适合。
6. Feed-Forward Network:两层全连接
每个编码器/解码器层都有一个位置级前馈网络:
FFN ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2
- 第一层:d_model → d_ff(512 → 2048),升维 + ReLU激活
- 第二层:d_ff → d_model(2048 → 512),降维回原尺寸
关键特征:这个FFN对每个位置独立作用,不同位置共享参数。可以理解为"对每个词独立做一次特征变换"。
python
class FeedForward(torch.nn.Module):
def __init__(self, d_model=512, d_ff=2048):
super().__init__()
self.linear1 = torch.nn.Linear(d_model, d_ff)
self.linear2 = torch.nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear2(F.relu(self.linear1(x)))
7. Mask机制:防止"偷看"
7.1 Padding Mask
序列长度不一,短序列用0填充。Padding Mask把填充位置设为0,不参与注意力计算。
7.2 Sequence Mask(因果Mask)
解码器在训练时,位置t不能看到t+1及之后的信息(否则就是"作弊")。
实现方式:创建一个上三角为0、下三角为1的矩阵:
python
def generate_causal_mask(seq_len):
"""生成因果mask,防止解码器看到未来信息"""
mask = torch.tril(torch.ones(seq_len, seq_len)) # 下三角矩阵
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# 示例: seq_len=4
# [[1, 0, 0, 0],
# [1, 1, 0, 0],
# [1, 1, 1, 0],
# [1, 1, 1, 1]]
在计算注意力分数时,mask=0的位置加上 -1e9(负无穷),Softmax后权重趋近于0。
8. 编码器 vs 解码器:三种注意力
| 注意力类型 | 所在位置 | Q来源 | K来源 | V来源 | 作用 |
|---|---|---|---|---|---|
| Self-Attention | 编码器 | 输入X | 输入X | 输入X | 输入序列内部关系 |
| Masked Self-Attn | 解码器 | 已生成 | 已生成 | 已生成 | 只看已生成部分 |
| Cross-Attention | 解码器 | 解码器 | 编码器 | 编码器 | 桥接输入和输出 |
Cross-Attention是编码器-解码器架构的精髓:解码器通过Q"提问",编码器通过K"匹配"、V"回答"。
9. 从Transformer到GPT:架构进化
原始Transformer(2017)是编码器-解码器架构,用于机器翻译。后来的演化:
| 模型 | 年份 | 架构 | 特点 |
|---|---|---|---|
| Transformer | 2017 | Encoder-Decoder | 原版,用于翻译 |
| BERT | 2018 | Encoder-only | 双向理解,适合分类/NLU |
| GPT-1/2/3 | 2018-2020 | Decoder-only | 单向生成,自回归 |
| GPT-4 / Claude | 2023-2025 | Decoder-only | 规模放大 + RLHF对齐 |
GPT为什么只用Decoder? 因为语言生成的本质是"根据前文预测下一个词",天然是因果的。编码器的双向注意力反而会"泄露"未来信息,不适合生成任务。
Decoder-only架构的简化:
- 去掉Cross-Attention(没有编码器)
- 只保留Masked Self-Attention + FFN
- 堆叠层数更多(GPT-3有96层)
10. 计算复杂度分析
| 操作 | 复杂度 | 说明 |
|---|---|---|
| Self-Attention | O(n² × d) | n是序列长度,两两计算 |
| Feed-Forward | O(n × d²) | d是模型维度 |
| 总体 | O(n² × d + n × d²) | 取决于n和d的大小 |
Self-Attention的O(n²)是瓶颈------序列长度翻倍,计算量翻四倍。这也是为什么GPT-4的上下文窗口从8K扩展到128K需要大量工程优化(Flash Attention、KV Cache等)。
11. 面试高频问题速查
Q1:为什么用缩放因子 sqrt(d_k)?
点积的方差随d_k增长。d_k=64时,点积值可能在±8范围;d_k=1024时,可能到±32。这导致Softmax进入饱和区,梯度趋近于0。除以sqrt(d_k)把方差拉回1。
Q2:Multi-Head Attention和单头有什么区别?
单头只能学一种关注模式。多头让模型同时关注不同子空间的信息------一个头关注语法关系,另一个关注语义相似度。实验证明,h=8比单头效果显著更好。
Q3:位置编码为什么用加法而不是拼接?
拼接会多出d_model维参数,且位置信息可能被淹没。加法更简洁,且实验效果相当。近期也有RoPE(旋转位置编码)等改进方案。
Q4:LayerNorm和BatchNorm的区别?
| BatchNorm | LayerNorm | |
|---|---|---|
| 归一化维度 | batch维度 | feature维度 |
| 对序列长度敏感 | 敏感(变长不行) | 不敏感 |
| 典型应用 | CV | NLP |
| 推理时 | 需要running mean | 不需要 |
Q5:Transformer为什么比LSTM快?
LSTM必须逐步计算,O(n)的串行。Transformer的Self-Attention可以用矩阵乘法一步完成,GPU并行度极高。训练速度差10-100倍。
12. 动手实验
用PyTorch跑一个最小的Transformer Encoder:
python
import torch
import torch.nn as nn
import math
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + Add & Norm
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# FFN + Add & Norm
ff_output = self.ffn(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
# 使用示例
encoder_layer = TransformerEncoderLayer(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512) # batch=2, seq_len=10, d_model=512
output = encoder_layer(x)
print(output.shape) # torch.Size([2, 10, 512])
总结
一张表回顾Transformer核心组件:
| 组件 | 作用 | 关键公式 |
|---|---|---|
| Self-Attention | 建模序列内关系 | softmax(QK^T / sqrt(d_k)) * V |
| Multi-Head | 多子空间并行关注 | Concat(head_1,...,head_h) * W_O |
| 位置编码 | 注入位置信息 | sin/cos(pos / 10000^(2i/d)) |
| Add & Norm | 稳定训练 | LayerNorm(x + Sublayer(x)) |
| Feed-Forward | 位置级特征变换 | max(0, xW_1)W_2 |
| Mask | 防止信息泄露 | 上三角置零 |
Transformer的核心思想只有一条:用注意力替代循环,让序列中的每个位置都能直接看到所有其他位置。所有的组件------Multi-Head、位置编码、残差连接------都是为了让这个核心思想更好地工作。
搞懂这些,GPT、BERT、T5的架构差异就是在这个基础上的加减法了。
路易乔布斯 © 2026 | AI Agent & RAG学习计划 · 模块03-LLM基础 · 第一篇
参考文献:Vaswani et al., "Attention Is All You Need", NeurIPS 2017 --- arxiv.org/abs/1706.03762
