Pytorch中的优化器及其用法

来自 千问

在 PyTorch 中,优化器 就像是一个"智能导航员"。它根据损失函数计算出的梯度(即"路况信息"),决定模型参数(权重)应该往哪个方向、以多大的步幅移动,才能最快地到达"损失最低"的目的地。

PyTorch 的优化器都封装在 torch.optim 模块中。

为什么要用优化器?

如果没有优化器,你需要手动写代码来更新每一个参数,非常繁琐且容易出错:

python 复制代码
# 手动更新(不推荐)
w = w - learning_rate * w.grad

有了优化器,你只需要调用一行代码 optimizer.step(),它就能自动遍历模型中所有需要更新的参数,并根据设定的算法(如 SGD、Adam)完成更新。

核心用法:三步走

  1. 构建优化器:在训练前,告诉优化器它要管理哪些参数,以及学习率是多少。
  2. 清空梯度:在每次反向传播前,把上一轮计算的梯度清零。
  3. 更新参数:在反向传播计算出梯度后,让优化器执行更新。

常用优化器介绍与对比

优化器 全称 特点 适用场景
SGD 随机梯度下降 最基础,可加动量(Momentum) 教学、简单任务、需要精细调参
Adam 自适应矩估计 最常用。自适应学习率,收敛快 默认推荐!90% 的场景都可用
AdamW Adam + 权重衰减 Adam 的改进版,正则化效果更好 训练 Transformer 等大模型时首选
RMSprop 均方根传播 自适应学习率,适合非稳态目标 RNN、GAN 等复杂模型

代码示例:完整训练流程

下面是一个使用 nn.Linear 模型和 Adam 优化器的完整示例:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 1. 定义模型
model = nn.Linear(10, 1)  # 简单线性模型

# 2. 定义优化器 (关键步骤)
# lr (learning_rate): 学习率,控制步长
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 3. 定义损失函数
criterion = nn.MSELoss()

# --- 模拟训练循环 ---
# 模拟数据
inputs = torch.randn(32, 10)  # 32个样本,10维特征
targets = torch.randn(32, 1)   # 真实标签

for epoch in range(100):
    # --- 训练四部曲 ---
    
    # 1. 前向传播:计算预测值
    outputs = model(inputs)
    
    # 2. 计算损失
    loss = criterion(outputs, targets)
    
    # 3. 反向传播:计算梯度
    # ⚠️ 在这之前,必须先清空上一轮的梯度!
    optimizer.zero_grad()   # 清空梯度
    loss.backward()        # 计算梯度 (此时梯度被存入 .grad 属性)
    
    # 4. 更新参数:优化器执行更新
    optimizer.step()       # 核心:根据梯度更新参数
    
    # 打印损失
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

print("训练完成!")

进阶用法:为不同层设置不同学习率

在实际项目中(如迁移学习),你可能希望模型的某些部分学得快一点,某些部分学得慢一点。你可以通过传入参数组列表来实现:

python 复制代码
# 假设 model 有两个部分:features 和 classifier
optimizer = optim.SGD([
    {'params': model.features.parameters()},           # 特征提取层:使用默认 lr
    {'params': model.classifier.parameters(), 'lr': 1e-3} # 分类层:单独设置 lr
], lr=1e-2, momentum=0.9) # 默认学习率为 1e-2,momentum 对所有层都生效

总结

  • optimizer.zero_grad()必须loss.backward() 之前调用,防止梯度累加。
  • optimizer.step()必须loss.backward() 之后调用,执行参数更新。
  • 选择建议
    • 新手/通用 :直接用 Adam,省心且效果好。
    • 追求极致 :在图像分类等任务中,最后微调时尝试 SGD
    • 大模型 :使用 AdamW
相关推荐
遇见小修修9 小时前
选择正规上门修电脑服务,有哪些通用标准和判断方法?
python
黎阳之光9 小时前
黎阳之光透明大楼:实景孪生重构智慧建筑全新范式
人工智能·物联网·算法·安全·数字孪生
步步为营DotNet9 小时前
Blazor 与 Microsoft.Extensions.AI 在客户端性能优化中的协同应用
人工智能·microsoft·性能优化
ayqy贾杰9 小时前
SpaceX 收购 Cursor,马斯克花600亿美元买了个代码编辑器
前端·人工智能·机器学习
JAMSAN09309 小时前
机器人轴承:被低估的“物理关节”,正在打开300倍增长空间
数据库·人工智能·机器人·智能硬件
财经资讯数据_灵砚智能9 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年6月16日
人工智能·python·ai·信息可视化·自然语言处理·ai编程·灵砚智能
“码”力全开9 小时前
解耦与重塑:基于 Docker 容器化与 GB28181/RTSP 统一接入的 AI 视频管理平台架构解析(支持源码交付与边缘计算)
人工智能·docker·边缘计算
小宋102110 小时前
4 万 Star 的开源 ChatGPT 桌面端:用 Jan 把电脑变成离线 AI 工作站
人工智能·chatgpt·开源·jan
searchforAI10 小时前
啥是LLM?大语言模型从原理到选型的完整科普
人工智能·科技·深度学习·ai·语言模型·知识图谱·agent
我就是全世界10 小时前
具身智能难现“ChatGPT时刻”:缺统一范式,更缺优质数据
人工智能·chatgpt·机器人