本节目标:
- 理解优化器在深度学习训练中的作用
- 掌握不同优化算法的原理,优缺点和使用场景
- 了解学习率的调节机制(如学习率衰减,Warmup)
- 理解PyTorch中如何正确使用优化器
- 对比主流优化器在不同任务中的表现
一、什么是优化器?
优化器是一个算法,用于根据反向传播计算出的梯度,更新神经网络中的可学习参数(如权重和偏置)
优化器的核心目标是最小化损失函数:
其中:
:模型参数
:学习率
:参数的梯度
二、常见优化器原理与对比
|--------------|------------------|------------|------------|-----------------------|
| 优化器 | 原理 | 优点 | 缺点 | 适用场景 |
| SGD | 梯度下降法 | 简单,高效 | 收敛慢,易陷局部极小 | 小规模数据,控制精细更新 |
| SGD+Momentum | 加速梯度方向 | 快速收敛,越过鞍点 | 超参数复杂 | 图像任务(如CNN) |
| RMSProp | 自适应学习率 | 适用于非平稳目标 | 权重更新不直观 | RNN、时许任务 |
| Adam | RMSProp+Momentum | 收敛快,调参少 | 有时过拟合,泛化差 | 默认首选 |
| AdamW | 正确实现权重衰减 | 防止L2范数错误作用 | 略慢 | Transformer类模型(如BERT) |
| Adagrad | 每参数自适应更新 | 对稀疏数据有效 | 学习率快速衰减 | 文本,稀疏特征 |
三、每种优化器工作原理详解
1.SGD(Stochastic Gradient Descent)
最基本的优化器
- 更新方向依赖当前梯度
- 易受震荡影响,收敛慢
2.SGD+Momentum(动量项)
引入速度概念:
是动量因子(常取0.9)
- 避免震荡,跳出鞍点,加速收敛
3.RMSProp(Root Mean Square Prop)
核心思想:适用过去梯度平方的均值调整每个参数的学习率
- 更适合非平稳目标(如RNN)
4.Adam(Adaptive Moment Estimation)
结合Momentum + RMSProp的优点:
,
- 默认推荐优化器
- 自适应每个参数的学习率
=0.9,
=0.999通常效果好
5.AdamW(Weight Decay修正)
Adam原始版本中权重衰减不正确,AdamW正确分离了权重衰减项:
- 当前Transformer、BERT、ViT中的默认优化器
四、学习率策略(LR Schedule)
|-------------------|------------------------------------------|-----------|
| 策略 | 说明 | 使用场景 |
| 固定学习率 | 手动设定不变 | 简单但收敛慢 |
| Step Decay | 每N轮衰减一次 | 经典方式 |
| Exponential Decay | 每轮按指数衰减 | 过拟合时使用 |
| Cosine Annealing | 逐步变慢再重启 | 近年大模型常用 |
| Warmup | 初始逐渐增加学习率 | 避免前期训练不稳定 |
python
# 示例:PyTorch中CosineAnnealing + Warmup
from torch.optim.lr_scheduler import CosineAnnealingLR
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
五、PyTorch优化器使用模板
python
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
for epoch in range(num_epochs):
for batch in dataloader:
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
六、优化器选择建议总结
|-------------|--------------|-------------------|
| 场景 | 首选优化器 | 备注 |
| 初学者&默认 | Adam | 通用表现优良 |
| 文本任务(如NLP) | AdamW | HuggingFace等大模型默认 |
| 图像任务 | SGD+Momentum | 微调CNN效果更稳定 |
| RNN结构 | RMSProp | 时间序列任务 |
| 小样本 or 稀疏特征 | Adagrad | 推荐系统、文本分类 |
七、常见问题与调试建议
|----------|------------|----------------|
| 现象 | 可能原因 | 建议 |
| loss不下降 | 学习率太小或太大 | 使用warmup或调整LR |
| loss波动大 | 没有使用动量 | 使用SGD+Momentum |
| 模型过拟合 | 权重未衰减 | 使用AdamW或添加正则项 |
| loss为nan | 学习率过高或梯度爆炸 | 梯度裁剪、降低lr |
八、优化器进化关系
css
SGD
└── + Momentum → SGD+Momentum
└── + RMS → RMSProp
└── + 一阶矩 → Adam
└── + Weight Decay 修复 → AdamW