PyTorch torch.optim 优化器介绍与论文

目录

    • 概述
    • 常用优化器
      • [1. **SGD** (Stochastic Gradient Descent) - 随机梯度下降](#1. SGD (Stochastic Gradient Descent) - 随机梯度下降)
      • [2. **Adam** (Adaptive Moment Estimation) ⭐ 最常用](#2. Adam (Adaptive Moment Estimation) ⭐ 最常用)
      • [3. **AdamW** (Adam with Weight Decay) ⭐ PI0.5 使用](#3. AdamW (Adam with Weight Decay) ⭐ PI0.5 使用)
      • [4. **RMSprop** (Root Mean Square Propagation)](#4. RMSprop (Root Mean Square Propagation))
      • [5. **Adagrad** (Adaptive Gradient)](#5. Adagrad (Adaptive Gradient))
      • [6. **Adadelta**](#6. Adadelta)
      • [7. **Adamax**](#7. Adamax)
      • [8. **RAdam** (Rectified Adam)](#8. RAdam (Rectified Adam))
      • [9. **LBFGS** (Limited-memory BFGS)](#9. LBFGS (Limited-memory BFGS))
    • 优化器对比表
    • [在 LeRobot 中的使用](#在 LeRobot 中的使用)
      • [PI0.5 配置](#PI0.5 配置)
    • 选择建议
    • 关键论文总结
    • 参考资料
    • 总结

概述

torch.optim 是 PyTorch 提供的优化器模块,包含多种梯度下降优化算法。

常用优化器

1. SGD (Stochastic Gradient Descent) - 随机梯度下降

简介:最基础的优化算法,使用固定学习率更新参数。

公式

复制代码
v_t = momentum × v_{t-1} + g_t
θ_t = θ_{t-1} - lr × v_t

特点

  • 简单稳定
  • 收敛速度较慢
  • 适合凸优化问题

论文

使用示例

python 复制代码
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

2. Adam (Adaptive Moment Estimation) ⭐ 最常用

简介:自适应学习率优化器,结合了动量和自适应学习率。

公式

复制代码
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t          # 一阶矩估计
v_t = β₂ × v_{t-1} + (1 - β₂) × g_t²         # 二阶矩估计
m̂_t = m_t / (1 - β₁^t)                       # 偏差修正
v̂_t = v_t / (1 - β₂^t)
θ_t = θ_{t-1} - lr × m̂_t / (√v̂_t + ε)

特点

  • ✅ 自适应学习率
  • ✅ 收敛速度快
  • ✅ 对超参数不敏感
  • ✅ 适合大多数深度学习任务

论文

  • Adam: A Method for Stochastic Optimization

使用示例

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))

3. AdamW (Adam with Weight Decay) ⭐ PI0.5 使用

简介:Adam 的改进版本,修正了权重衰减的实现。

关键改进

  • 将权重衰减从梯度中分离
  • 更正确的 L2 正则化实现
  • 通常比 Adam 效果更好

公式

复制代码
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t
v_t = β₂ × v_{t-1} + (1 - β₂) × g_t²
m̂_t = m_t / (1 - β₁^t)
v̂_t = v_t / (1 - β₂^t)
θ_t = θ_{t-1} - lr × [m̂_t / (√v̂_t + ε) + weight_decay × θ_{t-1}]

特点

  • ✅ 修正了 Adam 的权重衰减问题
  • ✅ 更好的泛化性能
  • ✅ 适合 Transformer 等大模型

论文

使用示例

python 复制代码
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=2.5e-5, 
    betas=(0.9, 0.95),
    weight_decay=0.01
)

4. RMSprop (Root Mean Square Propagation)

简介:自适应学习率优化器,使用梯度平方的移动平均。

公式

复制代码
v_t = α × v_{t-1} + (1 - α) × g_t²
θ_t = θ_{t-1} - lr × g_t / (√v_t + ε)

特点

  • ✅ 自适应学习率
  • ✅ 适合非平稳目标
  • ✅ RNN 训练效果好

论文

使用示例

python 复制代码
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)

5. Adagrad (Adaptive Gradient)

简介:自适应学习率优化器,累积历史梯度平方。

公式

复制代码
G_t = G_{t-1} + g_t²
θ_t = θ_{t-1} - lr × g_t / (√G_t + ε)

特点

  • ✅ 自动降低学习率
  • ⚠️ 学习率可能过小
  • ⚠️ 适合稀疏梯度

论文

使用示例

python 复制代码
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)

6. Adadelta

简介:Adagrad 的改进版本,解决学习率衰减过快的问题。

公式

复制代码
E[g²]_t = ρ × E[g²]_{t-1} + (1 - ρ) × g_t²
Δθ_t = -√(E[Δθ²]_{t-1} + ε) / √(E[g²]_t + ε) × g_t
E[Δθ²]_t = ρ × E[Δθ²]_{t-1} + (1 - ρ) × Δθ_t²
θ_t = θ_{t-1} + Δθ_t

特点

  • ✅ 不需要手动设置学习率
  • ✅ 解决 Adagrad 学习率衰减问题

论文

使用示例

python 复制代码
optimizer = torch.optim.Adadelta(model.parameters(), rho=0.9)

7. Adamax

简介:Adam 的变体,使用无穷范数代替 L2 范数。

公式

复制代码
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t
u_t = max(β₂ × u_{t-1}, |g_t|)
θ_t = θ_{t-1} - lr × m_t / (u_t + ε)

特点

  • ✅ 在某些情况下比 Adam 更稳定
  • ✅ 适合稀疏梯度

论文

  • Adam: A Method for Stochastic Optimization (与 Adam 同一篇)

使用示例

python 复制代码
optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, betas=(0.9, 0.999))

8. RAdam (Rectified Adam)

简介:修正 Adam 的方差问题,在训练初期更稳定。

特点

  • ✅ 修正 Adam 的方差问题
  • ✅ 训练初期更稳定
  • ✅ 自适应切换到 SGD

论文

  • On the Variance of the Adaptive Learning Rate and Beyond

注意:PyTorch 原生不支持,需要第三方库。


9. LBFGS (Limited-memory BFGS)

简介:拟牛顿法,使用二阶导数信息。

特点

  • ✅ 收敛速度快(接近二阶方法)
  • ⚠️ 内存占用大
  • ⚠️ 不适合大规模模型

论文

使用示例

python 复制代码
optimizer = torch.optim.LBFGS(model.parameters(), lr=1, max_iter=20)

优化器对比表

优化器 学习率 动量 自适应 收敛速度 推荐度
SGD 固定 可选 ⭐⭐
Adam 自适应 ⭐⭐⭐⭐⭐
AdamW 自适应 ⭐⭐⭐⭐⭐
RMSprop 自适应 ⭐⭐⭐
Adagrad 自适应 ⭐⭐
Adadelta 自适应 ⭐⭐⭐

在 LeRobot 中的使用

PI0.5 配置

文件policies/pi05/configuration_pi05.py

python 复制代码
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01

实际使用

python 复制代码
optimizer = torch.optim.AdamW(
    params,
    lr=2.5e-5,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.01
)

选择建议

推荐使用

  1. AdamW ⭐⭐⭐⭐⭐

    • 大多数深度学习任务
    • Transformer 模型
    • 大模型微调
  2. Adam ⭐⭐⭐⭐

    • 通用深度学习
    • 快速原型开发
  3. SGD + Momentum ⭐⭐⭐

    • 凸优化问题
    • 需要稳定训练时

特殊场景

  • RNN/LSTM:RMSprop
  • 稀疏梯度:Adagrad, Adamax
  • 小规模模型:LBFGS

关键论文总结

优化器 核心论文 年份 作者
SGD Stochastic Approximation 1951 Robbins & Monro
Momentum On the importance of initialization and momentum 2013 Sutskever et al.
Adam Adam: A Method for Stochastic Optimization 2014 Kingma & Ba
AdamW Decoupled Weight Decay Regularization 2017 Loshchilov & Hutter
RMSprop Neural Networks Lecture 6 2012 Hinton
Adagrad Adaptive Subgradient Methods 2011 Duchi et al.
Adadelta ADADELTA: An Adaptive Learning Rate Method 2012 Zeiler
RAdam On the Variance of the Adaptive Learning Rate 2019 Liu et al.

参考资料

总结

最推荐AdamW - 修正了 Adam 的权重衰减问题,效果更好

PI0.5 使用AdamW - 配置为 lr=2.5e-5, betas=(0.9, 0.95), weight_decay=0.01

相关推荐
艾莉丝努力练剑3 小时前
hixl vs NCCL:昇腾生态通信库的独特优势分析
运维·c++·人工智能·cann
执风挽^3 小时前
Python基础编程题2
开发语言·python·算法·visual studio code
梦帮科技3 小时前
Node.js配置生成器CLI工具开发实战
前端·人工智能·windows·前端框架·node.js·json
程序员泠零澪回家种桔子3 小时前
Spring AI框架全方位详解
java·人工智能·后端·spring·ai·架构
Echo_NGC22373 小时前
【FFmpeg 使用指南】Part 3:码率控制策略与质量评估体系
人工智能·ffmpeg·视频·码率
纤纡.3 小时前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
大大大反派3 小时前
CANN 生态中的自动化部署引擎:深入 `mindx-sdk` 项目构建端到端 AI 应用
运维·人工智能·自动化
程序猿追3 小时前
深度解读 AIR (AI Runtime):揭秘 CANN 极致算力编排与调度的核心引擎
人工智能
2601_949593653 小时前
深入解析CANN-acl应用层接口:构建高效的AI应用开发框架
数据库·人工智能
●VON3 小时前
CANN安全与隐私:从模型加固到数据合规的全栈防护实战
人工智能·安全