Transformer架构详解 - 第一、二部分:基础与核心思想、核心组件详解

目录

第一部分:基础与核心思想

  1. 概述
  2. 历史背景与发展
  3. 核心思想
  4. 整体架构概览
  5. 输入表示

第二部分:核心组件详解 (见transformer_guide_part2.md)

  1. 自注意力机制

  2. 多头注意力

  3. 前馈神经网络

  4. 残差连接与层归一化

  5. 位置编码

第三部分:编码器与解码器 (见transformer_guide_part3.md)

  1. 编码器详解

  2. 解码器详解

  3. 编码器-解码器交互

第四部分:训练与优化 (见transformer_guide_part4.md)

  1. 训练方法

  2. 优化技术

  3. 正则化技术

第五部分:变体与演进 (见transformer_guide_part5.md)

  1. Transformer变体

  2. 高效Transformer

  3. 视觉Transformer

第六部分:应用与实践 (见transformer_guide_part6.md)

  1. 自然语言处理应用

  2. 计算机视觉应用

  3. 多模态应用

  4. 代码实现

  5. 参考资料


第一部分:基础与核心思想


1. 概述

1.1 什么是Transformer

Transformer 是一种基于自注意力机制(Self-Attention) 的深度学习架构,由 Google 团队在 2017 年论文 "Attention Is All You Need" 中首次提出。它彻底改变了自然语言处理(NLP)领域,并逐渐扩展到计算机视觉、语音处理、多模态学习等多个领域。

1.2 核心特征

特征 说明
并行计算 不像RNN需要顺序处理,Transformer可以并行处理整个序列
长距离依赖 通过自注意力机制直接建模任意距离的依赖关系
可扩展性 易于扩展到大规模模型(如GPT-3的175B参数)
通用性 同一架构适用于多种任务和模态

1.3 影响力

Transformer架构催生了一系列突破性模型:

复制代码
Transformer (2017)
    ├── NLP领域
    │   ├── BERT (2018) - 预训练语言理解
    │   ├── GPT系列 (2018-2023) - 大规模语言生成
    │   ├── T5 (2019) - 统一文本到文本框架
    │   └── LLaMA (2023) - 开源大语言模型
    │
    ├── CV领域
    │   ├── ViT (2020) - 视觉Transformer
    │   ├── DETR (2020) - 目标检测
    │   └── Swin Transformer (2021) - 层次化视觉Transformer
    │
    ├── 多模态
    │   ├── CLIP (2021) - 图文对比学习
    │   ├── DALL-E (2021) - 文本到图像生成
    │   └── GPT-4V (2023) - 多模态大语言模型
    │
    └── 其他领域
        ├── AlphaFold2 (2021) - 蛋白质结构预测
        ├── Jukebox (2020) - 音乐生成
        └── Whisper (2022) - 语音识别

2. 历史背景与发展

2.1 序列建模的演进

2.1.1 早期方法:N-gram

N-gram模型(1980s-2000s):

复制代码
P(w_t | w_1, ..., w_{t-1}) ≈ P(w_t | w_{t-n+1}, ..., w_{t-1})

例如:3-gram
P("今天" | "我", "喜欢") = Count("我", "喜欢", "今天") / Count("我", "喜欢")

局限性

  • 数据稀疏问题
  • 无法捕捉长距离依赖
  • 存储需求随n指数增长
2.1.2 循环神经网络 (RNN)

基本RNN(1990s):

复制代码
隐藏状态更新:
h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)

输出:
y_t = W_hy * h_t + b_y

架构图

复制代码
x_1 → [RNN] → h_1 → [RNN] → h_2 → ... → [RNN] → h_T
        ↑              ↑                      ↑
        |              |                      |
       h_0            h_1                    h_{T-1}

问题

  • 梯度消失/爆炸问题
  • 无法并行计算
  • 长距离依赖建模困难
2.1.3 LSTM与GRU

LSTM (Long Short-Term Memory, 1997)

复制代码
遗忘门: f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
输入门: i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
候选值: C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C)
细胞状态: C_t = f_t * C_{t-1} + i_t * C̃_t
输出门: o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
隐藏状态: h_t = o_t * tanh(C_t)

GRU (Gated Recurrent Unit, 2014)

复制代码
更新门: z_t = σ(W_z · [h_{t-1}, x_t])
重置门: r_t = σ(W_r · [h_{t-1}, x_t])
候选隐藏状态: h̃_t = tanh(W · [r_t * h_{t-1}, x_t])
隐藏状态: h_t = (1 - z_t) * h_{t-1} + z_t * h̃_t

改进

  • 缓解梯度消失问题
  • 能够学习更长的依赖

仍然存在的问题

  • 顺序计算,无法并行
  • 长距离依赖仍然困难(100-200 tokens)
2.1.4 注意力机制的引入

Bahdanau Attention (2014)

复制代码
在机器翻译中,不再只依赖编码器最后一个隐藏状态,而是动态关注源序列的不同部分:

上下文向量: c_i = Σ α_{ij} * h_j

注意力权重: α_{ij} = exp(e_{ij}) / Σ_k exp(e_{ik})
对齐分数: e_{ij} = a(s_{i-1}, h_j)  # 对齐模型

意义

  • 首次在NLP中引入注意力机制
  • 显著提升机器翻译性能
  • 为Transformer奠定基础

2.2 Transformer的诞生

2.2.1 论文背景

论文信息

  • 标题: "Attention Is All You Need"
  • 作者: Vaswani et al. (Google Brain & Google Research)
  • 发表: NeurIPS 2017
  • 引用数: 100,000+ (截至2024)

核心创新

  1. 完全基于注意力机制,抛弃RNN/CNN
  2. 自注意力(Self-Attention)机制
  3. 多头注意力(Multi-Head Attention)
  4. 位置编码(Positional Encoding)
2.2.2 设计动机

RNN的问题

复制代码
RNN的顺序计算:
x_1 → h_1 → x_2 → h_2 → x_3 → h_3 → ...

