PyTorch torch.optim 优化器介绍与论文

目录

    • 概述
    • 常用优化器
      • [1. **SGD** (Stochastic Gradient Descent) - 随机梯度下降](#1. SGD (Stochastic Gradient Descent) - 随机梯度下降)
      • [2. **Adam** (Adaptive Moment Estimation) ⭐ 最常用](#2. Adam (Adaptive Moment Estimation) ⭐ 最常用)
      • [3. **AdamW** (Adam with Weight Decay) ⭐ PI0.5 使用](#3. AdamW (Adam with Weight Decay) ⭐ PI0.5 使用)
      • [4. **RMSprop** (Root Mean Square Propagation)](#4. RMSprop (Root Mean Square Propagation))
      • [5. **Adagrad** (Adaptive Gradient)](#5. Adagrad (Adaptive Gradient))
      • [6. **Adadelta**](#6. Adadelta)
      • [7. **Adamax**](#7. Adamax)
      • [8. **RAdam** (Rectified Adam)](#8. RAdam (Rectified Adam))
      • [9. **LBFGS** (Limited-memory BFGS)](#9. LBFGS (Limited-memory BFGS))
    • 优化器对比表
    • [在 LeRobot 中的使用](#在 LeRobot 中的使用)
      • [PI0.5 配置](#PI0.5 配置)
    • 选择建议
    • 关键论文总结
    • 参考资料
    • 总结

概述

torch.optim 是 PyTorch 提供的优化器模块,包含多种梯度下降优化算法。

常用优化器

1. SGD (Stochastic Gradient Descent) - 随机梯度下降

简介:最基础的优化算法,使用固定学习率更新参数。

公式

复制代码
v_t = momentum × v_{t-1} + g_t
θ_t = θ_{t-1} - lr × v_t

特点

  • 简单稳定
  • 收敛速度较慢
  • 适合凸优化问题

论文

使用示例

python 复制代码
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

2. Adam (Adaptive Moment Estimation) ⭐ 最常用

简介:自适应学习率优化器,结合了动量和自适应学习率。

公式

复制代码
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t          # 一阶矩估计
v_t = β₂ × v_{t-1} + (1 - β₂) × g_t²         # 二阶矩估计
m̂_t = m_t / (1 - β₁^t)                       # 偏差修正
v̂_t = v_t / (1 - β₂^t)
θ_t = θ_{t-1} - lr × m̂_t / (√v̂_t + ε)

特点

  • ✅ 自适应学习率
  • ✅ 收敛速度快
  • ✅ 对超参数不敏感
  • ✅ 适合大多数深度学习任务

论文

  • Adam: A Method for Stochastic Optimization

使用示例

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))

3. AdamW (Adam with Weight Decay) ⭐ PI0.5 使用

简介:Adam 的改进版本,修正了权重衰减的实现。

关键改进

  • 将权重衰减从梯度中分离
  • 更正确的 L2 正则化实现
  • 通常比 Adam 效果更好

公式

复制代码
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t
v_t = β₂ × v_{t-1} + (1 - β₂) × g_t²
m̂_t = m_t / (1 - β₁^t)
v̂_t = v_t / (1 - β₂^t)
θ_t = θ_{t-1} - lr × [m̂_t / (√v̂_t + ε) + weight_decay × θ_{t-1}]

特点

  • ✅ 修正了 Adam 的权重衰减问题
  • ✅ 更好的泛化性能
  • ✅ 适合 Transformer 等大模型

论文

使用示例

python 复制代码
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=2.5e-5, 
    betas=(0.9, 0.95),
    weight_decay=0.01
)

4. RMSprop (Root Mean Square Propagation)

简介:自适应学习率优化器,使用梯度平方的移动平均。

公式

复制代码
v_t = α × v_{t-1} + (1 - α) × g_t²
θ_t = θ_{t-1} - lr × g_t / (√v_t + ε)

特点

  • ✅ 自适应学习率
  • ✅ 适合非平稳目标
  • ✅ RNN 训练效果好

论文

使用示例

python 复制代码
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)

5. Adagrad (Adaptive Gradient)

简介:自适应学习率优化器,累积历史梯度平方。

公式

复制代码
G_t = G_{t-1} + g_t²
θ_t = θ_{t-1} - lr × g_t / (√G_t + ε)

特点

  • ✅ 自动降低学习率
  • ⚠️ 学习率可能过小
  • ⚠️ 适合稀疏梯度

论文

使用示例

python 复制代码
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)

6. Adadelta

简介:Adagrad 的改进版本,解决学习率衰减过快的问题。

公式

复制代码
E[g²]_t = ρ × E[g²]_{t-1} + (1 - ρ) × g_t²
Δθ_t = -√(E[Δθ²]_{t-1} + ε) / √(E[g²]_t + ε) × g_t
E[Δθ²]_t = ρ × E[Δθ²]_{t-1} + (1 - ρ) × Δθ_t²
θ_t = θ_{t-1} + Δθ_t

特点

  • ✅ 不需要手动设置学习率
  • ✅ 解决 Adagrad 学习率衰减问题

论文

使用示例

python 复制代码
optimizer = torch.optim.Adadelta(model.parameters(), rho=0.9)

7. Adamax

简介:Adam 的变体,使用无穷范数代替 L2 范数。

公式

复制代码
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t
u_t = max(β₂ × u_{t-1}, |g_t|)
θ_t = θ_{t-1} - lr × m_t / (u_t + ε)

特点

  • ✅ 在某些情况下比 Adam 更稳定
  • ✅ 适合稀疏梯度

论文

  • Adam: A Method for Stochastic Optimization (与 Adam 同一篇)

使用示例

python 复制代码
optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, betas=(0.9, 0.999))

8. RAdam (Rectified Adam)

简介:修正 Adam 的方差问题,在训练初期更稳定。

特点

  • ✅ 修正 Adam 的方差问题
  • ✅ 训练初期更稳定
  • ✅ 自适应切换到 SGD

论文

  • On the Variance of the Adaptive Learning Rate and Beyond

注意:PyTorch 原生不支持,需要第三方库。


9. LBFGS (Limited-memory BFGS)

简介:拟牛顿法,使用二阶导数信息。

特点

  • ✅ 收敛速度快(接近二阶方法)
  • ⚠️ 内存占用大
  • ⚠️ 不适合大规模模型

论文

使用示例

python 复制代码
optimizer = torch.optim.LBFGS(model.parameters(), lr=1, max_iter=20)

优化器对比表

优化器 学习率 动量 自适应 收敛速度 推荐度
SGD 固定 可选 ⭐⭐
Adam 自适应 ⭐⭐⭐⭐⭐
AdamW 自适应 ⭐⭐⭐⭐⭐
RMSprop 自适应 ⭐⭐⭐
Adagrad 自适应 ⭐⭐
Adadelta 自适应 ⭐⭐⭐

在 LeRobot 中的使用

PI0.5 配置

文件policies/pi05/configuration_pi05.py

python 复制代码
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01

实际使用

python 复制代码
optimizer = torch.optim.AdamW(
    params,
    lr=2.5e-5,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.01
)

选择建议

推荐使用

  1. AdamW ⭐⭐⭐⭐⭐

    • 大多数深度学习任务
    • Transformer 模型
    • 大模型微调
  2. Adam ⭐⭐⭐⭐

    • 通用深度学习
    • 快速原型开发
  3. SGD + Momentum ⭐⭐⭐

    • 凸优化问题
    • 需要稳定训练时

特殊场景

  • RNN/LSTM:RMSprop
  • 稀疏梯度:Adagrad, Adamax
  • 小规模模型:LBFGS

关键论文总结

优化器 核心论文 年份 作者
SGD Stochastic Approximation 1951 Robbins & Monro
Momentum On the importance of initialization and momentum 2013 Sutskever et al.
Adam Adam: A Method for Stochastic Optimization 2014 Kingma & Ba
AdamW Decoupled Weight Decay Regularization 2017 Loshchilov & Hutter
RMSprop Neural Networks Lecture 6 2012 Hinton
Adagrad Adaptive Subgradient Methods 2011 Duchi et al.
Adadelta ADADELTA: An Adaptive Learning Rate Method 2012 Zeiler
RAdam On the Variance of the Adaptive Learning Rate 2019 Liu et al.

参考资料

总结

最推荐AdamW - 修正了 Adam 的权重衰减问题,效果更好

PI0.5 使用AdamW - 配置为 lr=2.5e-5, betas=(0.9, 0.95), weight_decay=0.01

相关推荐
高工智能汽车4 小时前
爱芯元智通过港交所聆讯,智能汽车芯片市场格局加速重构
人工智能·重构·汽车
大力财经4 小时前
悬架、底盘、制动被同时重构,星空计划想把“驾驶”变成一种系统能力
人工智能
喵手5 小时前
Python爬虫零基础入门【第九章:实战项目教学·第15节】搜索页采集:关键词队列 + 结果去重 + 反爬友好策略!
爬虫·python·爬虫实战·python爬虫工程化实战·零基础python爬虫教学·搜索页采集·关键词队列
梁下轻语的秋缘5 小时前
Prompt工程核心指南:从入门到精通,让AI精准响应你的需求
大数据·人工智能·prompt
FreeBuf_5 小时前
ChatGPT引用马斯克AI生成的Grokipedia是否陷入“内容陷阱“?
人工智能·chatgpt
Suchadar5 小时前
if判断语句——Python
开发语言·python
ʚB҉L҉A҉C҉K҉.҉基҉德҉^҉大5 小时前
自动化机器学习(AutoML)库TPOT使用指南
jvm·数据库·python
福客AI智能客服5 小时前
工单智转:电商智能客服与客服AI系统重构售后服务效率
大数据·人工智能
柳鲲鹏5 小时前
OpenCV:超分辨率、超采样及测试性能
人工智能·opencv·计算机视觉
喵手5 小时前
Python爬虫零基础入门【第九章:实战项目教学·第14节】表格型页面采集:多列、多行、跨页(通用表格解析)!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·表格型页面采集·通用表格解析