pytorch-优化器

PyTorch 2.11 文档中的 torch.optim 是一个专门用于优化神经网络参数的包。

如果说"模型"是躯干,"损失函数"是痛觉感官,那么 Optimizer(优化器) 就是大脑。它根据痛觉(Loss)来决定如何移动身体(更新参数),以求减小痛苦。


1. 核心工作流程:三步走

文档强调了使用优化器时的标准"公式",这几乎出现在每个 PyTorch 训练脚本中:

  1. 清零梯度optimizer.zero_grad()

    • 因为 PyTorch 的梯度是累加的,如果不清零,上一次的梯度会干扰本次计算。
  2. 反向传播loss.backward()

    • 计算每个参数对损失的影响(求导)。
  3. 更新参数optimizer.step()

    • 根据计算出的梯度,真正动手修改模型里的权重。

2. 如何创建一个优化器?

要构造一个优化器,你必须给它两个东西:

  1. 待优化的参数 :通常是 model.parameters()

  2. 学习率 (Learning Rate, lr):决定步子迈多大。

python 复制代码
# 以常用的随机梯度下降 (SGD) 为例
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 或者更聪明的 Adam
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

3. 常见优化算法对比

文档中列出了许多算法,最常用的有这几种:

优化器 全称 特点 适用场景
SGD Stochastic Gradient Descent 最基础,配合动量(Momentum)效果扎实。 几乎所有场景,尤其是微调预训练模型。
Adam Adaptive Moment Estimation 自适应学习率,非常鲁棒,收敛极快。 绝大多数初学者和复杂任务的首选。
RMSprop Root Mean Square Prop 适合处理非平稳目标。 循环神经网络 (RNN) 常用。
LBFGS 限定内存 BFGS 拟牛顿法,内存开销大。 小规模数据的精确优化。

4. 进阶功能:学习率调度器 (LR Scheduler)

文档中还有很大一部分在讲 torch.optim.lr_scheduler

  • 目的:训练刚开始时步子大一点(快速收敛),快结束时步子小一点(精确寻找最低点)。

  • 常见写法

python 复制代码
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step() # 每个 epoch 结束后更新一次学习率

5. 将优化器加入你之前的代码

结合你上一段 CIFAR-10 的代码,完整的"学习"闭环应该是这样的:

python 复制代码
# ... 之前的模型定义和数据加载 ...

loss_func = nn.CrossEntropyLoss()
tudui = Tudui()
# 1. 定义优化器
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)

for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss_func(outputs, targets)
    
    # 2. 优化三步走
    optimizer.zero_grad() # 清零
    result_loss.backward() # 算梯度
    optimizer.step()       # 改参数
    
    print("模型学习了一次...")

6. 💡 文档中的关键提示

  • Per-parameter options:优化器允许你为不同的层设置不同的学习率(比如冻结特征层,只调优分类层)。

  • Weight Decay (权重衰减):这是 L2 正则化的别名,文档中大多数优化器都支持这个参数,用于防止模型过拟合。

相关推荐
YueJoy.AI10 小时前
创业公司如何实现持续增长
人工智能·ai·语言模型
zxsz_com_cn10 小时前
工厂中需要预测性维护的关键设备
人工智能·数据挖掘
AI科技星10 小时前
基于**v=c(空间光速螺旋运动)唯一第一性原理**重新完整求导证明
人工智能·线性代数·算法·机器学习·架构·概率论·学习方法
__log10 小时前
如何优雅地“借鉴”任何网站的设计系统
人工智能·架构·知识图谱
卡次卡次110 小时前
vibecoding起步注意点:插件、Skills、MCP、Hooks
服务器·数据库·python·oracle
醒醒该学习了!10 小时前
AI生成音频
人工智能
我的xiaodoujiao10 小时前
API 接口自动化测试详细图文教程学习系列24--如何用Pytest去设计接口测试用例并执行
python·学习·测试工具·pytest
SOC罗三炮10 小时前
OpenHuman 源码深度解构:一个 Rust 驱动的本地优先 AI 个人助手
开发语言·人工智能·rust
冰西瓜60010 小时前
深度学习的数学原理(四十一)—— KV Cache
人工智能·深度学习
一点一木10 小时前
🚀 2026 年 5 月 GitHub 十大热门项目排行榜 🔥
人工智能·github·ai编程