问题:
1. 无法并行: 必须等h_1计算完才能算h_2
2. 长距离依赖: h_1000需要经过999步传递
3. 训练慢: 无法利用GPU并行能力

解决方案的思路

复制代码
如果能直接建立任意两个位置之间的连接...

x_1 ←→ x_2
 ↕       ↕
x_3 ←→ x_4

这样:
1. 可以并行计算所有位置
2. 任意距离的依赖都是1步
3. 充分利用GPU

2.3 发展历程

2.3.1 编码器时代 (2018-2019)

BERT (Bidirectional Encoder Representations from Transformers)

复制代码
发布: 2018年10月 (Google)
架构: 仅编码器
参数: 110M (Base), 340M (Large)
训练: Masked Language Model + Next Sentence Prediction
影响: 刷新11项NLP基准

核心思想

复制代码
传统语言模型: 单向(从左到右)
P(w_t | w_1, ..., w_{t-1})

BERT: 双向
P(w_t | w_1, ..., w_{t-1}, w_{t+1}, ..., w_n)

通过Masked Language Model实现:
输入: "我 喜欢 [MASK] 学习"
预测: "喜欢" → "深度"

RoBERTa (2019)

  • 移除NSP任务
  • 更多训练数据
  • 更大批次大小
  • 动态masking
2.3.2 解码器时代 (2018-2020)

GPT (Generative Pre-trained Transformer)

复制代码
发布: 2018年6月 (OpenAI)
架构: 仅解码器
参数: 117M
训练: 自回归语言模型
特点: 生成式预训练 + 判别式微调

GPT-2 (2019)

复制代码
参数: 1.5B
训练数据: WebText (40GB文本)
特点:
- 零样本学习能力
- 多任务学习
- 引发"太危险而不能发布"的讨论

GPT-3 (2020)

复制代码
参数: 175B
训练数据: 300B tokens
特点:
- Few-shot学习
- In-context学习
- 涌现能力
- 开启大语言模型时代
2.3.3 编码器-解码器时代 (2019-2020)

T5 (Text-to-Text Transfer Transformer)

复制代码
发布: 2019年 (Google)
架构: 编码器-解码器
参数: 11B
创新: 将所有NLP任务统一为文本到文本格式

统一框架

复制代码
翻译: "translate English to German: The house is wonderful" → "Das Haus ist wunderbar"
摘要: "summarize: {长文本}" → "{摘要}"
问答: "question: {问题} context: {上下文}" → "{答案}"
分类: "sentiment: {文本}" → "positive/negative"
2.3.4 视觉Transformer时代 (2020-至今)

ViT (Vision Transformer)

复制代码
发布: 2020年10月 (Google)
创新: 将Transformer直接应用于图像
方法: 图像分割为patches,作为序列处理
影响: 开创视觉Transformer时代

Swin Transformer (2021)

  • 层次化结构
  • 窗口注意力
  • 移位窗口
  • 适用于各种视觉任务
2.3.5 大语言模型时代 (2022-至今)

ChatGPT (2022.11)

  • GPT-3.5 + RLHF
  • 对话式交互
  • 引发AI热潮

GPT-4 (2023.3)

  • 多模态(文本+图像)
  • 更强的推理能力
  • 更好的安全性

开源大模型

复制代码
LLaMA (2023.2) - Meta
    ├── LLaMA 2 (2023.7)
    ├── Alpaca (2023.3) - Stanford
    ├── Vicuna (2023.3) - LMSYS
    └── Mistral (2023.9)

Qwen (2023.8) - 阿里
ChatGLM (2023.3) - 清华
DeepSeek (2024) - 幻方

3. 核心思想

3.1 注意力机制的直觉

3.1.1 人类注意力的类比
复制代码
想象你在阅读一段话:

"The animal didn't cross the street because it was too tired"

当你读到 "it" 时,你的大脑会自动关注 "animal" 而不是 "street"
这就是注意力机制的核心思想:根据上下文动态聚焦相关信息
3.1.2 注意力的数学表达
复制代码
查询 (Query): "我在找什么?" - 当前位置的信息需求
键 (Key): "我有什么?" - 每个位置的特征描述
值 (Value): "我能提供什么?" - 每个位置的实际内容

注意力 = f(Query, Key) × Value

3.2 自注意力机制

3.2.1 核心概念

自注意力(Self-Attention)是Transformer的核心创新:

复制代码
传统注意力: Query来自一个序列,Key/Value来自另一个序列
自注意力: Query、Key、Value都来自同一个序列

作用: 让序列中的每个位置都能"看到"并"关注"序列中的所有其他位置
3.2.2 直观理解
复制代码
句子: "我 喜欢 在 图书馆 学习"

对于每个词,计算它与其他所有词的相关性:

"我": 与"喜欢"(主谓关系)、"学习"(施事关系)相关
"喜欢": 与"我"(主谓关系)、"学习"(动宾关系)相关
"图书馆": 与"学习"(地点关系)相关
...

自注意力让每个词都能获取整个句子的上下文信息
3.2.3 与RNN的对比
特性 RNN 自注意力
计算路径长度 O(n) O(1)
并行性 顺序 完全并行
长距离依赖 困难 直接建模
计算复杂度 O(n·d²) O(n²·d)

3.3 位置编码

3.3.1 为什么需要位置编码
复制代码
自注意力机制是"置换等变"的(permutation equivariant)
即:打乱输入顺序,输出也会相应打乱,但不会改变内容

问题:无法区分
"猫 追 狗" vs "狗 追 猫"

解决方案:添加位置信息
3.3.2 正弦位置编码
复制代码
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:
- pos: 位置索引
- i: 维度索引
- d_model: 模型维度

设计直觉

复制代码
1. 每个位置有唯一的编码
2. 相对位置可以通过线性变换获得
3. 可以泛化到更长的序列
4. 值有界,不会发散

3.4 编码器-解码器架构

3.4.1 整体设计思想
复制代码
编码器 (Encoder):
- 理解输入序列
- 提取特征表示
- 双向注意力(可以看到完整输入)

解码器 (Decoder):
- 生成输出序列
- 自回归生成(逐token生成)
- 因果注意力(只能看到之前的输出)

