Transformer 架构:大语言模型的"心脏"
💡 摘要:Transformer 是 ChatGPT、Claude 等所有大语言模型的基础架构。它通过"自注意力"机制解决了 RNN 的长遗忘问题,让 AI 真正能理解长文本。
引言
2017 年,Google 发表了一篇论文《Attention Is All You Need》,提出了一种全新的神经网络架构------Transformer。
你可能没读过这篇论文,但你每天使用的 ChatGPT、Claude、Gemini......它们的核心都是 Transformer。
在这之前,处理文本的主流是 RNN (循环神经网络),但它有个致命缺陷:文本越长,越记不住开头的内容。
Transformer 的出现彻底改变了这一切------它让模型能同时看到整段文本,并且任何两个词都能直接"对话"。
核心概念
为什么 Transformer 比 RNN 更强?
想象你在读一本 500 页的小说:
| 方式 | 类比 | 问题 |
|---|---|---|
| RNN | 一页一页读,读完后面忘了前面 | 长遗忘:第 500 页时,第 1 页的信息已经模糊了 |
| Transformer | 整本书摊开,随时翻到任意一页 | 无遗忘:第 500 页和第 1 页直接关联 |
关键差异:
| 维度 | RNN | Transformer |
|---|---|---|
| 计算方式 | 串行(一个词一个词处理) | 并行(整个序列同时处理) |
| 长距离依赖 | 路径长度 = N(随文本变长) | 路径长度 = 1(恒定) |
| 训练速度 | 慢(只能串行) | 快(GPU 并行) |
| 长文本效果 | 差(信息衰减) | 好(直接关联) |
Encoder 和 Decoder:两个核心角色
Transformer 由两部分组成:
| 角色 | 职责 | 类比 |
|---|---|---|
| Encoder(编码器) | 理解输入,生成上下文感知的表示 | 阅读理解:读懂整篇文章 |
| Decoder(解码器) | 基于编码器的理解,逐词生成输出 | 写作表达:根据理解写出答案 |
它们如何协作?
输入文本 → Encoder(理解)→ 编码表示 → Decoder(生成)→ 输出文本
在 Decoder 中,有一个关键机制叫交叉注意力(Cross-Attention)------它让 Decoder 在生成每个词时,都能"查阅" Encoder 的输出,找到最相关的信息。
核心机制:自注意力
Q、K、V 是什么?
自注意力的核心是三个向量:Q(Query)、K(Key)、V(Value)。
一句话记忆: Q 负责提问,K 负责响应匹配,V 负责给出答案。
类比数据库搜索:
| 角色 | 类比 | 作用 |
|---|---|---|
| Q(Query) | 搜索关键词 | "我想找什么?" |
| K(Key) | 文章标题/标签 | "我有什么信息?" |
| V(Value) | 文章正文 | "我的实际内容是什么?" |
示例: 在句子"苹果公司发布了新手机"中:
- "苹果"的 Q 可能想找"公司"相关信息
- "公司"的 K 匹配到"苹果"的 Q
- "公司"的 V 提供了实际语义内容
公式解读
Attention(Q, K, V) = softmax(Q × K^T / √d_k) × V
为什么要除以 √d_k?
这是为了防止 softmax 进入"饱和区"------当点积值太大时,softmax 会把几乎所有概率分配给最大值,导致梯度消失、训练停滞。
| 情况 | 效果 |
|---|---|
| 不缩放 | 点积方差大 → softmax 饱和 → 梯度消失 → 训练困难 |
| 除以 √d_k | 方差归一化为 1 → 梯度适中 → 训练稳定 |
多头注意力:为什么需要多个"头"?
单头 像一个专家只用自己的方式看问题;多头像一个委员会,每个成员从不同角度分析。
| 头的类型 | 关注的关系 | 示例 |
|---|---|---|
| 语法头 | 相邻词的局部依赖 | "红色的"修饰"苹果" |
| 句法头 | 主谓关系 | "猫"和"追"的连接 |
| 长距离头 | 语义相似的词 | "苹果"和"水果" |
| 位置头 | 固定偏移 | 关注前一个词 |
为什么每个头要降维?
假设总维度 d_model=512,头数 nhead=8,每个头维度 d_k=512/8=64:
- 保持总计算量不变
- 迫使每个头专注于不同类型的特征
- 输出拼接后仍为 512 维,完美匹配后续层
位置编码:让模型知道"顺序"
为什么需要位置编码?
Transformer 的自注意力有一个关键特性:置换不变性。
如果交换输入中两个词的位置,自注意力的计算结果会完全相同。这意味着模型不知道词的先后顺序。
后果:
- "I love you" 和 "You love I" 被视为相同
- "不好" 与 "好不" 无法区分
- 模型退化为"词袋模型"
解决方案
在词嵌入(表示"什么词")的基础上,叠加位置编码(表示"哪个位置"):
python
输出 = 词嵌入 + 位置编码
为什么是加法而不是拼接?
- 加法不增加维度,计算效率更高
- 语义和位置在同一向量空间中线性叠加
- 模型可以通过后续权重分别利用或忽略它们
残差连接和层归一化
残差连接(Skip Connection)
python
输出 = 子层(输入) + 输入
作用:
- 解决梯度消失:提供"短路路径",让梯度直接流回浅层
- 防止退化:即使子层学到零,至少还有原始输入
- 让网络"走得深"
层归一化(Layer Normalization)
python
LN(x) = γ × (x - μ) / σ + β
作用:
- 稳定训练:不依赖批次统计,适合可变长度序列
- 加速收敛:将输出拉回均值为 0、方差为 1
- 与残差配合:避免残差叠加导致的数值偏移
一句话总结: 残差连接让网络"走得深",层归一化让网络"走得稳"。
代码示例
位置编码实现
python
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""位置编码模块"""
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
# 正弦和余弦交替
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位 sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位 cos
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, :x.size(1), :]
简易 Transformer 编码器
python
class SimpleTransformerEncoder(nn.Module):
"""简化的 Transformer 编码器"""
def __init__(self, vocab_size: int, d_model: int = 512,
nhead: int = 8, num_layers: int = 6):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
# 使用 PyTorch 自带的 Transformer 层
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=0.1,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers
)
def forward(self, src: torch.Tensor) -> torch.Tensor:
embedded = self.embedding(src)
embedded = self.pos_encoder(embedded)
return self.transformer_encoder(embedded)
测试运行
python
# 创建输入(2 个样本,每个 10 个词)
input_ids = torch.randint(0, 1000, (2, 10))
# 创建模型并运行
model = SimpleTransformerEncoder(vocab_size=1000)
output = model(input_ids)
print(f"输入形状:{input_ids.shape}") # torch.Size([2, 10])
print(f"输出形状:{output.shape}") # torch.Size([2, 10, 512])
最佳实践
- 利用并行计算:Transformer 可并行处理全序列,避免 RNN 的串行瓶颈
- 合理设置位置编码:根据任务选择固定编码或可学习编码
- 多层堆叠:使用 6 层以上编码器捕获多层次特征
- Pre-Norm vs Post-Norm:现代模型多用 Pre-Norm(LN → 子层 → 残差),训练更稳定
总结
核心要点回顾:
- 并行计算:整个序列同时处理,训练速度远超 RNN
- 自注意力:任意两个词直接关联,解决长遗忘问题
- Q/K/V:查询-匹配-提取三阶段,像数据库检索
- 多头:多个子空间同时捕捉不同类型的关系
- 位置编码:弥补置换不变性,让模型感知顺序
- 残差+归一化:让深度网络稳定训练
理解了 Transformer,你就理解了所有大语言模型的"心脏"。