1. 详细解释CNN的工作原理,为什么卷积神经网络特别适合处理图像数据?
详细解答:
CNN核心组件
1. 卷积层(Convolutional Layer)
数学原理:
对于输入I和卷积核K:
(I * K)(i, j) = ΣΣ I(i+m, j+n) × K(m, n)
输出特征图大小:
output_size = (input_size - kernel_size + 2×padding) / stride + 1
工作机制:
- 卷积核在输入上滑动,执行逐元素乘法和求和
- 共享参数:同一卷积核扫描整个输入
- 多个卷积核提取不同的特征模式
实现示例:
python
# TensorFlow/Keras实现
layers.Conv2D(
filters=64, # 卷积核数量
kernel_size=(3, 3), # 卷积核大小
strides=(1, 1), # 步长
padding='same', # 填充方式
activation='relu'
)
# 参数量计算:
# (kernel_h × kernel_w × input_channels + 1) × num_filters
# (3 × 3 × 3 + 1) × 64 = 1,792 参数(假设RGB输入)
特征提取层次:
- 浅层:边缘、纹理、颜色
- 中层:形状、部件(眼睛、轮子)
- 深层:高级语义(脸部、汽车)
2. 池化层(Pooling Layer)
作用:
- 降采样,减少空间维度
- 提供平移不变性
- 降低计算复杂度
- 扩大感受野
类型对比:
python
# Max Pooling(最常用)
layers.MaxPooling2D(pool_size=(2, 2), strides=2)
# 保留最显著的特征,丢弃位置信息
# Average Pooling
layers.AveragePooling2D(pool_size=(2, 2))
# 保留平均信息,更平滑
# Global Average Pooling(常用于分类网络末端)
layers.GlobalAveragePooling2D()
# 将每个特征图压缩为一个值
Max vs Average选择:
- Max Pooling:检测特征是否存在(分类任务)
- Average Pooling:保留整体信息(分割任务)
3. 全连接层(Fully Connected Layer)
作用:
- 整合卷积层提取的局部特征
- 执行最终的分类或回归
位置:
通常位于网络末端,现代架构趋向于减少或替代FC层
CNN为什么适合图像处理
1. 局部连接性(Local Connectivity)
传统全连接问题:
对于224×224×3的图像:
第一层全连接需要 224×224×3 = 150,528 个输入
如果隐藏层有1000个神经元:
参数量 = 150,528 × 1000 = 1.5亿参数!
卷积的优势:
3×3卷积核,64个filter:
参数量 = (3×3×3 + 1) × 64 = 1,792 参数
降低了5个数量级!
原理:
- 图像的局部区域具有强相关性
- 相邻像素比远距离像素更相关
- 卷积核捕获局部模式
2. 参数共享(Parameter Sharing)
机制:
同一卷积核在整个图像上滑动,参数被重复使用
好处:
- 大幅减少参数量
- 减少过拟合风险
- 降低内存需求
- 加速训练
实例对比:
全连接检测边缘:
需要为每个位置学习独立的检测器 → 参数冗余
卷积检测边缘:
一个3×3 Sobel核应用于所有位置 → 参数高效
3. 平移不变性(Translation Invariance)
定义:
对象在图像中的位置改变不影响检测结果
实现方式:
- 参数共享 → 特征检测器在所有位置相同
- 池化层 → 进一步增强位移容忍度
示例:
猫的图像无论在左上角还是右下角,同一组卷积核都能识别
4. 层次化特征学习(Hierarchical Feature Learning)
金字塔结构:
输入层: 224×224×3 (原始像素)
↓
Conv1: 112×112×64 (边缘、纹理)
↓
Conv2: 56×56×128 (简单形状)
↓
Conv3: 28×28×256 (物体部件)
↓
Conv4: 14×14×512 (完整对象)
↓
FC: 1000 (类别概率)
符合视觉认知规律:
人类视觉系统也是层次化处理,从简单到复杂
5. 感受野(Receptive Field)
定义:
输出特征图中一个神经元对应输入图像中的区域大小
逐层增长:
3×3卷积堆叠的感受野增长:
Layer 1: 3×3
Layer 2: 5×5 (3 + 2×1)
Layer 3: 7×7 (5 + 2×1)
Layer n: 2n+1
两个3×3卷积 = 一个5×5卷积的感受野
但参数量更少:2×(3×3) = 18 vs 5×5 = 25
设计原则:
- 深度网络通过堆叠小卷积核获得大感受野
- VGGNet、ResNet都采用3×3卷积核
经典CNN架构演进
LeNet-5 (1998)
结构:
Input(32×32) → Conv(5×5) → Pool → Conv(5×5) → Pool → FC → FC → Output(10)
贡献:
- 奠定CNN基本架构
- 用于手写数字识别
AlexNet (2012)
创新点:
python
# 使用ReLU激活
layers.Conv2D(96, 11, strides=4, activation='relu')
# 引入Dropout
layers.Dropout(0.5)
# 数据增强
ImageDataGenerator(
rotation_range=10,
width_shift_range=0.1,
horizontal_flip=True
)
影响:
- 深度学习在ImageNet上的突破
- GPU加速训练
- ReLU替代Sigmoid
VGGNet (2014)
核心思想:
深度很重要,统一使用3×3小卷积核
架构模式:
python
# VGG块
def vgg_block(num_convs, num_filters):
block = Sequential()
for _ in range(num_convs):
block.add(layers.Conv2D(num_filters, 3,
padding='same',
activation='relu'))
block.add(layers.MaxPooling2D(2))
return block
# VGG-16
model = Sequential([
vgg_block(2, 64), # 2个卷积
vgg_block(2, 128),
vgg_block(3, 256), # 3个卷积
vgg_block(3, 512),
vgg_block(3, 512),
layers.Flatten(),
layers.Dense(4096, activation='relu'),
layers.Dropout(0.5),
layers.Dense(4096, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1000, activation='softmax')
])
优势:
- 结构简单、规律
- 易于理解和实现
缺点:
- 参数量巨大(138M)
- 训练和推理慢
GoogLeNet / Inception (2014)
创新:Inception模块
python
def inception_module(x, f1, f3_reduce, f3, f5_reduce, f5, pool_proj):
# 1×1卷积分支
branch1 = layers.Conv2D(f1, 1, activation='relu')(x)
# 1×1 → 3×3卷积分支
branch2 = layers.Conv2D(f3_reduce, 1, activation='relu')(x)
branch2 = layers.Conv2D(f3, 3, padding='same', activation='relu')(branch2)
# 1×1 → 5×5卷积分支
branch3 = layers.Conv2D(f5_reduce, 1, activation='relu')(x)
branch3 = layers.Conv2D(f5, 5, padding='same', activation='relu')(branch3)
# 池化 → 1×1卷积分支
branch4 = layers.MaxPooling2D(3, strides=1, padding='same')(x)
branch4 = layers.Conv2D(pool_proj, 1, activation='relu')(branch4)
# 拼接所有分支
output = layers.Concatenate()([branch1, branch2, branch3, branch4])
return output
优势:
- 多尺度特征提取(1×1、3×3、5×5)
- 1×1卷积降维,减少计算量
- 参数效率高
ResNet (2015)
解决问题:
网络越深,训练越困难(退化问题)
残差块:
python
def residual_block(x, filters, stride=1):
shortcut = x
# 主路径
x = layers.Conv2D(filters, 3, strides=stride, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
# 如果维度不匹配,调整shortcut
if stride != 1:
shortcut = layers.Conv2D(filters, 1, strides=stride)(shortcut)
shortcut = layers.BatchNormalization()(shortcut)
# 残差连接
x = layers.Add()([x, shortcut])
x = layers.Activation('relu')(x)
return x
数学原理:
传统:H(x) = F(x)
ResNet: H(x) = F(x) + x
学习残差:F(x) = H(x) - x
恒等映射更容易学习
突破:
- 训练超深网络(152层,甚至1000层)
- 梯度直接通过shortcut传播
- 性能随深度提升
现代CNN设计原则
1. 深度与宽度平衡
- 更深的网络 vs 更宽的层
- EfficientNet通过复合缩放同时调整
2. 计算效率
python
# 深度可分离卷积(MobileNet)
layers.SeparableConv2D(64, 3, padding='same')
# 参数量:3×3×1 + 1×1×input_ch×64 << 3×3×input_ch×64
3. 注意力机制
python
# Squeeze-and-Excitation块(SENet)
def se_block(x, ratio=16):
channels = x.shape[-1]
# Squeeze:全局平均池化
se = layers.GlobalAveragePooling2D()(x)
# Excitation:全连接层
se = layers.Dense(channels // ratio, activation='relu')(se)
se = layers.Dense(channels, activation='sigmoid')(se)
se = layers.Reshape((1, 1, channels))(se)
# 通道加权
return layers.Multiply()([x, se])
4. 自动化架构搜索(NAS)
- EfficientNet、RegNet等通过NAS优化
实践建议
选择架构:
- 精度优先:ResNet-50/101、EfficientNet
- 速度优先:MobileNetV3、EfficientNet-Lite
- 边缘设备:MobileNet、ShuffleNet
迁移学习:
python
# 使用预训练模型
base_model = tf.keras.applications.ResNet50(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False # 冻结预训练权重
# 添加自定义分类头
model = Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(num_classes, activation='softmax')
])
2. 解释Transformer架构的核心机制,为什么它在NLP和CV领域都取得了巨大成功?
详细解答:
Transformer核心组件
1. 自注意力机制(Self-Attention)
数学定义:
Attention(Q, K, V) = softmax(QK^T / √d_k) × V
其中:
- Q (Query): 查询矩阵
- K (Key): 键矩阵
- V (Value): 值矩阵
- d_k: 键向量维度(用于缩放)
计算流程:
python
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# 线性变换生成Q, K, V
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
# x shape: (batch_size, seq_len, embed_dim)
# 生成Q, K, V
Q = self.query(x) # (batch, seq_len, embed_dim)
K = self.key(x)
V = self.value(x)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, seq_len, seq_len)
scores = scores / math.sqrt(self.embed_dim) # 缩放
# Softmax归一化
attention_weights = torch.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attention_weights, V) # (batch, seq_len, embed_dim)
return output, attention_weights
直观理解:
对于句子"The cat sat on the mat":
- Query"sat"询问:我应该关注哪些词?
- Key"cat"、"mat"等回应:我与你的相关性
- 注意力权重:sat对cat的权重可能很高
- Value"cat"的语义被加权融入"sat"的表示
优势:
- 捕获长距离依赖(不受距离限制)
- 并行计算(不像RNN需要顺序处理)
- 动态权重(根据上下文调整)
2. 多头注意力(Multi-Head Attention)
动机:
不同的注意力头可以关注不同的语义子空间
实现:
python
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.out = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# 生成Q, K, V并reshape为多头
Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# 转置以便并行计算: (batch, num_heads, seq_len, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention = torch.softmax(scores, dim=-1)
# 加权求和
out = torch.matmul(attention, V) # (batch, num_heads, seq_len, head_dim)
# 拼接所有头
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# 最终线性变换
return self.out(out)
示例分析:
embed_dim = 512, num_heads = 8
每个头的维度 = 512 / 8 = 64
头1可能关注:主谓关系
头2可能关注:形容词-名词关系
头3可能关注:长距离依赖
...
好处:
- 增加模型容量
- 捕获多种语义关系
- 集成多个子空间的信息
3. 位置编码(Positional Encoding)
问题:
Self-Attention对位置不敏感(置换等变性)
解决方案:
添加位置信息到输入嵌入
正弦位置编码:
python
def positional_encoding(seq_len, d_model):
position = torch.arange(seq_len).unsqueeze(1) # (seq_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
pe = torch.zeros(seq_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
return pe
# 使用
x = token_embeddings + positional_encoding(seq_len, d_model)
性质:
- 对每个维度使用不同频率的sin/cos
- 相对位置可以通过线性变换表示
- 可以外推到训练时未见过的序列长度
可学习位置编码(BERT等):
python
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
4. 前馈网络(Feed-Forward Network)
结构:
python
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# FFN(x) = max(0, xW1 + b1)W2 + b2
return self.linear2(self.dropout(torch.relu(self.linear1(x))))
作用:
- 增加非线性
- 对每个位置独立应用(位置无关)
- 通常d_ff = 4 × d_model
5. 层归一化与残差连接
Transformer块结构:
python
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 自注意力 + 残差 + LayerNorm
attn_output = self.attention(x)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差 + LayerNorm
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
LayerNorm vs BatchNorm:
- LayerNorm:在特征维度归一化,适合序列数据
- BatchNorm:在batch维度归一化,适合CNN
Transformer完整架构
编码器-解码器结构(原始Transformer):
python
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size,
d_model=512, num_heads=8, num_layers=6, d_ff=2048):
super().__init__()
# 嵌入层
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 编码器
self.encoder = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
# 解码器(包含交叉注意力)
self.decoder = nn.ModuleList([
DecoderBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
# 输出层
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt):
# 编码
src_embed = self.src_embedding(src) + positional_encoding(...)
enc_output = src_embed
for encoder_layer in self.encoder:
enc_output = encoder_layer(enc_output)
# 解码
tgt_embed = self.tgt_embedding(tgt) + positional_encoding(...)
dec_output = tgt_embed
for decoder_layer in self.decoder:
dec_output = decoder_layer(dec_output, enc_output)
# 输出
return self.fc_out(dec_output)
为什么Transformer如此成功
在NLP领域
1. 长距离依赖建模
RNN问题:
"The cat, which was sitting on the mat that was in the corner of the room, meowed."
距离太远,梯度消失
Transformer优势:
通过attention直接连接任意两个词,O(1)复杂度
2. 并行化训练
RNN: t1 → t2 → t3 → ... → tn (顺序)
Transformer: 所有位置同时计算 (并行)
训练速度提升:10-100倍
3. 可扩展性
模型规模可以轻松扩展:
GPT-3: 175B参数
PaLM: 540B参数
性能随规模提升(Scaling Law)
4. 预训练-微调范式
python
# 预训练(自监督)
BERT: 掩码语言模型 + 下一句预测
GPT: 自回归语言模型
# 微调(少量标注数据)
分类、NER、问答等下游任务
成功案例:
- BERT: 双向编码器,适合理解任务
- GPT: 单向解码器,适合生成任务
- T5: 统一文本到文本框架
- ChatGPT/GPT-4: 大规模对话系统
在CV领域
Vision Transformer (ViT)核心思想:
python
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=1000,
d_model=768, num_heads=12, num_layers=12):
super().__init__()
num_patches = (img_size // patch_size) ** 2
# 图像patch嵌入
self.patch_embedding = nn.Conv2d(
3, d_model,
kernel_size=patch_size,
stride=patch_size
) # 将图像分割为patch
# 位置嵌入
self.position_embedding = nn.Parameter(
torch.randn(1, num_patches + 1, d_model)
)
# [CLS] token
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
# Transformer编码器
self.transformer = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_model * 4)
for _ in range(num_layers)
])
# 分类头
self.mlp_head = nn.Linear(d_model, num_classes)
def forward(self, x):
batch_size = x.shape[0]
# 图像 → patch序列
x = self.patch_embedding(x) # (B, d_model, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
# 添加[CLS] token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 添加位置嵌入
x = x + self.position_embedding
# Transformer编码
for transformer_layer in self.transformer:
x = transformer_layer(x)
# 分类(使用[CLS] token)
return self.mlp_head(x[:, 0])
关键设计:
- 将图像分割为16×16的patch(相当于"词")
- Patch嵌入(相当于词嵌入)
- 标准Transformer处理
ViT优势:
-
更少归纳偏置
- CNN硬编码了局部性和平移等变性
- Transformer学习这些模式,更灵活
-
全局建模
- 从第一层就能建立全局依赖
- CNN需要堆叠多层才能获得大感受野
-
大规模数据效率
- 在大数据集(JFT-300M)上超越CNN
- 数据越多,优势越明显
CV领域其他成功应用:
DETR(目标检测):
python
# 端到端检测,无需NMS
queries = nn.Parameter(torch.randn(num_queries, d_model))
# Transformer解码queries → 目标位置和类别
Swin Transformer(层次化视觉backbone):
- 窗口注意力(降低复杂度)
- 层次化结构(类似CNN金字塔)
- 适配密集预测任务(分割、检测)
CLIP(视觉-语言预训练):
- 图像编码器:ViT
- 文本编码器:Transformer
- 对比学习对齐两个模态
Transformer的局限性与改进
计算复杂度问题:
Self-Attention复杂度:O(n²d)
其中n是序列长度,d是维度
对于长序列(如高分辨率图像),计算量巨大
改进方法:
1. 稀疏注意力(Sparse Attention)
python
# Longformer: 局部窗口 + 全局注意力
# 复杂度:O(n × w),w是窗口大小
2. 线性注意力(Linear Attention)
Performer、Linformer等
将O(n²)降至O(n)
3. 高效Transformer
- Reformer:局部敏感哈希
- BigBird:稀疏注意力模式
- Flash Attention:内存优化
实践建议
NLP任务选择:
- 理解任务:BERT、RoBERTa
- 生成任务:GPT、T5
- 多模态:CLIP、DALL-E
CV任务选择:
- 分类:ViT、Swin Transformer
- 检测:DETR、Deformable DETR
- 分割:Segmenter、Mask2Former
训练技巧:
python
# 预训练模型微调
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# 冻结大部分层,只微调顶层
for param in model.encoder.layer[:10].parameters():
param.requires_grad = False
3. 比较RNN、LSTM、GRU的架构差异,解释为什么LSTM和GRU能够缓解梯度消失问题
详细解答:
标准RNN (Vanilla RNN)
结构:
python
class VanillaRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
# 权重矩阵
self.W_hh = nn.Linear(hidden_size, hidden_size) # 隐藏到隐藏
self.W_xh = nn.Linear(input_size, hidden_size) # 输入到隐藏
self.W_hy = nn.Linear(hidden_size, output_size) # 隐藏到输出
def forward(self, x, h_prev):
# h_t = tanh(W_hh * h_{t-1} + W_xh * x_t)
h_t = torch.tanh(self.W_hh(h_prev) + self.W_xh(x))
y_t = self.W_hy(h_t)
return y_t, h_t
数学表示:
h_t = tanh(W_hh × h_{t-1} + W_xh × x_t + b_h)
y_t = W_hy × h_t + b_y
梯度消失问题分析:
反向传播时的梯度链:
∂L/∂h_t = ∂L/∂h_{t+1} × ∂h_{t+1}/∂h_t
∂h_{t+1}/∂h_t = W_hh^T × diag(1 - tanh²(h_{t+1}))
经过T个时间步:
∂L/∂h_1 = ∂L/∂h_T × ∏(t=1 to T-1) W_hh^T × diag(1 - tanh²(h_t))
问题:
- tanh导数 ∈ (0, 1],最大值为1
- 连乘导致指数衰减
- 长序列时梯度趋近于0
具体例子:
python
# 假设tanh导数平均为0.25
# 经过10步:0.25^10 ≈ 9.5 × 10^-7
# 经过20步:0.25^20 ≈ 9.1 × 10^-13
# 梯度几乎完全消失!
局限性:
- 无法学习长期依赖(>10步)
- 训练不稳定
- 只能记住最近的信息
LSTM (Long Short-Term Memory)
核心思想:
引入"细胞状态"(cell state)作为信息高速公路,以及三个门控机制精确控制信息流动。
架构图解:
┌─────────────────────────────┐
│ Cell State (C_t) │ ← 信息高速公路
└─────────────────────────────┘
↑ ↑ ↑
forget input output
gate gate gate
↑ ↑ ↑
┌────┴──────────┴─────────┴────┐
│ x_t, h_{t-1} │
└───────────────────────────────┘
完整实现:
python
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# 四个门(forget, input, candidate, output)共享输入
self.W = nn.Linear(input_size + hidden_size, 4 * hidden_size)
def forward(self, x, h_prev, c_prev):
# 拼接输入和隐藏状态
combined = torch.cat([x, h_prev], dim=1)
# 一次线性变换生成所有门
gates = self.W(combined)
# 分割为四个门
f_t, i_t, g_t, o_t = gates.chunk(4, dim=1)
# 1. 遗忘门:决定丢弃多少旧信息
f_t = torch.sigmoid(f_t)
# 2. 输入门:决定添加多少新信息
i_t = torch.sigmoid(i_t)
# 3. 候选值:新的候选信息
g_t = torch.tanh(g_t)
# 4. 输出门:决定输出多少信息
o_t = torch.sigmoid(o_t)
# 更新细胞状态
c_t = f_t * c_prev + i_t * g_t
# 计算隐藏状态
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
数学公式:
遗忘门:f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
输入门:i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
候选值:g_t = tanh(W_g · [h_{t-1}, x_t] + b_g)
输出门:o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
细胞状态更新:C_t = f_t ⊙ C_{t-1} + i_t ⊙ g_t
隐藏状态:h_t = o_t ⊙ tanh(C_t)
其中⊙表示逐元素乘法
为什么能缓解梯度消失:
关键:细胞状态的梯度路径
∂C_t/∂C_{t-1} = f_t
反向传播时:
∂L/∂C_1 = ∂L/∂C_T × ∏(t=1 to T-1) f_t
优势:
1. f_t ∈ (0, 1) 由sigmoid动态控制,而非固定的tanh导数
2. 如果f_t接近1,梯度几乎无衰减
3. 模型学会在需要长期记忆时让f_t ≈ 1
4. 细胞状态提供了一条相对"平滑"的梯度路径
直观理解:
传统RNN:信息必须通过所有隐藏状态
LSTM:信息可以直接沿细胞状态传递,门控决定何时读写
门控机制的作用:
| 门 | 作用 | 示例 |
|---|---|---|
| 遗忘门 | 选择性遗忘旧信息 | 句子结束时忘记前一句的主语 |
| 输入门 | 选择性添加新信息 | 遇到新主语时更新 |
| 输出门 | 选择性输出信息 | 根据任务需要决定暴露多少内部状态 |
GRU (Gated Recurrent Unit)
核心思想:
简化LSTM,合并细胞状态和隐藏状态,只使用两个门。
架构:
python
class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# 重置门和更新门
self.W_z = nn.Linear(input_size + hidden_size, hidden_size) # 更新门
self.W_r = nn.Linear(input_size + hidden_size, hidden_size) # 重置门
self.W_h = nn.Linear(input_size + hidden_size, hidden_size) # 候选隐藏状态
def forward(self, x, h_prev):
combined = torch.cat([x, h_prev], dim=1)
# 1. 更新门:决定保留多少旧信息
z_t = torch.sigmoid(self.W_z(combined))
# 2. 重置门:决定遗忘多少旧信息来计算候选状态
r_t = torch.sigmoid(self.W_r(combined))
# 3. 候选隐藏状态(使用重置后的h_prev)
combined_reset = torch.cat([x, r_t * h_prev], dim=1)
h_tilde = torch.tanh(self.W_h(combined_reset))
# 4. 更新隐藏状态(线性插值)
h_t = (1 - z_t) * h_prev + z_t * h_tilde
return h_t
数学公式:
更新门:z_t = σ(W_z · [h_{t-1}, x_t])
重置门:r_t = σ(W_r · [h_{t-1}, x_t])
候选状态:h̃_t = tanh(W_h · [r_t ⊙ h_{t-1}, x_t])
最终状态:h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t
梯度流动分析:
∂h_t/∂h_{t-1} = (1 - z_t) + z_t × ∂h̃_t/∂h_{t-1}
= (1 - z_t) + z_t × r_t × tanh'(·) × W_h
优势:
- z_t接近0时:h_t ≈ h_{t-1},梯度直接传递(类似LSTM的f_t≈1)
- 提供了一条相对直接的梯度路径
三者对比
| 特性 | RNN | LSTM | GRU |
|---|---|---|---|
| 参数量 | 最少 | 最多 | 中等 |
| 计算复杂度 | O(h²) | O(4h²) | O(3h²) |
| 门的数量 | 0 | 3个 | 2个 |
| 状态数量 | 1 (h_t) | 2 (h_t, C_t) | 1 (h_t) |
| 长期依赖 | ✗ 差 | ✓ 优秀 | ✓ 良好 |
| 训练速度 | 快 | 慢 | 中等 |
| 表达能力 | 弱 | 强 | 中强 |
参数量对比(input_size=100, hidden_size=128):
python
# RNN
params_rnn = (100 + 128) * 128 = 29,184
# LSTM(4个门)
params_lstm = (100 + 128) * 128 * 4 = 116,736
# GRU(3个门)
params_gru = (100 + 128) * 128 * 3 = 87,552
实际应用场景
选择RNN:
- 简单的序列任务
- 序列很短(<10步)
- 计算资源极度受限
- 实际中很少使用
选择LSTM:
- 需要精确的长期记忆(如语言建模)
- 复杂的时间依赖关系
- 性能最重要,计算资源充足
- 经典任务(机器翻译、语音识别)
选择GRU:
- 数据量有限时(参数少,不易过拟合)
- 需要较快训练速度
- 长期依赖不是极端复杂
- 现代推荐:通常优先尝试GRU
实际代码示例:
python
import torch.nn as nn
# PyTorch内置实现
lstm = nn.LSTM(input_size=100, hidden_size=128, num_layers=2,
batch_first=True, dropout=0.2)
gru = nn.GRU(input_size=100, hidden_size=128, num_layers=2,
batch_first=True, dropout=0.2)
# 前向传播
x = torch.randn(32, 50, 100) # (batch, seq_len, input_size)
# LSTM
output, (h_n, c_n) = lstm(x)
# GRU
output, h_n = gru(x)
现代替代方案
Transformer的兴起:
- 2017年后,Transformer逐渐替代RNN/LSTM/GRU
- 并行化训练,速度快10-100倍
- 更好的长距离依赖建模
- BERT、GPT完全基于Transformer
何时仍使用LSTM/GRU:
- 序列长度可变且较短
- 实时在线学习场景
- 边缘设备部署(参数量小)
- 时间序列预测(特别是多变量)
- 某些特定领域(如语音合成的韵律建模)
混合架构:
python
# CNN提取局部特征 + LSTM建模时序
class CNN_LSTM(nn.Module):
def __init__(self):
super().__init__()
self.cnn = nn.Conv1d(in_channels=10, out_channels=64, kernel_size=3)
self.lstm = nn.LSTM(input_size=64, hidden_size=128, num_layers=2)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
# x: (batch, seq_len, features)
x = x.transpose(1, 2) # (batch, features, seq_len)
x = self.cnn(x)
x = x.transpose(1, 2) # (batch, seq_len, channels)
output, (h_n, c_n) = self.lstm(x)
return self.fc(output[:, -1, :]) # 使用最后时刻的输出
训练技巧
梯度裁剪(必须):
python
# 防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
初始化策略:
python
# LSTM遗忘门偏置初始化为1(保持长期记忆)
for name, param in lstm.named_parameters():
if 'bias' in name:
n = param.size(0)
param.data[n//4:n//2].fill_(1.0) # 遗忘门偏置
序列打包(处理变长序列):
python
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# 打包序列
packed_input = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
packed_output, (h_n, c_n) = lstm(packed_input)
output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)