编码器-解码器注意力:
- 解码器关注编码器的输出
- 实现输入到输出的映射
3.4.2 信息流
复制代码
输入序列 → [编码器] → 编码表示
                          ↓
                     [编码器-解码器注意力]
                          ↓
已生成序列 → [解码器] → 下一个token概率

3.5 并行计算

3.5.1 RNN的瓶颈
python 复制代码
# RNN的顺序计算
h = zeros(hidden_size)
for x in sequence:
    h = rnn_cell(x, h)  # 必须等待前一步完成
    # 无法并行!
3.5.2 Transformer的并行性
python 复制代码
# Transformer的并行计算
# 所有位置可以同时计算
Q = X @ W_Q  # [seq_len, d_model] @ [d_model, d_k] = [seq_len, d_k]
K = X @ W_K  # 同上
V = X @ W_V  # 同上

# 注意力计算也是并行的
attention = softmax(Q @ K.T / sqrt(d_k)) @ V  # 矩阵乘法,并行计算

4. 整体架构概览

4.1 原始Transformer架构

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    Transformer 整体架构                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  输入序列                                              输出序列   │
│      ↓                                                      ↑    │
│  ┌────────┐                                           ┌────────┐│
│  │ 输入嵌入│                                           │输出嵌入 ││
│  └────┬───┘                                           └────┬───┘│
│       ↓                                                      ↑    │
│  ┌────────┐                                           ┌────────┐│
│  │位置编码 │                                           │位置编码 ││
│  └────┬───┘                                           └────┬───┘│
│       ↓                                                      ↑    │
│  ┌─────────────────────────────────────────────────────────────┐│
│  │                    编码器 × N                                ││
│  │  ┌──────────────────────────────────────────────────────┐  ││
│  │  │           多头自注意力层                               │  ││
│  │  │    Q = XW_Q, K = XW_K, V = XW_V                     │  ││
│  │  │    Attention = softmax(QK^T/√d_k)V                   │  ││
│  │  └──────────────────────────────────────────────────────┘  ││
│  │                          ↓ + 残差连接                        ││
│  │                       层归一化                               ││
│  │                          ↓                                  ││
│  │  ┌──────────────────────────────────────────────────────┐  ││
│  │  │           前馈神经网络                                 │  ││
│  │  │    FFN = max(0, xW_1 + b_1)W_2 + b_2                 │  ││
│  │  └──────────────────────────────────────────────────────┘  ││
│  │                          ↓ + 残差连接                        ││
│  │                       层归一化                               ││
│  └─────────────────────────────────────────────────────────────┘│
│       ↓                                                      ↑    │
│       ↓              ┌───────────────────────────────────────┐   │
│       └──────────────→       解码器 × N                     │   │
│                      │  ┌──────────────────────────────────┐ │   │
│                      │  │     掩码多头自注意力层            │ │   │
│                      │  │  (因果掩码,防止看到未来)         │ │   │
│                      │  └──────────────────────────────────┘ │   │
│                      │                ↓ + 残差连接            │   │
│                      │             层归一化                    │   │
│                      │                ↓                       │   │
│                      │  ┌──────────────────────────────────┐ │   │
│                      │  │     编码器-解码器注意力层          │ │   │
│                      │  │  Q来自解码器,K,V来自编码器       │ │   │
│                      │  └──────────────────────────────────┘ │   │
│                      │                ↓ + 残差连接            │   │
│                      │             层归一化                    │   │
│                      │                ↓                       │   │
│                      │  ┌──────────────────────────────────┐ │   │
│                      │  │           前馈神经网络            │ │   │
│                      │  └──────────────────────────────────┘ │   │
│                      │                ↓ + 残差连接            │   │
│                      │             层归一化                    │   │
│                      └───────────────────┬───────────────────┘   │
│                                          ↓                       │
│                                    线性层 + Softmax               │
│                                          ↓                       │
│                                      输出概率                     │
└─────────────────────────────────────────────────────────────────┘

4.2 编码器详解

4.2.1 单个编码器层
复制代码
输入: X ∈ R^{n×d_model}
    ↓
┌─────────────────────────────────────────┐
│         多头自注意力 (Multi-Head)        │
│                                          │
│  对于每个头 h_i:                         │
│    Q_i = XW_i^Q, K_i = XW_i^K, V_i = XW_i^V  │
│    head_i = Attention(Q_i, K_i, V_i)    │
│                                          │
│  MultiHead = Concat(head_1, ..., head_h)W^O  │
└─────────────────────────────────────────┘
    ↓
Add & Norm: LayerNorm(X + MultiHead(X))
    ↓
┌─────────────────────────────────────────┐
│         前馈神经网络 (FFN)               │
│                                          │
│  FFN(x) = max(0, xW_1 + b_1)W_2 + b_2  │
│                                          │
│  或使用GLU变体:                          │
│  FFN(x) = (xW_1) ⊙ σ(xW_3) W_2        │
└─────────────────────────────────────────┘
    ↓
Add & Norm: LayerNorm(x + FFN(x))
    ↓
输出: Y ∈ R^{n×d_model}
4.2.2 编码器堆叠
python 复制代码
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)  # 最终层归一化

4.3 解码器详解

4.3.1 单个解码器层
复制代码
输入: Y ∈ R^{m×d_model}
    ↓
┌─────────────────────────────────────────┐
│      掩码多头自注意力 (Masked)           │
│                                          │
│  防止位置关注后续位置(因果掩码)         │
│  Mask[i,j] = -∞ if j > i else 0         │
└─────────────────────────────────────────┘
    ↓
Add & Norm
    ↓
┌─────────────────────────────────────────┐
│      编码器-解码器注意力                  │
│                                          │
│  Q = 解码器输出                           │
│  K, V = 编码器输出                        │
│                                          │
│  让解码器关注输入序列的相关部分           │
└─────────────────────────────────────────┘
    ↓
Add & Norm
    ↓
┌─────────────────────────────────────────┐
│         前馈神经网络 (FFN)               │
└─────────────────────────────────────────┘
    ↓
