深度学习:优化器(Optimizer)

本节目标:

  • 理解优化器在深度学习训练中的作用
  • 掌握不同优化算法的原理,优缺点和使用场景
  • 了解学习率的调节机制(如学习率衰减,Warmup)
  • 理解PyTorch中如何正确使用优化器
  • 对比主流优化器在不同任务中的表现

一、什么是优化器?

优化器是一个算法,用于根据反向传播计算出的梯度,更新神经网络中的可学习参数(如权重和偏置)

优化器的核心目标是最小化损失函数:

其中:

  • :模型参数
  • :学习率
  • :参数的梯度

二、常见优化器原理与对比

|--------------|------------------|------------|------------|-----------------------|
| 优化器 | 原理 | 优点 | 缺点 | 适用场景 |
| SGD | 梯度下降法 | 简单,高效 | 收敛慢,易陷局部极小 | 小规模数据,控制精细更新 |
| SGD+Momentum | 加速梯度方向 | 快速收敛,越过鞍点 | 超参数复杂 | 图像任务(如CNN) |
| RMSProp | 自适应学习率 | 适用于非平稳目标 | 权重更新不直观 | RNN、时许任务 |
| Adam | RMSProp+Momentum | 收敛快,调参少 | 有时过拟合,泛化差 | 默认首选 |
| AdamW | 正确实现权重衰减 | 防止L2范数错误作用 | 略慢 | Transformer类模型(如BERT) |
| Adagrad | 每参数自适应更新 | 对稀疏数据有效 | 学习率快速衰减 | 文本,稀疏特征 |

三、每种优化器工作原理详解

1.SGD(Stochastic Gradient Descent)

最基本的优化器

  • 更新方向依赖当前梯度
  • 易受震荡影响,收敛慢

2.SGD+Momentum(动量项)

引入速度概念:

  • 是动量因子(常取0.9)
  • 避免震荡,跳出鞍点,加速收敛

3.RMSProp(Root Mean Square Prop)

核心思想:适用过去梯度平方的均值调整每个参数的学习率

  • 更适合非平稳目标(如RNN)

4.Adam(Adaptive Moment Estimation)

结合Momentum + RMSProp的优点:

  • 默认推荐优化器
  • 自适应每个参数的学习率
  • =0.9,=0.999通常效果好

5.AdamW(Weight Decay修正)

Adam原始版本中权重衰减不正确,AdamW正确分离了权重衰减项:

  • 当前Transformer、BERT、ViT中的默认优化器

四、学习率策略(LR Schedule)

|-------------------|------------------------------------------|-----------|
| 策略 | 说明 | 使用场景 |
| 固定学习率 | 手动设定不变 | 简单但收敛慢 |
| Step Decay | 每N轮衰减一次 | 经典方式 |
| Exponential Decay | 每轮按指数衰减 | 过拟合时使用 |
| Cosine Annealing | 逐步变慢再重启 | 近年大模型常用 |
| Warmup | 初始逐渐增加学习率 | 避免前期训练不稳定 |

python 复制代码
# 示例:PyTorch中CosineAnnealing + Warmup
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

五、PyTorch优化器使用模板

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

for epoch in range(num_epochs):
    for batch in dataloader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

六、优化器选择建议总结

|-------------|--------------|-------------------|
| 场景 | 首选优化器 | 备注 |
| 初学者&默认 | Adam | 通用表现优良 |
| 文本任务(如NLP) | AdamW | HuggingFace等大模型默认 |
| 图像任务 | SGD+Momentum | 微调CNN效果更稳定 |
| RNN结构 | RMSProp | 时间序列任务 |
| 小样本 or 稀疏特征 | Adagrad | 推荐系统、文本分类 |

七、常见问题与调试建议

|----------|------------|----------------|
| 现象 | 可能原因 | 建议 |
| loss不下降 | 学习率太小或太大 | 使用warmup或调整LR |
| loss波动大 | 没有使用动量 | 使用SGD+Momentum |
| 模型过拟合 | 权重未衰减 | 使用AdamW或添加正则项 |
| loss为nan | 学习率过高或梯度爆炸 | 梯度裁剪、降低lr |

八、优化器进化关系

css 复制代码
SGD
 └── + Momentum → SGD+Momentum
      └── + RMS → RMSProp
           └── + 一阶矩 → Adam
                └── + Weight Decay 修复 → AdamW
相关推荐
小王爱学人工智能2 分钟前
快速了解机器学习
人工智能·机器学习
hqxstudying37 分钟前
SpringAI的使用
java·开发语言·人工智能·springai
victory043142 分钟前
影响人类发音的疾病类型种类和数据集
人工智能·深度学习·ai
lishaoan7743 分钟前
tensorflow目标分类:分绍(一)
人工智能·分类·tensorflow·目标分类
Albert_Lsk44 分钟前
【2025/08/03】GitHub 今日热门项目
人工智能·开源·github·开源协议
点云SLAM1 小时前
PyTorch 应用于3D 点云数据处理汇总和点云配准示例演示
人工智能·pytorch·深度学习·3d·点云目标检测·点云补全·点云分类
AI Echoes1 小时前
ChatGPT、Playground手动模拟Agent摘要缓冲混合记忆功能
人工智能·python·langchain
腾讯云开发者1 小时前
智涌云端,与 AI 共生,腾讯云架构师峰会圆满落幕!
人工智能