Transformer架构详解:起源、原理与应用
目录
- Transformer的起源与历史背景
- Transformer核心架构详解
- 自注意力机制深入解析
- Transformer在结构化数据中的应用
- Transformer在图像数据中的应用
- 性能优化与变体
- 实际应用案例
1. Transformer的起源与历史背景
1.1 深度学习序列建模的演进
在Transformer出现之前,自然语言处理领域主要依赖以下架构:
循环神经网络时代(2010-2017)
- RNN (Recurrent Neural Networks):基础循环架构,存在梯度消失问题
- LSTM (Long Short-Term Memory):通过门控机制解决长期依赖问题
- GRU (Gated Recurrent Units):LSTM的简化版本
- Seq2Seq模型:编码器-解码器架构,用于机器翻译
RNN架构的局限性:
- 顺序计算:无法并行化,训练效率低
- 长期依赖问题:即使使用LSTM,处理超长序列仍然困难
- 梯度传播路径长:容易出现梯度消失或爆炸
- 信息瓶颈:编码器需要将所有信息压缩到固定维度的向量中
1.2 注意力机制的诞生
Bahdanau Attention (2014)
- 由Dzmitry Bahdanau等人提出
- 允许解码器在生成每个词时关注输入序列的不同部分
- 解决了Seq2Seq模型的信息瓶颈问题
核心思想:
对于输出序列的每个位置,计算其与输入序列所有位置的相关性
相关性越高的位置获得更大的权重
1.3 "Attention is All You Need" 论文
发表信息:
- 时间:2017年6月(arXiv)
- 作者团队:Google Brain和Google Research
- 核心作者:Ashish Vaswani, Noam Shazeer, Niki Parmar等
- 发表会议:NIPS 2017(现NeurIPS)
革命性贡献:
- 完全抛弃循环结构:首次提出纯注意力架构
- 并行化计算:大幅提升训练和推理速度
- 多头注意力机制:从多个子空间捕获信息
- 位置编码:通过三角函数编码位置信息
- 残差连接与层归一化:稳定深层网络训练
影响力:
- 截至2024年,论文引用次数超过10万次
- 催生了BERT、GPT、T5等众多里程碑模型
- 从NLP扩展到CV、语音、生物信息学等多个领域
2. Transformer核心架构详解
2.1 整体架构
Transformer采用编码器-解码器(Encoder-Decoder)架构:
输入序列 → [嵌入层 + 位置编码] → 编码器栈 → 解码器栈 → 输出概率分布
↓
上下文向量
核心组件:
- 编码器(Encoder):6层堆叠,每层包含多头自注意力和前馈网络
- 解码器(Decoder):6层堆叠,包含掩码自注意力、交叉注意力和前馈网络
- 注意力机制:核心计算单元
- 位置编码:注入序列位置信息
2.2 编码器详细结构
单个编码器层包含:
python
# 伪代码表示
class EncoderLayer:
def forward(x):
# 1. 多头自注意力
attn_output = MultiHeadAttention(Q=x, K=x, V=x)
x = LayerNorm(x + attn_output) # 残差连接 + 层归一化
# 2. 位置前馈网络
ffn_output = FeedForward(x)
x = LayerNorm(x + ffn_output) # 残差连接 + 层归一化
return x
参数配置(原论文):
- 模型维度 (d_model):512
- 头数 (num_heads):8
- 前馈网络维度 (d_ff):2048
- Dropout率:0.1
2.3 解码器详细结构
单个解码器层包含:
python
# 伪代码表示
class DecoderLayer:
def forward(x, encoder_output):
# 1. 掩码多头自注意力(防止看到未来信息)
masked_attn = MaskedMultiHeadAttention(Q=x, K=x, V=x)
x = LayerNorm(x + masked_attn)
# 2. 编码器-解码器注意力(交叉注意力)
cross_attn = MultiHeadAttention(
Q=x,
K=encoder_output,
V=encoder_output
)
x = LayerNorm(x + cross_attn)
# 3. 位置前馈网络
ffn_output = FeedForward(x)
x = LayerNorm(x + ffn_output)
return x
关键差异:
- 第一个注意力层使用掩码,确保位置i只能关注i之前的位置
- 增加交叉注意力层,Query来自解码器,Key和Value来自编码器输出
2.4 位置编码(Positional Encoding)
由于Transformer没有循环或卷积结构,需要显式注入位置信息。
公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中:
pos:序列中的位置(0到max_len-1)i:维度索引(0到d_model/2-1)d_model:模型维度
特点:
- 确定性:不需要学习,节省参数
- 外推性:理论上可以处理训练时未见过的序列长度
- 相对位置关系:PE(pos+k)可以表示为PE(pos)的线性函数
可视化示例:
位置0: [sin(0/10000^0), cos(0/10000^0), sin(0/10000^(2/512)), ...]
位置1: [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/512)), ...]
位置2: [sin(2/10000^0), cos(2/10000^0), sin(2/10000^(2/512)), ...]
...
3. 自注意力机制深入解析
3.1 缩放点积注意力(Scaled Dot-Product Attention)
核心公式:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
计算步骤:
-
计算相似度得分:
Scores = QK^T 维度:(seq_len_q, d_k) × (d_k, seq_len_k) = (seq_len_q, seq_len_k) -
缩放:
Scaled_Scores = Scores / √d_k- 目的:防止点积结果过大导致softmax梯度过小
- 当d_k较大时,点积方差为d_k,缩放使方差归一化
-
应用Softmax:
Attention_Weights = softmax(Scaled_Scores)- 每一行的权重和为1
- 表示Query对所有Key的注意力分布
-
加权求和:
Output = Attention_Weights × V 维度:(seq_len_q, seq_len_k) × (seq_len_k, d_v) = (seq_len_q, d_v)
示例计算(简化版):
python
import numpy as np
# 假设我们有3个词,每个词的表示维度为4
Q = np.array([[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 1, 0, 0]])
K = np.array([[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 1, 0, 0]])
V = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 1. 计算QK^T
scores = np.dot(Q, K.T) # 结果:[[2, 0, 2], [0, 2, 1], [2, 1, 2]]
# 2. 缩放
d_k = Q.shape[-1]
scaled_scores = scores / np.sqrt(d_k) # 除以2
# 3. Softmax
attention_weights = np.exp(scaled_scores) / np.exp(scaled_scores).sum(axis=-1, keepdims=True)
# 4. 加权求和
output = np.dot(attention_weights, V)
3.2 多头注意力(Multi-Head Attention)
核心思想:
- 不是只计算一次注意力,而是并行计算h次
- 每个"头"学习不同的注意力模式
- 类似于CNN中的多通道
公式:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
其中 head_i = Attention(QW^Q_i, KW^K_i, VW^V_i)
参数矩阵:
- W^Q_i ∈ ℝ^(d_model × d_k):Query投影矩阵
- W^K_i ∈ ℝ^(d_model × d_k):Key投影矩阵
- W^V_i ∈ ℝ^(d_model × d_v):Value投影矩阵
- W^O ∈ ℝ^(h·d_v × d_model):输出投影矩阵
典型配置:
d_model = 512
h = 8
d_k = d_v = d_model / h = 64
计算流程图:
输入 (batch, seq_len, d_model=512)
↓
分割成8个头,每个头维度64
↓
并行计算8个注意力
↓
Head 1: Attention(Q1, K1, V1) → (batch, seq_len, 64)
Head 2: Attention(Q2, K2, V2) → (batch, seq_len, 64)
...
Head 8: Attention(Q8, K8, V8) → (batch, seq_len, 64)
↓
拼接 (batch, seq_len, 512)
↓
线性投影 W^O
↓
输出 (batch, seq_len, d_model=512)
为什么使用多头?
-
不同表示子空间:每个头可以关注不同方面的信息
- 某个头可能关注语法关系
- 某个头可能关注语义相似性
- 某个头可能关注位置关系
-
增加模型容量:在不增加参数量的情况下增强表达能力
-
并行计算:多个头可以同时计算
3.3 掩码机制(Masking)
填充掩码(Padding Mask):
python
# 用于处理变长序列
# 假设序列长度为5,实际有效长度为3
sequence = [token1, token2, token3, <PAD>, <PAD>]
padding_mask = [0, 0, 0, 1, 1] # 1表示需要掩码的位置
# 在计算注意力前应用
scores = scores.masked_fill(padding_mask == 1, -1e9)
# Softmax后,被掩码的位置权重接近0
前瞻掩码(Look-Ahead Mask):
python
# 用于解码器自注意力,防止看到未来信息
# 序列长度为4的掩码矩阵
look_ahead_mask = [
[0, 1, 1, 1], # 位置0只能看到位置0
[0, 0, 1, 1], # 位置1只能看到位置0-1
[0, 0, 0, 1], # 位置2只能看到位置0-2
[0, 0, 0, 0], # 位置3可以看到所有位置
]
# 0表示允许注意,1表示掩码
3.4 前馈网络(Feed-Forward Network)
结构:
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
特点:
- 两层全连接网络
- 中间层使用ReLU激活
- 每个位置独立应用(position-wise)
- 参数在所有位置共享
维度变化:
(batch, seq_len, d_model=512)
↓ 第一层
(batch, seq_len, d_ff=2048)
↓ ReLU + 第二层
(batch, seq_len, d_model=512)
作用:
- 增加模型的非线性变换能力
- 每个位置可以进行独立的特征变换
- 类似于1×1卷积的作用
4. Transformer在结构化数据中的应用
4.1 结构化数据的挑战
表格数据特点:
- 异构特征:数值型、类别型混合
- 特征交互:特征之间存在复杂的非线性关系
- 数据量相对较小:通常几千到几十万样本
- 可解释性需求:金融、医疗等领域需要理解模型决策
传统方法:
- 树模型:XGBoost, LightGBM, CatBoost(表格数据的黄金标准)
- 深度学习:MLP、Wide & Deep、DeepFM
- 挑战:深度学习模型在小规模表格数据上通常不如树模型
4.2 TabTransformer (2020)
论文:"TabTransformer: Tabular Data Modeling Using Contextual Embeddings"
核心创新:
- 将类别特征转换为嵌入向量
- 使用Transformer学习类别特征之间的上下文关系
- 数值特征保持原样或简单归一化
架构设计:
类别特征 → 嵌入层 → Transformer编码器 → 拼接数值特征 → MLP → 输出
具体流程:
1. 类别特征处理:
- 特征1: [A] → Embedding → [e1_1, e1_2, ..., e1_d]
- 特征2: [B] → Embedding → [e2_1, e2_2, ..., e2_d]
- ...
2. Transformer编码:
- 输入: (num_categorical_features, embedding_dim)
- 多层自注意力捕获特征间关系
- 输出: (num_categorical_features, embedding_dim)
3. 特征融合:
- Flatten编码后的类别特征
- 拼接原始数值特征
- 通过MLP进行最终预测
代码示例:
python
import torch
import torch.nn as nn
class TabTransformer(nn.Module):
def __init__(self,
num_categories, # 每个类别特征的类别数
num_numerical, # 数值特征数量
embedding_dim=32, # 嵌入维度
num_layers=6, # Transformer层数
num_heads=8, # 注意力头数
ffn_dim=128, # 前馈网络维度
output_dim=1): # 输出维度
super().__init__()
# 类别特征嵌入
self.embeddings = nn.ModuleList([
nn.Embedding(num_cat, embedding_dim)
for num_cat in num_categories
])
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=num_heads,
dim_feedforward=ffn_dim,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# 输出MLP
total_dim = len(num_categories) * embedding_dim + num_numerical
self.mlp = nn.Sequential(
nn.Linear(total_dim, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, output_dim)
)
def forward(self, categorical_features, numerical_features):
# 嵌入类别特征
embeddings = [
emb(categorical_features[:, i])
for i, emb in enumerate(self.embeddings)
]
# (batch, num_categorical, embedding_dim)
cat_embedded = torch.stack(embeddings, dim=1)
# Transformer编码
# (batch, num_categorical, embedding_dim)
encoded = self.transformer(cat_embedded)
# Flatten类别特征
encoded_flat = encoded.reshape(encoded.size(0), -1)
# 拼接数值特征
combined = torch.cat([encoded_flat, numerical_features], dim=1)
# MLP预测
output = self.mlp(combined)
return output
# 使用示例
model = TabTransformer(
num_categories=[10, 5, 20, 8], # 4个类别特征
num_numerical=15, # 15个数值特征
embedding_dim=32,
num_layers=4,
num_heads=4
)
# 假设batch_size=64
cat_features = torch.randint(0, 10, (64, 4)) # 类别特征
num_features = torch.randn(64, 15) # 数值特征
output = model(cat_features, num_features)
优势:
- 特征交互:自动学习类别特征间的复杂关系
- 鲁棒性:对缺失值和噪声有一定容忍度
- 迁移学习:预训练的嵌入可以迁移到相关任务
实验结果(原论文):
- 在多个UCI数据集上超越传统MLP
- 在某些数据集上接近甚至超越树模型
- 特别适合类别特征较多的场景
4.3 FT-Transformer (2021)
论文:"Revisiting Deep Learning Models for Tabular Data"
核心改进:
- 所有特征都嵌入化:数值特征也转换为嵌入
- 特征级注意力:每个特征作为一个token
- 更简洁的架构:去除额外的MLP层
架构:
所有特征 → 特征嵌入 → [CLS] Token → Transformer → [CLS]输出 → 预测
特征嵌入方式:
- 类别特征: Embedding(category_value)
- 数值特征: Linear(numerical_value) + Feature_embedding
代码示例:
python
class FTTransformer(nn.Module):
def __init__(self,
num_features,
num_categories,
embedding_dim=64,
num_layers=3,
num_heads=8):
super().__init__()
# 类别特征嵌入
self.cat_embeddings = nn.ModuleList([
nn.Embedding(num_cat, embedding_dim)
for num_cat in num_categories
])
# 数值特征线性投影
self.num_projections = nn.ModuleList([
nn.Linear(1, embedding_dim)
for _ in range(num_features)
])
# 特征位置嵌入
total_features = len(num_categories) + num_features
self.feature_embeddings = nn.Parameter(
torch.randn(1, total_features, embedding_dim)
)
# CLS token
self.cls_token = nn.Parameter(
torch.randn(1, 1, embedding_dim)
)
# Transformer
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=num_heads,
dim_feedforward=embedding_dim * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# 输出层
self.output = nn.Linear(embedding_dim, 1)
def forward(self, categorical_features, numerical_features):
batch_size = categorical_features.size(0)
# 处理类别特征
cat_embeds = [
emb(categorical_features[:, i])
for i, emb in enumerate(self.cat_embeddings)
]
# 处理数值特征
num_embeds = [
proj(numerical_features[:, i:i+1])
for i, proj in enumerate(self.num_projections)
]
# 拼接所有特征嵌入
all_embeds = cat_embeds + num_embeds
features = torch.stack(all_embeds, dim=1) # (batch, num_features, dim)
# 添加特征位置嵌入
features = features + self.feature_embeddings
# 添加CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
features = torch.cat([cls_tokens, features], dim=1)
# Transformer编码
encoded = self.transformer(features)
# 使用CLS token的输出进行预测
cls_output = encoded[:, 0, :]
output = self.output(cls_output)
return output
优势:
- 统一处理所有类型的特征
- 更好的特征交互学习
- 在多个基准测试中达到SOTA
4.4 时间序列预测中的应用
Temporal Fusion Transformer (TFT, 2019)
场景:多变量时间序列预测
架构特点:
- 变量选择网络:学习哪些特征重要
- 时序融合解码器:融合不同时间尺度的信息
- 多头注意力:捕获长期依赖关系
应用:
- 电力负荷预测
- 股票价格预测
- 零售需求预测
代码框架:
python
class TemporalFusionTransformer(nn.Module):
def __init__(self, config):
super().__init__()
# 变量选择网络
self.variable_selection = VariableSelectionNetwork(
input_dim=config.num_features,
hidden_dim=config.hidden_dim
)
# LSTM编码器(处理历史序列)
self.encoder_lstm = nn.LSTM(
input_size=config.hidden_dim,
hidden_size=config.hidden_dim,
num_layers=1,
batch_first=True
)
# LSTM解码器(生成未来序列)
self.decoder_lstm = nn.LSTM(
input_size=config.hidden_dim,
hidden_size=config.hidden_dim,
num_layers=1,
batch_first=True
)
# 时序注意力层
self.temporal_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True
)
# 门控残差网络
self.grn = GatedResidualNetwork(config.hidden_dim)
# 输出层
self.output_layer = nn.Linear(
config.hidden_dim,
config.forecast_horizon
)
4.5 推荐系统中的应用
BERT4Rec (2019)
思想:将用户行为序列看作"句子",物品看作"词"
架构:
用户行为序列: [item1, item2, [MASK], item4, item5]
↓
Transformer编码器
↓
预测被mask的item
训练方式:
- 随机mask序列中的物品
- 模型预测被mask的物品
- 类似BERT的预训练方式
优势:
- 双向建模:考虑前后文信息
- 捕获长期兴趣
- 自监督学习:无需额外标注
5. Transformer在图像数据中的应用
5.1 从CNN到Transformer的转变
传统CNN的局限:
- 局部感受野:卷积核只能看到局部区域
- 归纳偏置强:平移不变性和局部性是硬编码的
- 长距离依赖:需要堆叠很多层才能建立全局关系
Transformer的优势:
- 全局感受野:每个位置都能关注整个图像
- 灵活的归纳偏置:通过数据学习而非硬编码
- 可扩展性:性能随数据量和模型大小持续提升
5.2 Vision Transformer (ViT, 2020)
论文:"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"(Google Research)
核心思想:将图像分割成固定大小的patch,每个patch作为一个token
架构流程:
1. 图像分割
输入图像 (224×224×3)
↓
分割成patches (14×14个patches,每个16×16×3)
2. Patch嵌入
每个patch → flatten → 线性投影
(16×16×3 = 768维) → 嵌入空间 (D维,如768)
3. 添加位置嵌入
+ 可学习的位置编码
4. 添加[CLS] token
[CLS] + patch_1 + patch_2 + ... + patch_196
5. Transformer编码器
多层自注意力 + FFN
6. 分类头
[CLS] token的输出 → MLP → 类别概率
详细代码实现:
python
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""将图像分割成patches并嵌入"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 使用卷积实现patch提取和线性投影
self.projection = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
# x: (batch, 3, 224, 224)
x = self.projection(x) # (batch, embed_dim, 14, 14)
x = x.flatten(2) # (batch, embed_dim, 196)
x = x.transpose(1, 2) # (batch, 196, embed_dim)
return x
class VisionTransformer(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12, # Transformer层数
num_heads=12,
mlp_ratio=4.0, # FFN隐藏层维度 = embed_dim * mlp_ratio
dropout=0.1):
super().__init__()
# Patch嵌入
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
num_patches = self.patch_embed.num_patches
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置嵌入(可学习)
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.pos_drop = nn.Dropout(p=dropout)
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=depth
)
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# 初始化权重
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
batch_size = x.shape[0]
# Patch嵌入
x = self.patch_embed(x) # (batch, 196, 768)
# 添加CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (batch, 197, 768)
# 添加位置嵌入
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer编码
x = self.transformer(x)
# 分类(使用CLS token)
x = self.norm(x[:, 0])
x = self.head(x)
return x
# 模型变体
def vit_base_patch16_224():
"""ViT-Base: 86M参数"""
return VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12
)
def vit_large_patch16_224():
"""ViT-Large: 307M参数"""
return VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16
)
def vit_huge_patch14_224():
"""ViT-Huge: 632M参数"""
return VisionTransformer(
img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16
)
关键发现(原论文):
-
数据量是关键:
- 在ImageNet-1K(130万图像)上:ViT < ResNet
- 在ImageNet-21K(1400万图像)上:ViT ≈ ResNet
- 在JFT-300M(3亿图像)上:ViT > ResNet
-
归纳偏置:
- ViT几乎没有图像特定的归纳偏置
- 完全依赖数据学习,因此需要大规模数据
-
计算效率:
- 相同精度下,ViT训练成本更低
- 更容易扩展到更大模型
5.3 Swin Transformer (2021)
论文:"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"(Microsoft Research Asia)
动机:
- ViT计算复杂度是O(N²),其中N是patch数量
- 对于高分辨率图像,计算量巨大
- 缺少层次化结构,难以用于密集预测任务
核心创新:
- 局部窗口注意力:限制注意力在固定大小的窗口内
- 滑动窗口机制:通过移位窗口建立跨窗口连接
- 层次化结构:类似CNN的金字塔结构
架构设计:
Stage 1: 56×56 patches
↓ Window Attention
↓ Shifted Window Attention
Stage 2: 28×28 patches (Patch Merging)
↓ Window Attention
↓ Shifted Window Attention
Stage 3: 14×14 patches (Patch Merging)
↓ Window Attention
↓ Shifted Window Attention
Stage 4: 7×7 patches (Patch Merging)
↓ Window Attention
↓ Shifted Window Attention
窗口注意力示例:
python
class WindowAttention(nn.Module):
"""在固定大小窗口内计算注意力"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size # (window_height, window_width)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 相对位置偏置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)
)
# QKV投影
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x, mask=None):
"""
x: (num_windows*batch, window_size*window_size, C)
mask: (num_windows, window_size*window_size, window_size*window_size)
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 添加相对位置偏置
# ... (省略细节)
# 应用掩码(用于shifted window)
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer基本块"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, (window_size, window_size), num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim)
)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# 滑动窗口(如果shift_size > 0)
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
else:
shifted_x = x
# 分割成窗口
x_windows = window_partition(shifted_x, self.window_size)
# 窗口注意力
attn_windows = self.attn(x_windows)
# 合并窗口
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# 反向滑动
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
优势:
- 计算效率:O(N)复杂度,可处理高分辨率图像
- 层次化表示:适合检测、分割等任务
- 性能优异:ImageNet分类、COCO检测都达到SOTA
实验结果:
- ImageNet-1K: Top-1 87.3% (Swin-L)
- COCO目标检测: 58.7 box AP
- ADE20K语义分割: 53.5 mIoU
5.4 Detection Transformer (DETR, 2020)
论文:"End-to-End Object Detection with Transformers"(Facebook AI)
革命性创新:
- 第一个端到端的目标检测器
- 不需要NMS(非极大值抑制)
- 不需要anchor boxes
架构:
输入图像
↓
CNN Backbone (ResNet) → 特征图
↓
Flatten + 位置编码
↓
Transformer编码器
↓
Object Queries(可学习)
↓
Transformer解码器
↓
FFN → N个预测 (类别 + 边界框)
↓
匈牙利匹配 + Loss
关键组件:
- Object Queries:
python
# N个可学习的查询向量(N=100)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
# 在解码器中使用
query_pos = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)
- 二分匹配:
python
def hungarian_matching(pred_boxes, pred_logits, target_boxes, target_labels):
"""使用匈牙利算法进行预测和GT的最优匹配"""
# 计算分类成本
cost_class = -pred_logits[:, target_labels]
# 计算L1距离成本
cost_bbox = torch.cdist(pred_boxes, target_boxes, p=1)
# 计算GIoU成本
cost_giou = -generalized_box_iou(pred_boxes, target_boxes)
# 总成本
C = cost_class + 5 * cost_bbox + 2 * cost_giou
# 匈牙利算法
indices = linear_sum_assignment(C.cpu())
return indices
完整实现框架:
python
class DETR(nn.Module):
def __init__(self, num_classes, num_queries=100, hidden_dim=256):
super().__init__()
# CNN骨干网络
self.backbone = resnet50(pretrained=True)
# 降维
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# 位置编码
self.pos_encoder = PositionEmbeddingSine(hidden_dim // 2)
# Transformer
self.transformer = nn.Transformer(
d_model=hidden_dim,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048
)
# Object queries
self.query_embed = nn.Embedding(num_queries, hidden_dim)
# 预测头
self.class_embed = nn.Linear(hidden_dim, num_classes + 1) # +1 for no-object
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) # (x, y, w, h)
def forward(self, images):
# 特征提取
features = self.backbone(images) # (batch, 2048, H, W)
features = self.conv(features) # (batch, 256, H, W)
# 位置编码
pos = self.pos_encoder(features)
# Flatten
batch_size = features.shape[0]
features = features.flatten(2).permute(2, 0, 1) # (HW, batch, 256)
pos = pos.flatten(2).permute(2, 0, 1)
# Object queries
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)
# Transformer
memory = self.transformer.encoder(features, pos=pos)
hs = self.transformer.decoder(
query_embed,
memory,
memory_key_padding_mask=None,
pos=pos,
query_pos=query_embed
)
# 预测
outputs_class = self.class_embed(hs) # (num_queries, batch, num_classes+1)
outputs_coord = self.bbox_embed(hs).sigmoid() # (num_queries, batch, 4)
return {
'pred_logits': outputs_class[-1],
'pred_boxes': outputs_coord[-1]
}
优势:
- 真正的端到端训练
- 全局推理能力
- 代码简洁,易于扩展
局限性:
- 训练收敛较慢(需要500 epochs)
- 小物体检测性能较弱
- 计算量较大
5.5 图像分割中的应用
Segmenter (2021)
将ViT应用于语义分割:
图像 → Patch嵌入 → Transformer编码器 → 逐patch分类 → 上采样 → 分割图
Mask2Former (2022)
结合DETR思想进行实例分割和全景分割:
图像 → Backbone → Pixel Decoder → Transformer Decoder → Mask预测
6. 性能优化与变体
6.1 高效注意力机制
Linformer (2020)
- 将注意力复杂度从O(N²)降到O(N)
- 方法:低秩矩阵近似K和V
python
# 核心思想
K_projected = K @ E # E: (N, k) 其中 k << N
V_projected = V @ F # F: (N, k)
Attention = softmax(Q @ K_projected.T) @ V_projected
Performer (2020)
- 使用快速注意力算法(FAVOR+)
- 线性复杂度,无近似误差
Reformer (2020)
- 局部敏感哈希(LSH)注意力
- 只计算相似query和key的注意力
6.2 位置编码改进
相对位置编码:
python
# 不是编码绝对位置,而是编码相对距离
relative_position = position_i - position_j
bias = learned_bias[relative_position]
attention_score = (Q @ K.T) + bias
旋转位置编码(RoPE):
- 通过旋转矩阵注入位置信息
- 用于LLaMA、PaLM等大模型
ALiBi(Attention with Linear Biases):
- 在注意力分数上添加线性偏置
- 外推性能优异
6.3 稀疏注意力
局部注意力:
python
# 只关注周围k个位置
for i in range(seq_len):
attend_to = range(max(0, i-k), min(seq_len, i+k+1))
Strided注意力:
python
# 每隔s个位置关注一次
attend_to = [0, s, 2s, 3s, ...]
Longformer:
- 结合局部注意力和全局注意力
- 可处理长达16k的序列
7. 实际应用案例
7.1 案例1:电商推荐系统
场景:用户商品点击序列预测
数据:
- 用户历史点击:[item_1, item_2, ..., item_n]
- 用户特征:年龄、性别、地域
- 商品特征:类别、价格、品牌
方案:
python
class RecommendationTransformer(nn.Module):
def __init__(self, num_items, embedding_dim=128):
super().__init__()
# 商品嵌入
self.item_embedding = nn.Embedding(num_items, embedding_dim)
# Transformer编码器
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=8,
dim_feedforward=512,
batch_first=True
),
num_layers=4
)
# 预测头
self.output = nn.Linear(embedding_dim, num_items)
def forward(self, item_sequence):
# 嵌入
x = self.item_embedding(item_sequence)
# 位置编码
x = x + self.positional_encoding(x)
# Transformer编码
encoded = self.transformer(x)
# 预测下一个商品
logits = self.output(encoded[:, -1, :])
return logits
效果:
- 相比LSTM提升5-10%的点击率
- 能够捕获长期用户兴趣
7.2 案例2:医学图像分割
场景:CT扫描肿瘤分割
挑战:
- 医学图像分辨率高(512×512甚至更大)
- 标注数据有限
- 需要高精度
方案:使用Swin-UNETR
输入CT切片
↓
Swin Transformer编码器(多尺度特征)
↓
UNet风格解码器
↓
分割结果
优势:
- 全局上下文建模
- 层次化特征提取
- 在BTCV数据集上达到SOTA
7.3 案例3:金融风控
场景:信用卡欺诈检测
数据特点:
- 高维特征:100+维
- 混合类型:交易金额(数值)、商户类别(类别)
- 类别不平衡:欺诈样本<1%
方案:使用TabTransformer
python
model = TabTransformer(
num_categories=[1000, 50, 20, ...], # 商户ID、地区、卡类型
num_numerical=45, # 交易金额、时间等
embedding_dim=64,
num_layers=6,
num_heads=8
)
# 训练时使用Focal Loss处理类别不平衡
criterion = FocalLoss(alpha=0.25, gamma=2.0)
效果:
- F1-score提升3-5%
- 特征交互自动学习,减少特征工程
7.4 案例4:自动驾驶
场景:3D目标检测
数据:激光雷达点云 + 相机图像
方案:TransFusion
点云 → PointNet提取特征
↓
相机图像 → CNN提取特征
↓
Transformer融合多模态特征
↓
检测头 → 3D边界框
创新点:
- 跨模态注意力机制
- 查询向量同时关注点云和图像特征
总结
Transformer的核心优势
- 并行化:不同于RNN的顺序计算,可以充分利用GPU
- 长距离依赖:通过注意力机制直接建立任意位置间的连接
- 可解释性:注意力权重可视化,了解模型关注点
- 可扩展性:性能随模型大小和数据量持续提升
- 通用性:从NLP到CV,从表格数据到多模态
未来发展方向
-
效率优化:
- 更高效的注意力机制(Flash Attention)
- 模型压缩与量化
- 稀疏模型(MoE)
-
架构创新:
- 更好的位置编码
- 动态深度网络
- 神经架构搜索
-
应用拓展:
- 多模态融合(CLIP, DALL-E)
- 科学计算(AlphaFold)
- 强化学习(Decision Transformer)
-
理论理解:
- 为什么Transformer有效?
- 如何设计更好的归纳偏置?
- 泛化性理论
学习资源
论文必读:
- Attention is All You Need (2017)
- BERT: Pre-training of Deep Bidirectional Transformers (2018)
- An Image is Worth 16x16 Words (ViT, 2020)
- Swin Transformer (2021)
代码实践:
- Hugging Face Transformers库
- PyTorch官方教程
- Annotated Transformer
课程推荐:
- Stanford CS224N (NLP with Deep Learning)
- Stanford CS231N (Computer Vision)
- Deep Learning Specialization (Andrew Ng)
文档版本 :v1.0
最后更新 :2024年
作者:Claude (Anthropic)