Add & Norm
    ↓
输出: Y' ∈ R^{m×d_model}
4.3.2 因果掩码
python 复制代码
def create_causal_mask(seq_len):
    """
    创建因果掩码,防止解码器看到未来位置
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

# 示例:seq_len = 4
# tensor([[0., -inf, -inf, -inf],
#         [0., 0., -inf, -inf],
#         [0., 0., 0., -inf],
#         [0., 0., 0., 0.]])

4.4 参数规模

4.4.1 原始Transformer参数
复制代码
原始Transformer (base):
- d_model = 512
- num_heads = 8
- d_ff = 2048
- num_layers = 6 (编码器) + 6 (解码器)
- vocab_size = 37000 (shared)

参数计算:
- 嵌入层: 37000 × 512 = 18.9M
- 位置编码: 512 × 512 = 0.26M (可学习版本)
- 每个编码器层:
  - 自注意力: 4 × 512² = 1.05M
  - FFN: 2 × 512 × 2048 = 2.10M
  - 总计: ~3.15M
- 6个编码器层: 6 × 3.15M = 18.9M
- 解码器类似,加上交叉注意力: ~25M
- 总参数: ~63M
4.4.2 不同规模的Transformer
模型 参数量 d_model 层数 头数 d_ff
Transformer-Base 65M 512 6+6 8 2048
Transformer-Big 213M 1024 6+6 16 4096
BERT-Base 110M 768 12 12 3072
BERT-Large 340M 1024 24 16 4096
GPT-2 1.5B 1600 48 25 6400
GPT-3 175B 12288 96 96 49152
LLaMA-65B 65B 8192 80 64 22016

5. 输入表示

5.1 词嵌入 (Word Embedding)

5.1.1 基本概念
复制代码
将离散的词(或token)映射为连续的向量表示

词汇表 V = {the, a, cat, dog, ...}
嵌入矩阵 E ∈ R^{|V|×d_model}

词"cat" → 嵌入向量 e_cat ∈ R^{d_model}
5.1.2 实现
python 复制代码
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # x: [batch_size, seq_len] (token indices)
        # output: [batch_size, seq_len, d_model]
        return self.embedding(x) * math.sqrt(self.d_model)

注意 :乘以 d m o d e l \sqrt{d_{model}} dmodel 是为了平衡嵌入和位置编码的尺度。

5.2 Tokenization方法

5.2.1 词级Tokenization (Word-level)
复制代码
输入: "I love natural language processing"
输出: ["I", "love", "natural", "language", "processing"]

优点: 语义清晰
缺点: 词汇表大,OOV问题
5.2.2 字符级Tokenization (Character-level)
复制代码
输入: "I love"
输出: ["I", " ", "l", "o", "v", "e"]

优点: 词汇表小,无OOV
缺点: 序列长,语义不清晰
5.2.3 子词Tokenization (Subword-level)

BPE (Byte Pair Encoding)

复制代码
训练过程:
1. 初始化: 所有字符作为基本token
2. 统计相邻token对的频率
3. 合并最频繁的token对
4. 重复直到达到目标词汇表大小

示例:
初始: {'l', 'o', 'w', 'e', 'r', 'n', 'e', 's', 's'}
频率: {'lo': 5, 'ow': 3, 'we': 4, 'er': 7, 'ss': 6}
合并 'er': {'l', 'o', 'w', 'e', 'r', 'n', 'es', 's', 'er'}
继续合并...

WordPiece(BERT使用):

复制代码
类似BPE,但使用不同的合并准则
选择使语言模型概率最大化的合并

标识: ## 表示子词的延续
"playing" → ["play", "##ing"]
"unhappy" → ["un", "happy"]

SentencePiece

复制代码
语言无关的分词器
直接在原始文本上训练
支持BPE和Unigram两种算法

特点:
- 不需要预分词
- 可逆编码
- 多语言支持
5.2.4 现代Tokenization示例
python 复制代码
from transformers import AutoTokenizer

# GPT-2 Tokenizer (BPE)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "Hello, how are you?"
tokens = tokenizer.tokenize(text)
# ['Hello', ',', 'how', 'are', 'you', '?']

input_ids = tokenizer.encode(text)
# [15496, 11, 703, 389, 345, 30]

# BERT Tokenizer (WordPiece)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer.tokenize(text)
# ['hello', ',', 'how', 'are', 'you', '?']

5.3 位置编码

5.3.1 正弦位置编码

数学公式

P E ( p o s , 2 i ) = sin ⁡ ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)

实现

python 复制代码
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        return x + self.pe[:, :x.size(1), :]

设计直觉

复制代码
1. 唯一性: 每个位置有唯一的编码
2. 有界性: 值在[-1, 1]之间
3. 相对位置: PE(pos+k)可以表示为PE(pos)的线性函数
4. 泛化性: 可以处理比训练时更长的序列
5.3.2 可学习位置编码
python 复制代码
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return x + self.embedding(positions)

对比

特性 正弦编码 可学习编码
参数量 0 max_len × d_model
泛化性 可泛化到更长序列 限于训练长度
性能 与可学习版本相当 与正弦版本相当
灵活性 固定 可学习最优表示
5.3.3 旋转位置编码 (RoPE)

RoPE (Rotary Position Embedding) 被现代大语言模型广泛使用:

python 复制代码
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # 计算频率
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算位置编码
        t = torch.arange(max_len).float()
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def forward(self, x, seq_len):
        return (
            self.cos_cached[:seq_len].to(x.device),
            self.sin_cached[:seq_len].to(x.device)
        )

def rotate_half(x):
    """旋转一半的维度"""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    """应用旋转位置编码"""
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

RoPE的优势

  • 相对位置信息自然融入注意力计算
  • 无需额外参数
  • 良好的外推能力

5.4 段落编码与类型编码

5.4.1 段落编码 (Segment Embedding)
复制代码
用于区分不同的句子或段落

例如BERT:
[CLS] I love cats [SEP] I love dogs [SEP]
  A    A   A    A    A    B   B   B    B

段落编码: A → embedding_A, B → embedding_B
5.4.2 类型编码 (Type Embedding)
复制代码
用于区分不同类型的token

例如:
- 文本token vs 图像token
- 用户token vs 助手token

第二部分:核心组件详解


6. 自注意力机制

6.1 缩放点积注意力

6.1.1 数学公式

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • Q Q Q (Query): 查询矩阵,形状 [ n , d k ] [n, d_k] [n,dk]
  • K K K (Key): 键矩阵,形状 [ m , d k ] [m, d_k] [m,dk]
  • V V V (Value): 值矩阵,形状 [ m , d v ] [m, d_v] [m,dv]
  • d k d_k dk: 键的维度
  • n n n: 查询序列长度
  • m m m: 键/值序列长度
6.1.2 计算步骤详解
python 复制代码
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    缩放点积注意力

    Args:
        Q: [batch_size, num_heads, seq_len_q, d_k]
        K: [batch_size, num_heads, seq_len_k, d_k]
        V: [batch_size, num_heads, seq_len_k, d_v]
        mask: [batch_size, 1, seq_len_q, seq_len_k] 或 None

    Returns:
        output: [batch_size, num_heads, seq_len_q, d_v]
        attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k]
    """
    d_k = Q.size(-1)

    # 步骤1: 计算注意力分数
    # scores = Q @ K^T / sqrt(d_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # scores: [batch_size, num_heads, seq_len_q, seq_len_k]

    # 步骤2: 应用掩码(可选)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # 步骤3: Softmax归一化
    attention_weights = F.softmax(scores, dim=-1)
    # attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k]

    # 步骤4: 加权求和
    output = torch.matmul(attention_weights, V)
    # output: [batch_size, num_heads, seq_len_q, d_v]

    return output, attention_weights
