pytorch-优化器

PyTorch 2.11 文档中的 torch.optim 是一个专门用于优化神经网络参数的包。

如果说"模型"是躯干,"损失函数"是痛觉感官,那么 Optimizer(优化器) 就是大脑。它根据痛觉(Loss)来决定如何移动身体(更新参数),以求减小痛苦。


1. 核心工作流程:三步走

文档强调了使用优化器时的标准"公式",这几乎出现在每个 PyTorch 训练脚本中:

  1. 清零梯度optimizer.zero_grad()

    • 因为 PyTorch 的梯度是累加的,如果不清零,上一次的梯度会干扰本次计算。
  2. 反向传播loss.backward()

    • 计算每个参数对损失的影响(求导)。
  3. 更新参数optimizer.step()

    • 根据计算出的梯度,真正动手修改模型里的权重。

2. 如何创建一个优化器?

要构造一个优化器,你必须给它两个东西:

  1. 待优化的参数 :通常是 model.parameters()

  2. 学习率 (Learning Rate, lr):决定步子迈多大。

python 复制代码
# 以常用的随机梯度下降 (SGD) 为例
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 或者更聪明的 Adam
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

3. 常见优化算法对比

文档中列出了许多算法,最常用的有这几种:

优化器 全称 特点 适用场景
SGD Stochastic Gradient Descent 最基础,配合动量(Momentum)效果扎实。 几乎所有场景,尤其是微调预训练模型。
Adam Adaptive Moment Estimation 自适应学习率,非常鲁棒,收敛极快。 绝大多数初学者和复杂任务的首选。
RMSprop Root Mean Square Prop 适合处理非平稳目标。 循环神经网络 (RNN) 常用。
LBFGS 限定内存 BFGS 拟牛顿法,内存开销大。 小规模数据的精确优化。

4. 进阶功能:学习率调度器 (LR Scheduler)

文档中还有很大一部分在讲 torch.optim.lr_scheduler

  • 目的:训练刚开始时步子大一点(快速收敛),快结束时步子小一点(精确寻找最低点)。

  • 常见写法

python 复制代码
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step() # 每个 epoch 结束后更新一次学习率

5. 将优化器加入你之前的代码

结合你上一段 CIFAR-10 的代码,完整的"学习"闭环应该是这样的:

python 复制代码
# ... 之前的模型定义和数据加载 ...

loss_func = nn.CrossEntropyLoss()
tudui = Tudui()
# 1. 定义优化器
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)

for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss_func(outputs, targets)
    
    # 2. 优化三步走
    optimizer.zero_grad() # 清零
    result_loss.backward() # 算梯度
    optimizer.step()       # 改参数
    
    print("模型学习了一次...")

6. 💡 文档中的关键提示

  • Per-parameter options:优化器允许你为不同的层设置不同的学习率(比如冻结特征层,只调优分类层)。

  • Weight Decay (权重衰减):这是 L2 正则化的别名,文档中大多数优化器都支持这个参数,用于防止模型过拟合。

相关推荐
沅柠-AI营销2 小时前
TOB 工业制造与高端装备行业:AI 语义搜索赋能企业精准获客
人工智能·ai搜索优化·geo优化·企业降本·制造业获客·tob营销·b2b获客
m0_617881422 小时前
在 Go 中声明包级全局 Map 的正确方法
jvm·数据库·python
Polar__Star2 小时前
Redis怎样管理废弃的数据集合_利用EXPIRE指令为任意数据类型设置生命周期
jvm·数据库·python
weixin_568996062 小时前
CSS布局如何解决父级因全是绝对定位导致本身没高度的问题
jvm·数据库·python
weixin_381288182 小时前
MySQL无法通过网络连接服务器_检查bind-address与访问权限
jvm·数据库·python
Raink老师2 小时前
【AI面试临阵磨枪】什么是上下文窗口(Context Window)限制?主流解决方法有哪些?
人工智能·ai 面试
Irene19912 小时前
Python 中的 round() 函数不是严格的“四舍五入“,而是采用银行家舍入法(Bankers‘ Rounding)
python
ZC跨境爬虫2 小时前
3D 地球卫星轨道可视化平台开发 Day9(AI阈值调控+小众卫星识别+低Token测试模式实战)
人工智能·python·3d·信息可视化·json
GJGCY2 小时前
2026企业RPA+AI智能体落地技术全景:四阶段演进与关键架构决策
人工智能·安全·ai·rpa·智能体