Pytorch常用内置优化器合集

PyTorch 提供了多种优化器,每种优化器都有其独特的特点和适用场景。选择合适的优化器可以显著影响模型的训练效率、收敛速度和最终性能。以下是 PyTorch 中常见的几种优化器及其详细说明和使用场景:

1. SGD(随机梯度下降,Stochastic Gradient Descent)

简介:

SGD 是最基础的优化器之一,它直接根据损失函数的梯度来更新模型参数。每次更新的公式为:

其中:

  • θt 是当前的参数值。
  • η 是学习率(learning rate),控制每次更新的步长。
  • gt 是当前参数的梯度。
优点:
  • 简单易用:SGD 是最基础的优化器,易于理解和实现。
  • 适用于凸优化问题:在凸优化问题中,SGD 可以有效地找到全局最优解。
缺点:
  • 容易陷入局部最小值:对于非凸优化问题(如深度神经网络),SGD 可能会陷入局部最小值或鞍点。
  • 收敛速度较慢:SGD 的收敛速度相对较慢,尤其是在高维空间中。
  • 对学习率敏感:SGD 对学习率的选择非常敏感,学习率过大可能导致发散,过小则导致收敛缓慢。
使用场景:
  • 简单的线性模型:如线性回归、逻辑回归等任务,SGD 是一个不错的选择。
  • 大规模数据集:SGD 可以处理大规模数据集,因为它只需要计算每个批次的梯度,而不是整个数据集的梯度。
示例代码:
import torch.optim as optim

# 创建 SGD 优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

2. SGD with Momentum(带动量的随机梯度下降)

简介:

SGD with Momentum 在标准的 SGD 基础上引入了动量项,使得参数更新不仅依赖于当前的梯度,还考虑了之前更新的方向。动量项可以帮助加速收敛,并且有助于穿越平坦的区域(如鞍点)。更新公式为:

其中:

  • vt是动量项,表示历史梯度的累积。
  • β是动量系数,通常设置为 0.9。
优点:
  • 加速收敛:动量项可以帮助模型更快地穿越平坦区域,加速收敛。
  • 避免局部最小值:动量可以帮助模型逃离局部最小值,减少陷入局部最优解的风险。
缺点:
  • 对超参数敏感:动量系数 β 和学习率 η 需要仔细调整,否则可能会影响收敛效果。
使用场景:
  • 深度神经网络:对于复杂的深度神经网络,尤其是卷积神经网络(CNN)和循环神经网络(RNN),SGD with Momentum 是一个常用的选择。
  • 需要加速收敛的任务:当训练过程中遇到平坦区域或鞍点时,SGD with Momentum 可以帮助加速收敛。
示例代码:
# 创建带有动量的 SGD 优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

3. Adam(Adaptive Moment Estimation)

简介:

Adam 是一种自适应学习率的优化器,结合了动量(Momentum)和 RMSprop 的优点。Adam 通过维护梯度的一阶矩估计(均值)和二阶矩估计(方差)来动态调整每个参数的学习率。更新公式为:

其中:

  • mt 是梯度的一阶矩估计(均值)。
  • vt 是梯度的二阶矩估计(方差)。
  • β1 和 β2 分别是动量系数和二阶矩衰减系数,通常设置为 0.9 和 0.999。
  • ϵ 是一个小常数,防止除零错误,通常设置为 1e-8。
优点:
  • 自适应学习率:Adam 为每个参数分配不同的学习率,能够更好地处理稀疏梯度和噪声梯度。
  • 快速收敛:Adam 通常比其他优化器更快地收敛,尤其是在高维空间中。
  • 稳定性好:Adam 对学习率的选择相对不那么敏感,适合大多数深度学习任务。
缺点:
  • 内存消耗较大:Adam 需要维护两个额外的状态(一阶矩和二阶矩),因此相比其他优化器,它的内存消耗更大。
  • 可能过拟合:在某些情况下,Adam 可能会导致模型过拟合,尤其是在训练后期。
使用场景:
  • 大多数深度学习任务:Adam 是目前最常用的优化器之一,适用于各种类型的深度学习任务,包括图像分类、自然语言处理、强化学习等。
  • 复杂模型:对于复杂的模型(如深度卷积神经网络、Transformer 模型),Adam 通常能提供较好的收敛速度和稳定性。
示例代码:
# 创建 Adam 优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8)

4. RMSprop(Root Mean Squared Propagation)

简介:

RMSprop 是一种自适应学习率的优化器,专门用于处理稀疏梯度问题。它通过维护梯度的平方的移动平均值来动态调整学习率。更新公式为:

其中:

  • vt 是梯度平方的移动平均值。
  • β 是衰减系数,通常设置为 0.9。
  • ϵ 是一个小常数,防止除零错误。
优点:
  • 处理稀疏梯度:RMSprop 特别适合处理稀疏梯度问题,例如在自然语言处理任务中,词嵌入矩阵中的许多元素可能是稀疏的。
  • 稳定性强:RMSprop 对学习率的选择相对不那么敏感,适合大多数深度学习任务。