6.1.3 为什么要缩放?

问题 :当 d k d_k dk 较大时, Q K T QK^T QKT 的值会很大

复制代码
假设 Q, K 的元素独立采样自 N(0, 1)
则 Q·K = Σ(q_i * k_i) 的方差为 d_k

当 d_k = 64 时,Q·K 的标准差约为 8
这会导致 softmax 的输入值很大,梯度接近于0

解决方案 :除以 d k \sqrt{d_k} dk

复制代码
缩放后,Q·K/sqrt(d_k) 的方差为 1
保持 softmax 在合理的范围内

6.2 注意力机制的直觉

6.2.1 信息检索类比
复制代码
注意力机制类似于信息检索系统:

数据库中的每条记录有:
- Key: 用于匹配查询的索引
- Value: 记录的实际内容

查询过程:
1. 用户输入 Query
2. 计算 Query 与每个 Key 的相似度
3. 根据相似度对 Value 进行加权求和
4. 返回结果
6.2.2 软寻址
复制代码
传统哈希表: 硬寻址
Query → 精确匹配一个Key → 返回对应的Value

注意力机制: 软寻址
Query → 与所有Key计算相似度 → 返回所有Value的加权和

优点:
- 可以从多个位置获取信息
- 可微分,可以端到端训练
- 可以处理模糊匹配

6.3 自注意力 vs 交叉注意力

6.3.1 自注意力 (Self-Attention)
复制代码
Q, K, V 都来自同一个序列

输入: X ∈ R^{n×d}
Q = XW_Q, K = XW_K, V = XW_V

作用: 建模序列内部的依赖关系

示例:
"猫 坐在 垫子 上"
自注意力让"坐在"关注"猫"(主语)和"垫子"(地点)
6.3.2 交叉注意力 (Cross-Attention)
复制代码
Q 来自一个序列,K, V 来自另一个序列

输入: X_src (源), X_tgt (目标)
Q = X_tgt W_Q, K = X_src W_K, V = X_src W_V

作用: 建模两个序列之间的对应关系

示例(机器翻译):
源: "I love cats"
目标: "我 喜欢 猫"

交叉注意力让"喜欢"关注"love"(翻译对应)

6.4 注意力权重可视化

6.4.1 可视化示例
python 复制代码
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, source_tokens, target_tokens):
    """
    可视化注意力权重

    Args:
        attention_weights: [target_len, source_len]
        source_tokens: 源序列的token列表
        target_tokens: 目标序列的token列表
    """
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(
        attention_weights,
        xticklabels=source_tokens,
        yticklabels=target_tokens,
        cmap='viridis',
        ax=ax
    )
    ax.set_xlabel('Source')
    ax.set_ylabel('Target')
    ax.set_title('Attention Weights')
    plt.tight_layout()
    plt.show()
6.4.2 注意力模式
复制代码
模式1: 对角线模式
- 位置i主要关注位置i
- 常见于局部依赖强的任务

模式2: 垂直模式
- 所有位置关注同一个位置(如[CLS])
- 常见于分类任务

模式3: 块状模式
- 形成明显的块状结构
- 常见于有明确语义边界的任务

模式4: 分散模式
- 注意力分散在多个位置
- 常见于需要综合信息的任务

6.5 注意力的变体

6.5.1 加性注意力 (Additive Attention)

Attention ( Q , K , V ) = softmax ( v T tanh ⁡ ( W 1 Q + W 2 K ) ) V \text{Attention}(Q, K, V) = \text{softmax}\left(v^T \tanh(W_1 Q + W_2 K)\right) V Attention(Q,K,V)=softmax(vTtanh(W1Q+W2K))V

python 复制代码
class AdditiveAttention(nn.Module):
    def __init__(self, d_model, d_hidden):
        super().__init__()
        self.W1 = nn.Linear(d_model, d_hidden)
        self.W2 = nn.Linear(d_model, d_hidden)
        self.v = nn.Linear(d_hidden, 1, bias=False)

    def forward(self, Q, K, V):
        # Q: [batch, n, d], K: [batch, m, d], V: [batch, m, d]

        # 扩展维度以广播
        Q_expanded = Q.unsqueeze(2)  # [batch, n, 1, d]
        K_expanded = K.unsqueeze(1)  # [batch, 1, m, d]

        # 计算注意力分数
        scores = self.v(torch.tanh(
            self.W1(Q_expanded) + self.W2(K_expanded)
        )).squeeze(-1)  # [batch, n, m]

        # Softmax和加权求和
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)

        return output, weights
