目录
第一部分:基础与核心思想
第二部分:核心组件详解 (见transformer_guide_part2.md)
第三部分:编码器与解码器 (见transformer_guide_part3.md)
第四部分:训练与优化 (见transformer_guide_part4.md)
第五部分:变体与演进 (见transformer_guide_part5.md)
第六部分:应用与实践 (见transformer_guide_part6.md)
第一部分:基础与核心思想
1. 概述
1.1 什么是Transformer
Transformer 是一种基于自注意力机制(Self-Attention) 的深度学习架构,由 Google 团队在 2017 年论文 "Attention Is All You Need" 中首次提出。它彻底改变了自然语言处理(NLP)领域,并逐渐扩展到计算机视觉、语音处理、多模态学习等多个领域。
1.2 核心特征
| 特征 | 说明 |
|---|---|
| 并行计算 | 不像RNN需要顺序处理,Transformer可以并行处理整个序列 |
| 长距离依赖 | 通过自注意力机制直接建模任意距离的依赖关系 |
| 可扩展性 | 易于扩展到大规模模型(如GPT-3的175B参数) |
| 通用性 | 同一架构适用于多种任务和模态 |
1.3 影响力
Transformer架构催生了一系列突破性模型:
Transformer (2017)
├── NLP领域
│ ├── BERT (2018) - 预训练语言理解
│ ├── GPT系列 (2018-2023) - 大规模语言生成
│ ├── T5 (2019) - 统一文本到文本框架
│ └── LLaMA (2023) - 开源大语言模型
│
├── CV领域
│ ├── ViT (2020) - 视觉Transformer
│ ├── DETR (2020) - 目标检测
│ └── Swin Transformer (2021) - 层次化视觉Transformer
│
├── 多模态
│ ├── CLIP (2021) - 图文对比学习
│ ├── DALL-E (2021) - 文本到图像生成
│ └── GPT-4V (2023) - 多模态大语言模型
│
└── 其他领域
├── AlphaFold2 (2021) - 蛋白质结构预测
├── Jukebox (2020) - 音乐生成
└── Whisper (2022) - 语音识别
2. 历史背景与发展
2.1 序列建模的演进
2.1.1 早期方法:N-gram
N-gram模型(1980s-2000s):
P(w_t | w_1, ..., w_{t-1}) ≈ P(w_t | w_{t-n+1}, ..., w_{t-1})
例如:3-gram
P("今天" | "我", "喜欢") = Count("我", "喜欢", "今天") / Count("我", "喜欢")
局限性:
- 数据稀疏问题
- 无法捕捉长距离依赖
- 存储需求随n指数增长
2.1.2 循环神经网络 (RNN)
基本RNN(1990s):
隐藏状态更新:
h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)
输出:
y_t = W_hy * h_t + b_y
架构图:
x_1 → [RNN] → h_1 → [RNN] → h_2 → ... → [RNN] → h_T
↑ ↑ ↑
| | |
h_0 h_1 h_{T-1}
问题:
- 梯度消失/爆炸问题
- 无法并行计算
- 长距离依赖建模困难
2.1.3 LSTM与GRU
LSTM (Long Short-Term Memory, 1997):
遗忘门: f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
输入门: i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
候选值: C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C)
细胞状态: C_t = f_t * C_{t-1} + i_t * C̃_t
输出门: o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
隐藏状态: h_t = o_t * tanh(C_t)
GRU (Gated Recurrent Unit, 2014):
更新门: z_t = σ(W_z · [h_{t-1}, x_t])
重置门: r_t = σ(W_r · [h_{t-1}, x_t])
候选隐藏状态: h̃_t = tanh(W · [r_t * h_{t-1}, x_t])
隐藏状态: h_t = (1 - z_t) * h_{t-1} + z_t * h̃_t
改进:
- 缓解梯度消失问题
- 能够学习更长的依赖
仍然存在的问题:
- 顺序计算,无法并行
- 长距离依赖仍然困难(100-200 tokens)
2.1.4 注意力机制的引入
Bahdanau Attention (2014):
在机器翻译中,不再只依赖编码器最后一个隐藏状态,而是动态关注源序列的不同部分:
上下文向量: c_i = Σ α_{ij} * h_j
注意力权重: α_{ij} = exp(e_{ij}) / Σ_k exp(e_{ik})
对齐分数: e_{ij} = a(s_{i-1}, h_j) # 对齐模型
意义:
- 首次在NLP中引入注意力机制
- 显著提升机器翻译性能
- 为Transformer奠定基础
2.2 Transformer的诞生
2.2.1 论文背景
论文信息:
- 标题: "Attention Is All You Need"
- 作者: Vaswani et al. (Google Brain & Google Research)
- 发表: NeurIPS 2017
- 引用数: 100,000+ (截至2024)
核心创新:
- 完全基于注意力机制,抛弃RNN/CNN
- 自注意力(Self-Attention)机制
- 多头注意力(Multi-Head Attention)
- 位置编码(Positional Encoding)
2.2.2 设计动机
RNN的问题:
RNN的顺序计算:
x_1 → h_1 → x_2 → h_2 → x_3 → h_3 → ...
问题:
1. 无法并行: 必须等h_1计算完才能算h_2
2. 长距离依赖: h_1000需要经过999步传递
3. 训练慢: 无法利用GPU并行能力
解决方案的思路:
如果能直接建立任意两个位置之间的连接...
x_1 ←→ x_2
↕ ↕
x_3 ←→ x_4
这样:
1. 可以并行计算所有位置
2. 任意距离的依赖都是1步
3. 充分利用GPU
2.3 发展历程
2.3.1 编码器时代 (2018-2019)
BERT (Bidirectional Encoder Representations from Transformers)
发布: 2018年10月 (Google)
架构: 仅编码器
参数: 110M (Base), 340M (Large)
训练: Masked Language Model + Next Sentence Prediction
影响: 刷新11项NLP基准
核心思想:
传统语言模型: 单向(从左到右)
P(w_t | w_1, ..., w_{t-1})
BERT: 双向
P(w_t | w_1, ..., w_{t-1}, w_{t+1}, ..., w_n)
通过Masked Language Model实现:
输入: "我 喜欢 [MASK] 学习"
预测: "喜欢" → "深度"
RoBERTa (2019):
- 移除NSP任务
- 更多训练数据
- 更大批次大小
- 动态masking
2.3.2 解码器时代 (2018-2020)
GPT (Generative Pre-trained Transformer)
发布: 2018年6月 (OpenAI)
架构: 仅解码器
参数: 117M
训练: 自回归语言模型
特点: 生成式预训练 + 判别式微调
GPT-2 (2019):
参数: 1.5B
训练数据: WebText (40GB文本)
特点:
- 零样本学习能力
- 多任务学习
- 引发"太危险而不能发布"的讨论
GPT-3 (2020):
参数: 175B
训练数据: 300B tokens
特点:
- Few-shot学习
- In-context学习
- 涌现能力
- 开启大语言模型时代
2.3.3 编码器-解码器时代 (2019-2020)
T5 (Text-to-Text Transfer Transformer)
发布: 2019年 (Google)
架构: 编码器-解码器
参数: 11B
创新: 将所有NLP任务统一为文本到文本格式
统一框架:
翻译: "translate English to German: The house is wonderful" → "Das Haus ist wunderbar"
摘要: "summarize: {长文本}" → "{摘要}"
问答: "question: {问题} context: {上下文}" → "{答案}"
分类: "sentiment: {文本}" → "positive/negative"
2.3.4 视觉Transformer时代 (2020-至今)
ViT (Vision Transformer)
发布: 2020年10月 (Google)
创新: 将Transformer直接应用于图像
方法: 图像分割为patches,作为序列处理
影响: 开创视觉Transformer时代
Swin Transformer (2021):
- 层次化结构
- 窗口注意力
- 移位窗口
- 适用于各种视觉任务
2.3.5 大语言模型时代 (2022-至今)
ChatGPT (2022.11):
- GPT-3.5 + RLHF
- 对话式交互
- 引发AI热潮
GPT-4 (2023.3):
- 多模态(文本+图像)
- 更强的推理能力
- 更好的安全性
开源大模型:
LLaMA (2023.2) - Meta
├── LLaMA 2 (2023.7)
├── Alpaca (2023.3) - Stanford
├── Vicuna (2023.3) - LMSYS
└── Mistral (2023.9)
Qwen (2023.8) - 阿里
ChatGLM (2023.3) - 清华
DeepSeek (2024) - 幻方
3. 核心思想
3.1 注意力机制的直觉
3.1.1 人类注意力的类比
想象你在阅读一段话:
"The animal didn't cross the street because it was too tired"
当你读到 "it" 时,你的大脑会自动关注 "animal" 而不是 "street"
这就是注意力机制的核心思想:根据上下文动态聚焦相关信息
3.1.2 注意力的数学表达
查询 (Query): "我在找什么?" - 当前位置的信息需求
键 (Key): "我有什么?" - 每个位置的特征描述
值 (Value): "我能提供什么?" - 每个位置的实际内容
注意力 = f(Query, Key) × Value
3.2 自注意力机制
3.2.1 核心概念
自注意力(Self-Attention)是Transformer的核心创新:
传统注意力: Query来自一个序列,Key/Value来自另一个序列
自注意力: Query、Key、Value都来自同一个序列
作用: 让序列中的每个位置都能"看到"并"关注"序列中的所有其他位置
3.2.2 直观理解
句子: "我 喜欢 在 图书馆 学习"
对于每个词,计算它与其他所有词的相关性:
"我": 与"喜欢"(主谓关系)、"学习"(施事关系)相关
"喜欢": 与"我"(主谓关系)、"学习"(动宾关系)相关
"图书馆": 与"学习"(地点关系)相关
...
自注意力让每个词都能获取整个句子的上下文信息
3.2.3 与RNN的对比
| 特性 | RNN | 自注意力 |
|---|---|---|
| 计算路径长度 | O(n) | O(1) |
| 并行性 | 顺序 | 完全并行 |
| 长距离依赖 | 困难 | 直接建模 |
| 计算复杂度 | O(n·d²) | O(n²·d) |
3.3 位置编码
3.3.1 为什么需要位置编码
自注意力机制是"置换等变"的(permutation equivariant)
即:打乱输入顺序,输出也会相应打乱,但不会改变内容
问题:无法区分
"猫 追 狗" vs "狗 追 猫"
解决方案:添加位置信息
3.3.2 正弦位置编码
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中:
- pos: 位置索引
- i: 维度索引
- d_model: 模型维度
设计直觉:
1. 每个位置有唯一的编码
2. 相对位置可以通过线性变换获得
3. 可以泛化到更长的序列
4. 值有界,不会发散
3.4 编码器-解码器架构
3.4.1 整体设计思想
编码器 (Encoder):
- 理解输入序列
- 提取特征表示
- 双向注意力(可以看到完整输入)
解码器 (Decoder):
- 生成输出序列
- 自回归生成(逐token生成)
- 因果注意力(只能看到之前的输出)
编码器-解码器注意力:
- 解码器关注编码器的输出
- 实现输入到输出的映射
3.4.2 信息流
输入序列 → [编码器] → 编码表示
↓
[编码器-解码器注意力]
↓
已生成序列 → [解码器] → 下一个token概率
3.5 并行计算
3.5.1 RNN的瓶颈
python
# RNN的顺序计算
h = zeros(hidden_size)
for x in sequence:
h = rnn_cell(x, h) # 必须等待前一步完成
# 无法并行!
3.5.2 Transformer的并行性
python
# Transformer的并行计算
# 所有位置可以同时计算
Q = X @ W_Q # [seq_len, d_model] @ [d_model, d_k] = [seq_len, d_k]
K = X @ W_K # 同上
V = X @ W_V # 同上
# 注意力计算也是并行的
attention = softmax(Q @ K.T / sqrt(d_k)) @ V # 矩阵乘法,并行计算
4. 整体架构概览
4.1 原始Transformer架构
┌─────────────────────────────────────────────────────────────────┐
│ Transformer 整体架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 输入序列 输出序列 │
│ ↓ ↑ │
│ ┌────────┐ ┌────────┐│
│ │ 输入嵌入│ │输出嵌入 ││
│ └────┬───┘ └────┬───┘│
│ ↓ ↑ │
│ ┌────────┐ ┌────────┐│
│ │位置编码 │ │位置编码 ││
│ └────┬───┘ └────┬───┘│
│ ↓ ↑ │
│ ┌─────────────────────────────────────────────────────────────┐│
│ │ 编码器 × N ││
│ │ ┌──────────────────────────────────────────────────────┐ ││
│ │ │ 多头自注意力层 │ ││
│ │ │ Q = XW_Q, K = XW_K, V = XW_V │ ││
│ │ │ Attention = softmax(QK^T/√d_k)V │ ││
│ │ └──────────────────────────────────────────────────────┘ ││
│ │ ↓ + 残差连接 ││
│ │ 层归一化 ││
│ │ ↓ ││
│ │ ┌──────────────────────────────────────────────────────┐ ││
│ │ │ 前馈神经网络 │ ││
│ │ │ FFN = max(0, xW_1 + b_1)W_2 + b_2 │ ││
│ │ └──────────────────────────────────────────────────────┘ ││
│ │ ↓ + 残差连接 ││
│ │ 层归一化 ││
│ └─────────────────────────────────────────────────────────────┘│
│ ↓ ↑ │
│ ↓ ┌───────────────────────────────────────┐ │
│ └──────────────→ 解码器 × N │ │
│ │ ┌──────────────────────────────────┐ │ │
│ │ │ 掩码多头自注意力层 │ │ │
│ │ │ (因果掩码,防止看到未来) │ │ │
│ │ └──────────────────────────────────┘ │ │
│ │ ↓ + 残差连接 │ │
│ │ 层归一化 │ │
│ │ ↓ │ │
│ │ ┌──────────────────────────────────┐ │ │
│ │ │ 编码器-解码器注意力层 │ │ │
│ │ │ Q来自解码器,K,V来自编码器 │ │ │
│ │ └──────────────────────────────────┘ │ │
│ │ ↓ + 残差连接 │ │
│ │ 层归一化 │ │
│ │ ↓ │ │
│ │ ┌──────────────────────────────────┐ │ │
│ │ │ 前馈神经网络 │ │ │
│ │ └──────────────────────────────────┘ │ │
│ │ ↓ + 残差连接 │ │
│ │ 层归一化 │ │
│ └───────────────────┬───────────────────┘ │
│ ↓ │
│ 线性层 + Softmax │
│ ↓ │
│ 输出概率 │
└─────────────────────────────────────────────────────────────────┘
4.2 编码器详解
4.2.1 单个编码器层
输入: X ∈ R^{n×d_model}
↓
┌─────────────────────────────────────────┐
│ 多头自注意力 (Multi-Head) │
│ │
│ 对于每个头 h_i: │
│ Q_i = XW_i^Q, K_i = XW_i^K, V_i = XW_i^V │
│ head_i = Attention(Q_i, K_i, V_i) │
│ │
│ MultiHead = Concat(head_1, ..., head_h)W^O │
└─────────────────────────────────────────┘
↓
Add & Norm: LayerNorm(X + MultiHead(X))
↓
┌─────────────────────────────────────────┐
│ 前馈神经网络 (FFN) │
│ │
│ FFN(x) = max(0, xW_1 + b_1)W_2 + b_2 │
│ │
│ 或使用GLU变体: │
│ FFN(x) = (xW_1) ⊙ σ(xW_3) W_2 │
└─────────────────────────────────────────┘
↓
Add & Norm: LayerNorm(x + FFN(x))
↓
输出: Y ∈ R^{n×d_model}
4.2.2 编码器堆叠
python
class Encoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x) # 最终层归一化
4.3 解码器详解
4.3.1 单个解码器层
输入: Y ∈ R^{m×d_model}
↓
┌─────────────────────────────────────────┐
│ 掩码多头自注意力 (Masked) │
│ │
│ 防止位置关注后续位置(因果掩码) │
│ Mask[i,j] = -∞ if j > i else 0 │
└─────────────────────────────────────────┘
↓
Add & Norm
↓
┌─────────────────────────────────────────┐
│ 编码器-解码器注意力 │
│ │
│ Q = 解码器输出 │
│ K, V = 编码器输出 │
│ │
│ 让解码器关注输入序列的相关部分 │
└─────────────────────────────────────────┘
↓
Add & Norm
↓
┌─────────────────────────────────────────┐
│ 前馈神经网络 (FFN) │
└─────────────────────────────────────────┘
↓
Add & Norm
↓
输出: Y' ∈ R^{m×d_model}
4.3.2 因果掩码
python
def create_causal_mask(seq_len):
"""
创建因果掩码,防止解码器看到未来位置
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
# 示例:seq_len = 4
# tensor([[0., -inf, -inf, -inf],
# [0., 0., -inf, -inf],
# [0., 0., 0., -inf],
# [0., 0., 0., 0.]])
4.4 参数规模
4.4.1 原始Transformer参数
原始Transformer (base):
- d_model = 512
- num_heads = 8
- d_ff = 2048
- num_layers = 6 (编码器) + 6 (解码器)
- vocab_size = 37000 (shared)
参数计算:
- 嵌入层: 37000 × 512 = 18.9M
- 位置编码: 512 × 512 = 0.26M (可学习版本)
- 每个编码器层:
- 自注意力: 4 × 512² = 1.05M
- FFN: 2 × 512 × 2048 = 2.10M
- 总计: ~3.15M
- 6个编码器层: 6 × 3.15M = 18.9M
- 解码器类似,加上交叉注意力: ~25M
- 总参数: ~63M
4.4.2 不同规模的Transformer
| 模型 | 参数量 | d_model | 层数 | 头数 | d_ff |
|---|---|---|---|---|---|
| Transformer-Base | 65M | 512 | 6+6 | 8 | 2048 |
| Transformer-Big | 213M | 1024 | 6+6 | 16 | 4096 |
| BERT-Base | 110M | 768 | 12 | 12 | 3072 |
| BERT-Large | 340M | 1024 | 24 | 16 | 4096 |
| GPT-2 | 1.5B | 1600 | 48 | 25 | 6400 |
| GPT-3 | 175B | 12288 | 96 | 96 | 49152 |
| LLaMA-65B | 65B | 8192 | 80 | 64 | 22016 |
5. 输入表示
5.1 词嵌入 (Word Embedding)
5.1.1 基本概念
将离散的词(或token)映射为连续的向量表示
词汇表 V = {the, a, cat, dog, ...}
嵌入矩阵 E ∈ R^{|V|×d_model}
词"cat" → 嵌入向量 e_cat ∈ R^{d_model}
5.1.2 实现
python
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
# x: [batch_size, seq_len] (token indices)
# output: [batch_size, seq_len, d_model]
return self.embedding(x) * math.sqrt(self.d_model)
注意 :乘以 d m o d e l \sqrt{d_{model}} dmodel 是为了平衡嵌入和位置编码的尺度。
5.2 Tokenization方法
5.2.1 词级Tokenization (Word-level)
输入: "I love natural language processing"
输出: ["I", "love", "natural", "language", "processing"]
优点: 语义清晰
缺点: 词汇表大,OOV问题
5.2.2 字符级Tokenization (Character-level)
输入: "I love"
输出: ["I", " ", "l", "o", "v", "e"]
优点: 词汇表小,无OOV
缺点: 序列长,语义不清晰
5.2.3 子词Tokenization (Subword-level)
BPE (Byte Pair Encoding):
训练过程:
1. 初始化: 所有字符作为基本token
2. 统计相邻token对的频率
3. 合并最频繁的token对
4. 重复直到达到目标词汇表大小
示例:
初始: {'l', 'o', 'w', 'e', 'r', 'n', 'e', 's', 's'}
频率: {'lo': 5, 'ow': 3, 'we': 4, 'er': 7, 'ss': 6}
合并 'er': {'l', 'o', 'w', 'e', 'r', 'n', 'es', 's', 'er'}
继续合并...
WordPiece(BERT使用):
类似BPE,但使用不同的合并准则
选择使语言模型概率最大化的合并
标识: ## 表示子词的延续
"playing" → ["play", "##ing"]
"unhappy" → ["un", "happy"]
SentencePiece:
语言无关的分词器
直接在原始文本上训练
支持BPE和Unigram两种算法
特点:
- 不需要预分词
- 可逆编码
- 多语言支持
5.2.4 现代Tokenization示例
python
from transformers import AutoTokenizer
# GPT-2 Tokenizer (BPE)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "Hello, how are you?"
tokens = tokenizer.tokenize(text)
# ['Hello', ',', 'how', 'are', 'you', '?']
input_ids = tokenizer.encode(text)
# [15496, 11, 703, 389, 345, 30]
# BERT Tokenizer (WordPiece)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer.tokenize(text)
# ['hello', ',', 'how', 'are', 'you', '?']
5.3 位置编码
5.3.1 正弦位置编码
数学公式:
P E ( p o s , 2 i ) = sin ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)
P E ( p o s , 2 i + 1 ) = cos ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)
实现:
python
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
return x + self.pe[:, :x.size(1), :]
设计直觉:
1. 唯一性: 每个位置有唯一的编码
2. 有界性: 值在[-1, 1]之间
3. 相对位置: PE(pos+k)可以表示为PE(pos)的线性函数
4. 泛化性: 可以处理比训练时更长的序列
5.3.2 可学习位置编码
python
class LearnedPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.embedding = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
return x + self.embedding(positions)
对比:
| 特性 | 正弦编码 | 可学习编码 |
|---|---|---|
| 参数量 | 0 | max_len × d_model |
| 泛化性 | 可泛化到更长序列 | 限于训练长度 |
| 性能 | 与可学习版本相当 | 与正弦版本相当 |
| 灵活性 | 固定 | 可学习最优表示 |
5.3.3 旋转位置编码 (RoPE)
RoPE (Rotary Position Embedding) 被现代大语言模型广泛使用:
python
class RotaryPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 计算频率
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer('inv_freq', inv_freq)
# 预计算位置编码
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def forward(self, x, seq_len):
return (
self.cos_cached[:seq_len].to(x.device),
self.sin_cached[:seq_len].to(x.device)
)
def rotate_half(x):
"""旋转一半的维度"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""应用旋转位置编码"""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
RoPE的优势:
- 相对位置信息自然融入注意力计算
- 无需额外参数
- 良好的外推能力
5.4 段落编码与类型编码
5.4.1 段落编码 (Segment Embedding)
用于区分不同的句子或段落
例如BERT:
[CLS] I love cats [SEP] I love dogs [SEP]
A A A A A B B B B
段落编码: A → embedding_A, B → embedding_B
5.4.2 类型编码 (Type Embedding)
用于区分不同类型的token
例如:
- 文本token vs 图像token
- 用户token vs 助手token
第二部分:核心组件详解
6. 自注意力机制
6.1 缩放点积注意力
6.1.1 数学公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
其中:
- Q Q Q (Query): 查询矩阵,形状 [ n , d k ] [n, d_k] [n,dk]
- K K K (Key): 键矩阵,形状 [ m , d k ] [m, d_k] [m,dk]
- V V V (Value): 值矩阵,形状 [ m , d v ] [m, d_v] [m,dv]
- d k d_k dk: 键的维度
- n n n: 查询序列长度
- m m m: 键/值序列长度
6.1.2 计算步骤详解
python
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力
Args:
Q: [batch_size, num_heads, seq_len_q, d_k]
K: [batch_size, num_heads, seq_len_k, d_k]
V: [batch_size, num_heads, seq_len_k, d_v]
mask: [batch_size, 1, seq_len_q, seq_len_k] 或 None
Returns:
output: [batch_size, num_heads, seq_len_q, d_v]
attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k]
"""
d_k = Q.size(-1)
# 步骤1: 计算注意力分数
# scores = Q @ K^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: [batch_size, num_heads, seq_len_q, seq_len_k]
# 步骤2: 应用掩码(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 步骤3: Softmax归一化
attention_weights = F.softmax(scores, dim=-1)
# attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k]
# 步骤4: 加权求和
output = torch.matmul(attention_weights, V)
# output: [batch_size, num_heads, seq_len_q, d_v]
return output, attention_weights
6.1.3 为什么要缩放?
问题 :当 d k d_k dk 较大时, Q K T QK^T QKT 的值会很大
假设 Q, K 的元素独立采样自 N(0, 1)
则 Q·K = Σ(q_i * k_i) 的方差为 d_k
当 d_k = 64 时,Q·K 的标准差约为 8
这会导致 softmax 的输入值很大,梯度接近于0
解决方案 :除以 d k \sqrt{d_k} dk
缩放后,Q·K/sqrt(d_k) 的方差为 1
保持 softmax 在合理的范围内
6.2 注意力机制的直觉
6.2.1 信息检索类比
注意力机制类似于信息检索系统:
数据库中的每条记录有:
- Key: 用于匹配查询的索引
- Value: 记录的实际内容
查询过程:
1. 用户输入 Query
2. 计算 Query 与每个 Key 的相似度
3. 根据相似度对 Value 进行加权求和
4. 返回结果
6.2.2 软寻址
传统哈希表: 硬寻址
Query → 精确匹配一个Key → 返回对应的Value
注意力机制: 软寻址
Query → 与所有Key计算相似度 → 返回所有Value的加权和
优点:
- 可以从多个位置获取信息
- 可微分,可以端到端训练
- 可以处理模糊匹配
6.3 自注意力 vs 交叉注意力
6.3.1 自注意力 (Self-Attention)
Q, K, V 都来自同一个序列
输入: X ∈ R^{n×d}
Q = XW_Q, K = XW_K, V = XW_V
作用: 建模序列内部的依赖关系
示例:
"猫 坐在 垫子 上"
自注意力让"坐在"关注"猫"(主语)和"垫子"(地点)
6.3.2 交叉注意力 (Cross-Attention)
Q 来自一个序列,K, V 来自另一个序列
输入: X_src (源), X_tgt (目标)
Q = X_tgt W_Q, K = X_src W_K, V = X_src W_V
作用: 建模两个序列之间的对应关系
示例(机器翻译):
源: "I love cats"
目标: "我 喜欢 猫"
交叉注意力让"喜欢"关注"love"(翻译对应)
6.4 注意力权重可视化
6.4.1 可视化示例
python
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attention_weights, source_tokens, target_tokens):
"""
可视化注意力权重
Args:
attention_weights: [target_len, source_len]
source_tokens: 源序列的token列表
target_tokens: 目标序列的token列表
"""
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(
attention_weights,
xticklabels=source_tokens,
yticklabels=target_tokens,
cmap='viridis',
ax=ax
)
ax.set_xlabel('Source')
ax.set_ylabel('Target')
ax.set_title('Attention Weights')
plt.tight_layout()
plt.show()
6.4.2 注意力模式
模式1: 对角线模式
- 位置i主要关注位置i
- 常见于局部依赖强的任务
模式2: 垂直模式
- 所有位置关注同一个位置(如[CLS])
- 常见于分类任务
模式3: 块状模式
- 形成明显的块状结构
- 常见于有明确语义边界的任务
模式4: 分散模式
- 注意力分散在多个位置
- 常见于需要综合信息的任务
6.5 注意力的变体
6.5.1 加性注意力 (Additive Attention)
Attention ( Q , K , V ) = softmax ( v T tanh ( W 1 Q + W 2 K ) ) V \text{Attention}(Q, K, V) = \text{softmax}\left(v^T \tanh(W_1 Q + W_2 K)\right) V Attention(Q,K,V)=softmax(vTtanh(W1Q+W2K))V
python
class AdditiveAttention(nn.Module):
def __init__(self, d_model, d_hidden):
super().__init__()
self.W1 = nn.Linear(d_model, d_hidden)
self.W2 = nn.Linear(d_model, d_hidden)
self.v = nn.Linear(d_hidden, 1, bias=False)
def forward(self, Q, K, V):
# Q: [batch, n, d], K: [batch, m, d], V: [batch, m, d]
# 扩展维度以广播
Q_expanded = Q.unsqueeze(2) # [batch, n, 1, d]
K_expanded = K.unsqueeze(1) # [batch, 1, m, d]
# 计算注意力分数
scores = self.v(torch.tanh(
self.W1(Q_expanded) + self.W2(K_expanded)
)).squeeze(-1) # [batch, n, m]
# Softmax和加权求和
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, V)
return output, weights
6.5.2 线性注意力 (Linear Attention)
目标 :将 O ( n 2 ) O(n^2) O(n2) 复杂度降为 O ( n ) O(n) O(n)
Attention ( Q , K , V ) = ϕ ( Q ) ( ϕ ( K ) T V ) ϕ ( Q ) ∑ i ϕ ( k i ) \text{Attention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)\sum_i \phi(k_i)} Attention(Q,K,V)=ϕ(Q)∑iϕ(ki)ϕ(Q)(ϕ(K)TV)
其中 ϕ \phi ϕ 是特征映射函数(如 elu + 1)
python
def linear_attention(Q, K, V):
"""
线性注意力,复杂度 O(n·d²)
"""
# 特征映射
Q = F.elu(Q) + 1
K = F.elu(K) + 1
# 先计算 K^T V
KV = torch.einsum('bnm,bnd->bmd', K, V)
# 计算 Q(K^T V)
output = torch.einsum('bnm,bmd->bnd', Q, KV)
# 归一化
Z = torch.einsum('bnm,bm->bn', Q, K.sum(dim=1))
output = output / (Z.unsqueeze(-1) + 1e-6)
return output
6.5.3 稀疏注意力 (Sparse Attention)
全注意力: 每个位置关注所有位置
稀疏注意力: 每个位置只关注部分位置
常见模式:
1. 局部窗口: 只关注附近的k个位置
2. 跨步模式: 关注固定间隔的位置
3. 随机模式: 随机选择部分位置
4. 组合模式: 结合多种模式
Longformer的稀疏注意力:
python
# 局部窗口 + 全局注意力
# 每个位置关注: 局部窗口内的位置 + 特定的全局位置(如[CLS])
def longformer_attention(Q, K, V, window_size, global_mask):
seq_len = Q.size(2)
# 局部注意力
local_attn = sliding_window_attention(Q, K, V, window_size)
# 全局注意力
global_attn = full_attention(Q, K, V, global_mask)
# 组合
output = local_attn + global_attn
return output
7. 多头注意力
7.1 核心思想
7.1.1 为什么需要多头?
单头注意力的局限:
- 只能学习一种注意力模式
- 不同类型的关系可能需要不同的关注方式
多头注意力的优势:
- 每个头可以学习不同的注意力模式
- 有些头关注语法关系,有些关注语义关系
- 提供更丰富的表示
7.1.2 直观理解
句子: "The cat sat on the mat"
头1 (语法头): 关注主谓关系
"The" → "cat", "cat" → "sat"
头2 (语义头): 关注语义相似性
"cat" → "mat" (都是名词,且押韵)
头3 (位置头): 关注局部依赖
"sat" → "on" → "the" → "mat"
多头组合: 综合多种关系信息
7.2 数学公式
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个头:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)
参数维度:
- W i Q ∈ R d m o d e l × d k W_i^Q \in \mathbb{R}^{d_{model} \times d_k} WiQ∈Rdmodel×dk
- W i K ∈ R d m o d e l × d k W_i^K \in \mathbb{R}^{d_{model} \times d_k} WiK∈Rdmodel×dk
- W i V ∈ R d m o d e l × d v W_i^V \in \mathbb{R}^{d_{model} \times d_v} WiV∈Rdmodel×dv
- W O ∈ R h d v × d m o d e l W^O \in \mathbb{R}^{hd_v \times d_{model}} WO∈Rhdv×dmodel
通常 d k = d v = d m o d e l / h d_k = d_v = d_{model} / h dk=dv=dmodel/h
7.3 实现
python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
缩放点积注意力
"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, V)
return output, attention_weights
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1. 线性投影
Q = self.W_q(Q) # [batch, seq_len, d_model]
K = self.W_k(K)
V = self.W_v(V)
# 2. 分割成多头
# [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. 计算注意力
output, attention_weights = self.scaled_dot_product_attention(
Q, K, V, mask
)
# 4. 合并多头
# [batch, num_heads, seq_len, d_k] -> [batch, seq_len, d_model]
output = output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# 5. 最终线性投影
output = self.W_o(output)
return output, attention_weights
7.4 注意力头的分析
7.4.1 头的功能分化
研究发现,不同的注意力头学习到不同的功能:
语法头 (Syntactic Heads):
- 关注主语-谓语关系
- 关注形容词-名词关系
- 关注介词-宾语关系
位置头 (Positional Heads):
- 关注前一个/后一个位置
- 关注特定位置(如句首)
语义头 (Semantic Heads):
- 关注语义相似的词
- 关注同义词/反义词
稀有头 (Rare Heads):
- 几乎不关注任何位置
- 可能是冗余的
7.4.2 头的重要性
python
def compute_head_importance(model, dataloader):
"""
计算每个注意力头的重要性
"""
head_importance = torch.zeros(model.num_layers, model.num_heads)
for batch in dataloader:
outputs = model(batch, output_attentions=True)
loss = outputs.loss
for layer in range(model.num_layers):
for head in range(model.num_heads):
# 使用梯度衡量重要性
attn = outputs.attentions[layer][:, head]
attn.retain_grad()
loss.backward(retain_graph=True)
head_importance[layer, head] += attn.grad.abs().mean()
return head_importance
7.5 多头注意力的变体
7.5.1 分组查询注意力 (GQA)
Grouped Query Attention 被 LLaMA 2、Mistral 等模型使用:
python
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_heads):
"""
Args:
num_heads: 查询头数 (如 32)
num_kv_heads: 键值头数 (如 8)
每 num_heads/num_kv_heads 个查询头共享一组键值头
"""
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_heads_per_group = num_heads // num_kv_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, num_heads * self.d_k)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 投影
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_kv_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_kv_heads, self.d_k).transpose(1, 2)
# 扩展K, V以匹配Q的头数
K = K.repeat_interleave(self.num_heads_per_group, dim=1)
V = V.repeat_interleave(self.num_heads_per_group, dim=1)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
# 合并
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
output = self.W_o(output)
return output
优势:
- 减少KV缓存大小
- 推理速度更快
- 性能损失很小
7.5.2 多查询注意力 (MQA)
Multi-Query Attention 是GQA的极端情况(num_kv_heads=1):
所有查询头共享同一组键值头
优势:
- KV缓存最小
- 推理最快
劣势:
- 可能损失一些性能
8. 前馈神经网络
8.1 基本FFN
8.1.1 结构
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
或使用GELU激活:
FFN(x) = GELU(xW_1 + b_1)W_2 + b_2
维度变化:
输入: x ∈ R^{d_model}
中间: h ∈ R^{d_ff} (通常 d_ff = 4 × d_model)
输出: y ∈ R^{d_model}
参数:
W_1 ∈ R^{d_model × d_ff}, b_1 ∈ R^{d_ff}
W_2 ∈ R^{d_ff × d_model}, b_2 ∈ R^{d_model}
8.1.2 实现
python
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1, activation='gelu'):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'gelu':
self.activation = nn.GELU()
elif activation == 'silu':
self.activation = nn.SiLU()
def forward(self, x):
# x: [batch, seq_len, d_model]
x = self.linear1(x) # [batch, seq_len, d_ff]
x = self.activation(x) # [batch, seq_len, d_ff]
x = self.dropout(x) # [batch, seq_len, d_ff]
x = self.linear2(x) # [batch, seq_len, d_model]
return x
8.2 GLU变体
Gated Linear Unit (GLU) 及其变体被广泛用于现代大语言模型:
8.2.1 GLU
GLU ( x ) = ( x W 1 ) ⊗ σ ( x W 3 ) \text{GLU}(x) = (xW_1) \otimes \sigma(xW_3) GLU(x)=(xW1)⊗σ(xW3)
其中 ⊗ \otimes ⊗ 是逐元素乘法, σ \sigma σ 是sigmoid函数
8.2.2 SwiGLU (LLaMA, PaLM使用)
SwiGLU ( x ) = Swish ( x W 1 ) ⊗ ( x W 3 ) \text{SwiGLU}(x) = \text{Swish}(xW_1) \otimes (xW_3) SwiGLU(x)=Swish(xW1)⊗(xW3)
其中 Swish ( x ) = x ⋅ σ ( x ) \text{Swish}(x) = x \cdot \sigma(x) Swish(x)=x⋅σ(x)
python
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
8.2.3 GeGLU
GeGLU ( x ) = GELU ( x W 1 ) ⊗ ( x W 3 ) \text{GeGLU}(x) = \text{GELU}(xW_1) \otimes (xW_3) GeGLU(x)=GELU(xW1)⊗(xW3)
python
class GeGLUFFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w2(F.gelu(self.w1(x)) * self.w3(x))
8.3 FFN的作用
8.3.1 记忆存储
研究发现FFN可以看作"键值记忆":
第一层 (W_1): 键匹配
h = xW_1 → 计算x与每个"键"的匹配度
激活函数: 选择性
只有匹配度高的"键"被激活
第二层 (W_2): 值读取
y = hW_2 → 读取对应的"值"
因此,FFN存储了大量的事实知识
8.3.2 与注意力的互补
注意力层: 信息聚合
- 从其他位置获取信息
- 建模位置间的关系
FFN层: 信息处理
- 对每个位置独立处理
- 应用非线性变换
- 存储和检索知识
两者交替使用,形成强大的表示能力
9. 残差连接与层归一化
9.1 残差连接
9.1.1 动机
问题: 深度网络的退化问题
- 理论上更深的网络应该不比浅网络差
- 实际上,过深的网络训练困难,性能下降
原因: 梯度消失/爆炸
- 深层梯度在反向传播中逐渐衰减或膨胀
9.1.2 残差连接
output = Layer ( x ) + x \text{output} = \text{Layer}(x) + x output=Layer(x)+x
python
class ResidualBlock(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, x):
return self.layer(x) + x
优势:
- 梯度可以直接流过跳跃连接
- 缓解梯度消失问题
- 允许训练更深的网络
9.1.3 预归一化 vs 后归一化
后归一化 (Post-Norm) - 原始Transformer:
x o u t = LayerNorm ( x + SubLayer ( x ) ) x_{out} = \text{LayerNorm}(x + \text{SubLayer}(x)) xout=LayerNorm(x+SubLayer(x))
预归一化 (Pre-Norm) - 现代Transformer:
x o u t = x + SubLayer ( LayerNorm ( x ) ) x_{out} = x + \text{SubLayer}(\text{LayerNorm}(x)) xout=x+SubLayer(LayerNorm(x))
python
# 后归一化
class PostNormBlock(nn.Module):
def __init__(self, sublayer, d_model):
super().__init__()
self.sublayer = sublayer
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
return self.norm(x + self.sublayer(x))
# 预归一化
class PreNormBlock(nn.Module):
def __init__(self, sublayer, d_model):
super().__init__()
self.sublayer = sublayer
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
return x + self.sublayer(self.norm(x))
对比:
| 特性 | 后归一化 | 预归一化 |
|---|---|---|
| 训练稳定性 | 需要warmup | 更稳定 |
| 最终性能 | 略高 | 略低 |
| 实现复杂度 | 简单 | 简单 |
| 主流选择 | 早期模型 | 现代大模型 |
9.2 层归一化
9.2.1 批归一化 vs 层归一化
批归一化 (Batch Normalization) :
x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵ xi−μB
其中 μ B , σ B \mu_B, \sigma_B μB,σB 是batch维度的统计量
层归一化 (Layer Normalization) :
x ^ i = x i − μ L σ L 2 + ϵ \hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}} x^i=σL2+ϵ xi−μL
其中 μ L , σ L \mu_L, \sigma_L μL,σL 是特征维度的统计量
python
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
# x: [batch, seq_len, d_model]
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
9.2.2 为什么用层归一化?
批归一化的问题:
1. 依赖batch size,小batch时统计量不稳定
2. 训练和推理时行为不同
3. 序列长度变化时需要重新计算
4. 难以处理变长序列
层归一化的优势:
1. 对每个样本独立归一化
2. 训练和推理行为一致
3. 不依赖batch size
4. 适合序列任务
9.2.3 RMSNorm
RMSNorm (Root Mean Square Normalization) 被现代大语言模型广泛使用:
RMSNorm ( x ) = x 1 d ∑ i = 1 d x i 2 + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon}} \cdot \gamma RMSNorm(x)=d1∑i=1dxi2+ϵ x⋅γ
python
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
优势:
- 计算更快(不需要计算均值)
- 效果相当
- 被LLaMA、Mistral等采用
10. 位置编码
10.1 绝对位置编码
10.1.1 正弦位置编码
(已在5.3.1节详细介绍)
10.1.2 可学习位置编码
(已在5.3.2节详细介绍)
10.2 相对位置编码
10.2.1 动机
绝对位置编码的问题:
1. 无法直接建模相对位置关系
2. 泛化到更长序列的能力有限
相对位置编码的优势:
1. 直接建模位置间的相对关系
2. 更好的长度泛化能力
10.2.2 Shaw et al. (2018) 的相对位置编码
在注意力计算中加入相对位置信息:
e_{ij} = (x_i W^Q)(x_j W^K + a_{ij}^K)^T
其中 a_{ij}^K 是可学习的相对位置偏置
10.2.3 T5的相对位置偏置
python
class T5RelativePositionBias(nn.Module):
def __init__(self, num_heads, num_buckets=32, max_distance=128):
super().__init__()
self.num_heads = num_heads
self.num_buckets = num_buckets
self.max_distance = max_distance
self.bias = nn.Embedding(num_buckets, num_heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
"""
将相对位置映射到有限的桶中
"""
# 对于近距离,使用线性映射
# 对于远距离,使用对数映射
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) /
math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, query_length, key_length):
# 计算相对位置矩阵
context_position = torch.arange(query_length)[:, None]
memory_position = torch.arange(key_length)[None, :]
relative_position = memory_position - context_position
# 映射到桶
relative_position_bucket = self._relative_position_bucket(relative_position)
# 查表获取偏置
values = self.bias(relative_position_bucket)
# [query_length, key_length, num_heads] -> [1, num_heads, query_length, key_length]
values = values.permute(2, 0, 1).unsqueeze(0)
return values
10.3 旋转位置编码 (RoPE)
10.3.1 核心思想
将位置信息编码为旋转矩阵,使得:
- 内积自然包含相对位置信息
- 无需额外参数
- 良好的外推能力
10.3.2 数学推导
对于位置 m m m 的查询 q m q_m qm 和位置 n n n 的键 k n k_n kn,应用旋转:
f ( q , m ) = R Θ , m q f(q, m) = R_{\Theta, m} q f(q,m)=RΘ,mq
其中旋转矩阵:
R Θ , m = ( cos m θ 1 − sin m θ 1 0 0 ⋯ sin m θ 1 cos m θ 1 0 0 ⋯ 0 0 cos m θ 2 − sin m θ 2 ⋯ 0 0 sin m θ 2 cos m θ 2 ⋯ ⋮ ⋮ ⋮ ⋮ ⋱ ) R_{\Theta, m} = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots \\ \vdots & \vdots & \vdots & \vdots & \ddots \end{pmatrix} RΘ,m= cosmθ1sinmθ100⋮−sinmθ1cosmθ100⋮00cosmθ2sinmθ2⋮00−sinmθ2cosmθ2⋮⋯⋯⋯⋯⋱
关键性质:
⟨ f ( q , m ) , f ( k , n ) ⟩ = g ( q , k , m − n ) \langle f(q, m), f(k, n) \rangle = g(q, k, m-n) ⟨f(q,m),f(k,n)⟩=g(q,k,m−n)
即内积只依赖于相对位置 m − n m-n m−n
10.3.3 实现
python
def precompute_freqs_cis(dim, max_len, theta=10000.0):
"""
预计算旋转位置编码的频率
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # 复数形式
return freqs_cis
def apply_rotary_emb(xq, xk, freqs_cis):
"""
应用旋转位置编码
"""
# 将实数转换为复数
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 应用旋转
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)