pytorch-优化器

PyTorch 2.11 文档中的 torch.optim 是一个专门用于优化神经网络参数的包。

如果说"模型"是躯干,"损失函数"是痛觉感官,那么 Optimizer(优化器) 就是大脑。它根据痛觉(Loss)来决定如何移动身体(更新参数),以求减小痛苦。


1. 核心工作流程:三步走

文档强调了使用优化器时的标准"公式",这几乎出现在每个 PyTorch 训练脚本中:

  1. 清零梯度optimizer.zero_grad()

    • 因为 PyTorch 的梯度是累加的,如果不清零,上一次的梯度会干扰本次计算。
  2. 反向传播loss.backward()

    • 计算每个参数对损失的影响(求导)。
  3. 更新参数optimizer.step()

    • 根据计算出的梯度,真正动手修改模型里的权重。

2. 如何创建一个优化器?

要构造一个优化器,你必须给它两个东西:

  1. 待优化的参数 :通常是 model.parameters()

  2. 学习率 (Learning Rate, lr):决定步子迈多大。

python 复制代码
# 以常用的随机梯度下降 (SGD) 为例
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 或者更聪明的 Adam
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

3. 常见优化算法对比

文档中列出了许多算法,最常用的有这几种:

优化器 全称 特点 适用场景
SGD Stochastic Gradient Descent 最基础,配合动量(Momentum)效果扎实。 几乎所有场景,尤其是微调预训练模型。
Adam Adaptive Moment Estimation 自适应学习率,非常鲁棒,收敛极快。 绝大多数初学者和复杂任务的首选。
RMSprop Root Mean Square Prop 适合处理非平稳目标。 循环神经网络 (RNN) 常用。
LBFGS 限定内存 BFGS 拟牛顿法,内存开销大。 小规模数据的精确优化。

4. 进阶功能:学习率调度器 (LR Scheduler)

文档中还有很大一部分在讲 torch.optim.lr_scheduler

  • 目的:训练刚开始时步子大一点(快速收敛),快结束时步子小一点(精确寻找最低点)。

  • 常见写法

python 复制代码
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step() # 每个 epoch 结束后更新一次学习率

5. 将优化器加入你之前的代码

结合你上一段 CIFAR-10 的代码,完整的"学习"闭环应该是这样的:

python 复制代码
# ... 之前的模型定义和数据加载 ...

loss_func = nn.CrossEntropyLoss()
tudui = Tudui()
# 1. 定义优化器
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)

for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss_func(outputs, targets)
    
    # 2. 优化三步走
    optimizer.zero_grad() # 清零
    result_loss.backward() # 算梯度
    optimizer.step()       # 改参数
    
    print("模型学习了一次...")

6. 💡 文档中的关键提示

  • Per-parameter options:优化器允许你为不同的层设置不同的学习率(比如冻结特征层,只调优分类层)。

  • Weight Decay (权重衰减):这是 L2 正则化的别名,文档中大多数优化器都支持这个参数,用于防止模型过拟合。

相关推荐
道友可好15 分钟前
AI 是最好的混乱放大器:代码熵管理实战
前端·人工智能·后端
不加辣椒2 小时前
第7章 边界与约束技术:确保输出的准确性与安全性
人工智能
AI悦创Python辅导2 小时前
Claude Code 越用越乱?Sub-Agents 才是上下文污染的解法
人工智能
Bigfish_coding2 小时前
前端转agent-【python】-07 长期记忆进阶:用 ChromaDB + 语义搜索给 Agent 装上真正的长期记忆
人工智能
阿黎梨梨2 小时前
AI Loop:告别“人肉写提示词”,让代码替你“鞭策”AI
javascript·人工智能
Csvn3 小时前
Python 两大经典坑点 —— 可变默认参数 & 闭包延迟绑定
后端·python
甲维斯3 小时前
坦克大战测试全翻车了!豆包,DeepSeek,Qwen,GPT,Claude
前端·人工智能·游戏开发
若丶相见3 小时前
AI 大模型零基础知识扫盲
人工智能
曲幽4 小时前
别再用网页翻译看源码了!你的私人翻译神器LibreTranslate,部署避坑指南来了
python·docker·web·pot·translate·libretranslate·arogstranslate
猿人谷4 小时前
不只是 CPU 阈值:STAR 如何用 GAT + Transformer 做容器级自动扩缩容?
人工智能·算法