112_深度学习的导航仪:PyTorch 优化器(Optimizer)全解析

在经历了前向传播计算 Loss、反向传播计算梯度(Gradient)后,我们来到了最关键的一步:更新参数。优化器就像是一位经验丰富的导航员,它根据梯度指示的方向,决定如何调整模型的权重,使 Loss 降到最低。

1. 优化器的核心逻辑

优化器的主要工作包含以下三个步骤,缺一不可:

  1. 梯度清零 ( zero_grad**)**:在每一轮计算开始前,必须把之前的梯度清空,否则梯度会不断累加,导致训练出错。
  2. 反向传播 ( backward**)**:计算当前误差对每个参数的梯度。
  3. 参数更新 ( step**)**:根据选定的算法(如 SGD、Adam)和梯度值,实际修改网络中的权重。

2. 实战代码:神经网络的完整训练循环

通过 CIFAR-10 数据集演示了如何使用 SGD(随机梯度下降) 优化器进行多轮(Epoch)训练。

复制代码
import torch
import torchvision
from torch import nn 
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)       
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x
    
loss = nn.CrossEntropyLoss() # 交叉熵    
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(),lr=0.01)   # 随机梯度下降优化器
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距
        optim.zero_grad()  # 梯度清零
        result_loss.backward() # 反向传播,计算损失函数的梯度
        optim.step()   # 根据梯度,对网络的参数进行调优
        running_loss = running_loss + result_loss
    print(running_loss) # 对这一轮所有误差的总和

3. 进阶:学习率调整策略 (LR Scheduler)

文件中还提到了一个进阶工具:StepLR

  • 为什么要调整学习率? 训练初期我们希望走得快(学习率大),训练后期为了精准落入最低点,我们需要走得慢(学习率小)。

  • 代码实现

    import torch
    import torchvision
    from torch import nn
    from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter

    dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
    dataloader = DataLoader(dataset, batch_size=64,drop_last=True)

    class Tudui(nn.Module):
    def init(self):
    super(Tudui, self).init()
    self.model1 = Sequential(
    Conv2d(3,32,5,padding=2),
    MaxPool2d(2),
    Conv2d(32,32,5,padding=2),
    MaxPool2d(2),
    Conv2d(32,64,5,padding=2),
    MaxPool2d(2),
    Flatten(),
    Linear(1024,64),
    Linear(64,10)
    )

    复制代码
      def forward(self, x):
          x = self.model1(x)
          return x

    loss = nn.CrossEntropyLoss() # 交叉熵
    tudui = Tudui()
    optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # 随机梯度下降优化器
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.1) # 每过 step_size 更新一次优化器,更新是学习率为原来的学习率的的 0.1 倍
    for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距
    optim.zero_grad() # 梯度清零
    result_loss.backward() # 反向传播,计算损失函数的梯度
    optim.step() # 根据梯度,对网络的参数进行调优
    scheduler.step() # 学习率太小了,所以20个轮次后,相当于没走多少
    running_loss = running_loss + result_loss
    print(running_loss) # 对这一轮所有误差的总和


4. 总结:训练全流程闭环

分析完此文件后,我们终于完成了 PyTorch 训练的完整拼图:

  1. 准备数据 (Dataset & DataLoader)
  2. 搭建结构 (nn.Module & Sequential)
  3. 衡量误差 (Loss Function)
  4. 计算方向 (Backward)
  5. 调整参数 (Optimizer)

💡 学习心得

优化器的 lr(学习率)设置非常关键。设置过大,模型可能在最低点附近反复横跳无法收敛;设置过小,模型学习速度会极慢。在实际开发中,Adam 优化器由于其自适应学习率的特性,通常比 SGD 更容易上手。

相关推荐
vx_biyesheji00012 小时前
计算机毕业设计:Python全栈图书数据挖掘与可视化看板 Django框架 爬虫 当当图书 Pandas 可视化 大数据 大模型 书籍(建议收藏)✅
爬虫·python·机器学习·数据挖掘·django·毕业设计·课程设计
吴佳浩10 小时前
GPU 编号进阶:CUDA\_VISIBLE\_DEVICES、多进程与容器化陷阱
人工智能·pytorch·python
吴佳浩10 小时前
GPU 编号错乱踩坑指南:PyTorch cuda 编号与 nvidia-smi 不一致
人工智能·pytorch·nvidia
卧蚕土豆10 小时前
【有啥问啥】OpenClaw 安装与使用教程
人工智能·深度学习
AI科技星10 小时前
全尺度角速度统一:基于 v ≡ c 的纯推导与验证
c语言·开发语言·人工智能·opencv·算法·机器学习·数据挖掘
星空下的月光影子11 小时前
一维CNN在工业过程信号处理与故障预警中的应用
人工智能·机器学习
【建模先锋】12 小时前
创新首发!基于注意力机制优化的高创新故障诊断模型
深度学习·信号处理·故障诊断·特征融合·轴承故障诊断·fft变换·vmd分解
云上的云端14 小时前
vLLM-Ascend operator torchvision::nms does not exist 问题解决
人工智能·pytorch·深度学习
Zhansiqi14 小时前
dayy43
pytorch·python·深度学习