缺点:
  • 收敛速度较慢:相比 Adam,RMSprop 的收敛速度可能稍慢,尤其是在高维空间中。
  • 对超参数敏感:虽然 RMSprop 对学习率的选择相对不敏感,但仍然需要仔细调整衰减系数 ββ。
使用场景:
  • 稀疏梯度问题:对于涉及稀疏梯度的任务(如自然语言处理、推荐系统),RMSprop 是一个不错的选择。
  • 深度神经网络:RMSprop 也适用于深度神经网络,尤其是卷积神经网络(CNN)和循环神经网络(RNN)。
示例代码:
# 创建 RMSprop 优化器
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.9, eps=1e-8)

5. Adagrad(Adaptive Gradient Algorithm)

简介:

Adagrad 是一种自适应学习率的优化器,它为每个参数分配不同的学习率,基于该参数的历史梯度。更新公式为:

其中:

  • Gt 是所有历史梯度的平方和。
  • ϵ 是一个小常数,防止除零错误。
优点:
  • 处理稀疏梯度:Adagrad 特别适合处理稀疏梯度问题,因为它为每个参数分配不同的学习率,能够更好地处理不同频率更新的参数。
  • 不需要手动调整学习率:Adagrad 自动调整每个参数的学习率,减少了手动调整学习率的工作量。
缺点:
  • 学习率逐渐变小:随着训练的进行,Adagrad 的学习率会逐渐变小,导致训练后期的更新步长过小,可能难以继续优化。
  • 内存消耗较大:Adagrad 需要存储所有历史梯度的平方和,因此内存消耗较大。
使用场景:
  • 稀疏梯度问题:对于涉及稀疏梯度的任务(如自然语言处理、推荐系统),Adagrad 是一个不错的选择。
  • 早期训练阶段:Adagrad 在训练初期表现较好,但在训练后期可能会因为学习率过小而难以继续优化。
示例代码:
# 创建 Adagrad 优化器
optimizer = optim.Adagrad(model.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)

6. AdamW(Adam with Weight Decay)

简介:

AdamW 是 Adam 的改进版本,结合了 Adam 的自适应学习率和权重衰减(L2 正则化)。与传统的 Adam 不同,AdamW 在权重衰减时不会影响学习率的自适应性。更新公式为:

其中:

  • λ 是权重衰减系数,用于控制正则化的强度。
优点:
  • 防止过拟合:AdamW 通过引入权重衰减(L2 正则化),能够有效防止模型过拟合,提升泛化能力。
  • 保持自适应学习率的优点:AdamW 保留了 Adam 的自适应学习率特性,能够在训练过程中动态调整每个参数的学习率。
缺点:
  • 内存消耗较大:与 Adam 类似,AdamW 也需要维护两个额外的状态(一阶矩和二阶矩),因此内存消耗较大。
使用场景:
  • 需要正则化的任务:对于容易过拟合的任务(如图像分类、自然语言处理),AdamW 是一个非常好的选择,因为它结合了 Adam 的快速收敛和权重衰减的正则化效果。
  • 深度学习任务:AdamW 适用于大多数深度学习任务,尤其是在训练大型模型时,能够有效防止过拟合。
示例代码:
# 创建 AdamW 优化器
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
相关推荐
Trouvaille ~9 分钟前
【机器学习】从流动到恒常,无穷中归一:积分的数学诗意
人工智能·python·机器学习·ai·数据分析·matplotlib·微积分
dundunmm17 分钟前
论文阅读:Deep Fusion Clustering Network With Reliable Structure Preservation
论文阅读·人工智能·数据挖掘·聚类·深度聚类·图聚类
szxinmai主板定制专家26 分钟前
【国产NI替代】基于FPGA的4通道电压 250M采样终端边缘计算采集板卡,主控支持龙芯/飞腾
人工智能·边缘计算
是十一月末26 分钟前
Opencv实现图像的腐蚀、膨胀及开、闭运算
人工智能·python·opencv·计算机视觉
云空33 分钟前
《探索PyTorch计算机视觉:原理、应用与实践》
人工智能·pytorch·python·深度学习·计算机视觉
杭杭爸爸35 分钟前
无人直播源码
人工智能·语音识别
dowhileprogramming44 分钟前
Python 中的迭代器
linux·数据库·python
Ainnle2 小时前
微软 CEO 萨提亚・纳德拉:回顾过去十年,展望 AI 时代的战略布局
人工智能·microsoft
长风清留扬2 小时前
基于OpenAI Whisper AI模型自动生成视频字幕:全面解析与实战指南
人工智能·神经网络·opencv·计算机视觉·自然语言处理·数据挖掘·whisper
0zxm2 小时前
08 Django - Django媒体文件&静态文件&文件上传
数据库·后端·python·django·sqlite