6.5.2 线性注意力 (Linear Attention)

目标 :将 O ( n 2 ) O(n^2) O(n2) 复杂度降为 O ( n ) O(n) O(n)

Attention ( Q , K , V ) = ϕ ( Q ) ( ϕ ( K ) T V ) ϕ ( Q ) ∑ i ϕ ( k i ) \text{Attention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)\sum_i \phi(k_i)} Attention(Q,K,V)=ϕ(Q)∑iϕ(ki)ϕ(Q)(ϕ(K)TV)

其中 ϕ \phi ϕ 是特征映射函数(如 elu + 1)

python 复制代码
def linear_attention(Q, K, V):
    """
    线性注意力,复杂度 O(n·d²)
    """
    # 特征映射
    Q = F.elu(Q) + 1
    K = F.elu(K) + 1

    # 先计算 K^T V
    KV = torch.einsum('bnm,bnd->bmd', K, V)

    # 计算 Q(K^T V)
    output = torch.einsum('bnm,bmd->bnd', Q, KV)

    # 归一化
    Z = torch.einsum('bnm,bm->bn', Q, K.sum(dim=1))
    output = output / (Z.unsqueeze(-1) + 1e-6)

    return output
6.5.3 稀疏注意力 (Sparse Attention)
复制代码
全注意力: 每个位置关注所有位置
稀疏注意力: 每个位置只关注部分位置

常见模式:
1. 局部窗口: 只关注附近的k个位置
2. 跨步模式: 关注固定间隔的位置
3. 随机模式: 随机选择部分位置
4. 组合模式: 结合多种模式

Longformer的稀疏注意力

python 复制代码
# 局部窗口 + 全局注意力
# 每个位置关注: 局部窗口内的位置 + 特定的全局位置(如[CLS])

def longformer_attention(Q, K, V, window_size, global_mask):
    seq_len = Q.size(2)

    # 局部注意力
    local_attn = sliding_window_attention(Q, K, V, window_size)

    # 全局注意力
    global_attn = full_attention(Q, K, V, global_mask)

    # 组合
    output = local_attn + global_attn

    return output

7. 多头注意力

7.1 核心思想

7.1.1 为什么需要多头?
复制代码
单头注意力的局限:
- 只能学习一种注意力模式
- 不同类型的关系可能需要不同的关注方式

多头注意力的优势:
- 每个头可以学习不同的注意力模式
- 有些头关注语法关系,有些关注语义关系
- 提供更丰富的表示
7.1.2 直观理解
复制代码
句子: "The cat sat on the mat"

头1 (语法头): 关注主谓关系
"The" → "cat", "cat" → "sat"

头2 (语义头): 关注语义相似性
"cat" → "mat" (都是名词,且押韵)

头3 (位置头): 关注局部依赖
"sat" → "on" → "the" → "mat"

多头组合: 综合多种关系信息

7.2 数学公式

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中每个头:

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

参数维度:

  • W i Q ∈ R d m o d e l × d k W_i^Q \in \mathbb{R}^{d_{model} \times d_k} WiQ∈Rdmodel×dk
  • W i K ∈ R d m o d e l × d k W_i^K \in \mathbb{R}^{d_{model} \times d_k} WiK∈Rdmodel×dk
  • W i V ∈ R d m o d e l × d v W_i^V \in \mathbb{R}^{d_{model} \times d_v} WiV∈Rdmodel×dv
  • W O ∈ R h d v × d m o d e l W^O \in \mathbb{R}^{hd_v \times d_{model}} WO∈Rhdv×dmodel

通常 d k = d v = d m o d e l / h d_k = d_v = d_{model} / h dk=dv=dmodel/h

7.3 实现

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        缩放点积注意力
        """
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output = torch.matmul(attention_weights, V)
        return output, attention_weights

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 1. 线性投影
        Q = self.W_q(Q)  # [batch, seq_len, d_model]
        K = self.W_k(K)
        V = self.W_v(V)

        # 2. 分割成多头
        # [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 3. 计算注意力
        output, attention_weights = self.scaled_dot_product_attention(
            Q, K, V, mask
        )

        # 4. 合并多头
        # [batch, num_heads, seq_len, d_k] -> [batch, seq_len, d_model]
        output = output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )

        # 5. 最终线性投影
        output = self.W_o(output)

        return output, attention_weights

7.4 注意力头的分析

7.4.1 头的功能分化

研究发现,不同的注意力头学习到不同的功能:

复制代码
语法头 (Syntactic Heads):
- 关注主语-谓语关系
- 关注形容词-名词关系
- 关注介词-宾语关系

位置头 (Positional Heads):
- 关注前一个/后一个位置
- 关注特定位置(如句首)

语义头 (Semantic Heads):
- 关注语义相似的词
- 关注同义词/反义词

稀有头 (Rare Heads):
- 几乎不关注任何位置
- 可能是冗余的
7.4.2 头的重要性
python 复制代码
def compute_head_importance(model, dataloader):
    """
    计算每个注意力头的重要性
    """
    head_importance = torch.zeros(model.num_layers, model.num_heads)

    for batch in dataloader:
        outputs = model(batch, output_attentions=True)
        loss = outputs.loss

        for layer in range(model.num_layers):
            for head in range(model.num_heads):
                # 使用梯度衡量重要性
                attn = outputs.attentions[layer][:, head]
                attn.retain_grad()
                loss.backward(retain_graph=True)
                head_importance[layer, head] += attn.grad.abs().mean()

    return head_importance

7.5 多头注意力的变体

7.5.1 分组查询注意力 (GQA)

Grouped Query Attention 被 LLaMA 2、Mistral 等模型使用:

python 复制代码
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        """
        Args:
            num_heads: 查询头数 (如 32)
            num_kv_heads: 键值头数 (如 8)
            每 num_heads/num_kv_heads 个查询头共享一组键值头
        """
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_heads_per_group = num_heads // num_kv_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, num_heads * self.d_k)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 投影
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_kv_heads, self.d_k).transpose(1, 2)

        # 扩展K, V以匹配Q的头数
        K = K.repeat_interleave(self.num_heads_per_group, dim=1)
        V = V.repeat_interleave(self.num_heads_per_group, dim=1)

        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        # 合并
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        output = self.W_o(output)

        return output

优势

  • 减少KV缓存大小
  • 推理速度更快
  • 性能损失很小
7.5.2 多查询注意力 (MQA)

Multi-Query Attention 是GQA的极端情况(num_kv_heads=1):

复制代码
所有查询头共享同一组键值头

优势:
- KV缓存最小
- 推理最快

劣势:
- 可能损失一些性能

8. 前馈神经网络

8.1 基本FFN

8.1.1 结构
复制代码
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2

或使用GELU激活:
FFN(x) = GELU(xW_1 + b_1)W_2 + b_2

维度变化

复制代码
输入: x ∈ R^{d_model}
中间: h ∈ R^{d_ff} (通常 d_ff = 4 × d_model)
输出: y ∈ R^{d_model}

参数:
W_1 ∈ R^{d_model × d_ff}, b_1 ∈ R^{d_ff}
W_2 ∈ R^{d_ff × d_model}, b_2 ∈ R^{d_model}
8.1.2 实现
python 复制代码
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1, activation='gelu'):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'silu':
            self.activation = nn.SiLU()

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        x = self.linear1(x)      # [batch, seq_len, d_ff]
        x = self.activation(x)    # [batch, seq_len, d_ff]
        x = self.dropout(x)       # [batch, seq_len, d_ff]
        x = self.linear2(x)      # [batch, seq_len, d_model]
        return x

8.2 GLU变体

Gated Linear Unit (GLU) 及其变体被广泛用于现代大语言模型:

8.2.1 GLU

GLU ( x ) = ( x W 1 ) ⊗ σ ( x W 3 ) \text{GLU}(x) = (xW_1) \otimes \sigma(xW_3) GLU(x)=(xW1)⊗σ(xW3)

其中 ⊗ \otimes ⊗ 是逐元素乘法, σ \sigma σ 是sigmoid函数

8.2.2 SwiGLU (LLaMA, PaLM使用)

SwiGLU ( x ) = Swish ( x W 1 ) ⊗ ( x W 3 ) \text{SwiGLU}(x) = \text{Swish}(xW_1) \otimes (xW_3) SwiGLU(x)=Swish(xW1)⊗(xW3)

其中 Swish ( x ) = x ⋅ σ ( x ) \text{Swish}(x) = x \cdot \sigma(x) Swish(x)=x⋅σ(x)

python 复制代码
class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
8.2.3 GeGLU

GeGLU ( x ) = GELU ( x W 1 ) ⊗ ( x W 3 ) \text{GeGLU}(x) = \text{GELU}(xW_1) \otimes (xW_3) GeGLU(x)=GELU(xW1)⊗(xW3)

python 复制代码
class GeGLUFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w2(F.gelu(self.w1(x)) * self.w3(x))

8.3 FFN的作用

8.3.1 记忆存储
复制代码
研究发现FFN可以看作"键值记忆":

第一层 (W_1): 键匹配
h = xW_1 → 计算x与每个"键"的匹配度

激活函数: 选择性
只有匹配度高的"键"被激活

第二层 (W_2): 值读取
y = hW_2 → 读取对应的"值"

因此,FFN存储了大量的事实知识
8.3.2 与注意力的互补
复制代码
注意力层: 信息聚合
- 从其他位置获取信息
- 建模位置间的关系

FFN层: 信息处理
- 对每个位置独立处理
- 应用非线性变换
- 存储和检索知识

两者交替使用,形成强大的表示能力

9. 残差连接与层归一化

9.1 残差连接

9.1.1 动机
复制代码
问题: 深度网络的退化问题
- 理论上更深的网络应该不比浅网络差
- 实际上,过深的网络训练困难,性能下降

原因: 梯度消失/爆炸
- 深层梯度在反向传播中逐渐衰减或膨胀
9.1.2 残差连接

output = Layer ( x ) + x \text{output} = \text{Layer}(x) + x output=Layer(x)+x

python 复制代码
class ResidualBlock(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, x):
        return self.layer(x) + x

优势

  • 梯度可以直接流过跳跃连接
  • 缓解梯度消失问题
  • 允许训练更深的网络
9.1.3 预归一化 vs 后归一化

后归一化 (Post-Norm) - 原始Transformer:
x o u t = LayerNorm ( x + SubLayer ( x ) ) x_{out} = \text{LayerNorm}(x + \text{SubLayer}(x)) xout=LayerNorm(x+SubLayer(x))

预归一化 (Pre-Norm) - 现代Transformer:
x o u t = x + SubLayer ( LayerNorm ( x ) ) x_{out} = x + \text{SubLayer}(\text{LayerNorm}(x)) xout=x+SubLayer(LayerNorm(x))

python 复制代码
# 后归一化
class PostNormBlock(nn.Module):
    def __init__(self, sublayer, d_model):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return self.norm(x + self.sublayer(x))

# 预归一化
class PreNormBlock(nn.Module):
    def __init__(self, sublayer, d_model):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return x + self.sublayer(self.norm(x))

对比

特性 后归一化 预归一化
训练稳定性 需要warmup 更稳定
最终性能 略高 略低
实现复杂度 简单 简单
主流选择 早期模型 现代大模型

9.2 层归一化

9.2.1 批归一化 vs 层归一化

批归一化 (Batch Normalization)
x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵ xi−μB

其中 μ B , σ B \mu_B, \sigma_B μB,σB 是batch维度的统计量

层归一化 (Layer Normalization)
x ^ i = x i − μ L σ L 2 + ϵ \hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}} x^i=σL2+ϵ xi−μL

其中 μ L , σ L \mu_L, \sigma_L μL,σL 是特征维度的统计量

python 复制代码
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta
9.2.2 为什么用层归一化?
复制代码
批归一化的问题:
1. 依赖batch size,小batch时统计量不稳定
2. 训练和推理时行为不同
3. 序列长度变化时需要重新计算
4. 难以处理变长序列

层归一化的优势:
1. 对每个样本独立归一化
2. 训练和推理行为一致
3. 不依赖batch size
4. 适合序列任务
9.2.3 RMSNorm

RMSNorm (Root Mean Square Normalization) 被现代大语言模型广泛使用:

RMSNorm ( x ) = x 1 d ∑ i = 1 d x i 2 + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon}} \cdot \gamma RMSNorm(x)=d1∑i=1dxi2+ϵ x⋅γ

python 复制代码
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

优势

  • 计算更快(不需要计算均值)
  • 效果相当
  • 被LLaMA、Mistral等采用

10. 位置编码

10.1 绝对位置编码

10.1.1 正弦位置编码

(已在5.3.1节详细介绍)

10.1.2 可学习位置编码

(已在5.3.2节详细介绍)

10.2 相对位置编码

10.2.1 动机
复制代码
绝对位置编码的问题:
1. 无法直接建模相对位置关系
2. 泛化到更长序列的能力有限

相对位置编码的优势:
1. 直接建模位置间的相对关系
2. 更好的长度泛化能力
10.2.2 Shaw et al. (2018) 的相对位置编码
复制代码
在注意力计算中加入相对位置信息:

e_{ij} = (x_i W^Q)(x_j W^K + a_{ij}^K)^T

其中 a_{ij}^K 是可学习的相对位置偏置
10.2.3 T5的相对位置偏置
python 复制代码
class T5RelativePositionBias(nn.Module):
    def __init__(self, num_heads, num_buckets=32, max_distance=128):
        super().__init__()
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.bias = nn.Embedding(num_buckets, num_heads)

    @staticmethod
    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
        """
        将相对位置映射到有限的桶中
        """
        # 对于近距离,使用线性映射
        # 对于远距离,使用对数映射
        ret = 0
        n = -relative_position

        num_buckets //= 2
        ret += (n < 0).long() * num_buckets
        n = torch.abs(n)

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) /
            math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, query_length, key_length):
        # 计算相对位置矩阵
        context_position = torch.arange(query_length)[:, None]
        memory_position = torch.arange(key_length)[None, :]
        relative_position = memory_position - context_position

        # 映射到桶
        relative_position_bucket = self._relative_position_bucket(relative_position)

        # 查表获取偏置
        values = self.bias(relative_position_bucket)

        # [query_length, key_length, num_heads] -> [1, num_heads, query_length, key_length]
        values = values.permute(2, 0, 1).unsqueeze(0)
        return values

10.3 旋转位置编码 (RoPE)

10.3.1 核心思想
复制代码
将位置信息编码为旋转矩阵,使得:
- 内积自然包含相对位置信息
- 无需额外参数
- 良好的外推能力
10.3.2 数学推导

对于位置 m m m 的查询 q m q_m qm 和位置 n n n 的键 k n k_n kn,应用旋转:

f ( q , m ) = R Θ , m q f(q, m) = R_{\Theta, m} q f(q,m)=RΘ,mq

其中旋转矩阵:

R Θ , m = ( cos ⁡ m θ 1 − sin ⁡ m θ 1 0 0 ⋯ sin ⁡ m θ 1 cos ⁡ m θ 1 0 0 ⋯ 0 0 cos ⁡ m θ 2 − sin ⁡ m θ 2 ⋯ 0 0 sin ⁡ m θ 2 cos ⁡ m θ 2 ⋯ ⋮ ⋮ ⋮ ⋮ ⋱ ) R_{\Theta, m} = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots \\ \vdots & \vdots & \vdots & \vdots & \ddots \end{pmatrix} RΘ,m= cosmθ1sinmθ100⋮−sinmθ1cosmθ100⋮00cosmθ2sinmθ2⋮00−sinmθ2cosmθ2⋮⋯⋯⋯⋯⋱

关键性质:
⟨ f ( q , m ) , f ( k , n ) ⟩ = g ( q , k , m − n ) \langle f(q, m), f(k, n) \rangle = g(q, k, m-n) ⟨f(q,m),f(k,n)⟩=g(q,k,m−n)

即内积只依赖于相对位置 m − n m-n m−n

10.3.3 实现
python 复制代码
def precompute_freqs_cis(dim, max_len, theta=10000.0):
    """
    预计算旋转位置编码的频率
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_len)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # 复数形式
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """
    应用旋转位置编码
    """
    # 将实数转换为复数
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 应用旋转
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)
相关推荐
白日做梦Q3 小时前
Label Studio 安装与使用完整文档(可直接复制部署)
深度学习·yolo·计算机视觉
快乐on9仔4 小时前
NLP学习(一)transformers之pipeline体验
人工智能·深度学习
Black蜡笔小新6 小时前
企业私有化AI训练推理一体工作站DLTM深度学习推理工作站重塑安全监控智能化体系
人工智能·深度学习
小a彤6 小时前
ops-nn 快速上手 - 神经网络算子使用入门指南
人工智能·深度学习·神经网络
Yunzenn6 小时前
深度分析字节最新研究cola-DLM 第 07 章:推理流水线逐行拆解 —— 从 prompt 到生成文本
人工智能·驱动开发·深度学习·chatgpt·架构·prompt·github
lqqjuly6 小时前
大语言模型 (LLM) 详解
人工智能·语言模型·自然语言处理
AI医影跨模态组学6 小时前
J Hepatol(IF=33.0)英国帝国理工学院:基于机器学习的影像组学模型在预测肝细胞癌免疫治疗结局中优于临床生物标志物
人工智能·深度学习·机器学习·论文·医学影像·影像组学
ujainu小6 小时前
CANN ops-transformer:MC2 通算融合如何减少通信开销
人工智能·深度学习·transformer
财经资讯数据_灵砚智能7 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月26日
大数据·人工智能·python·信息可视化·自然语言处理·ai编程·灵砚智能