112_深度学习的导航仪:PyTorch 优化器(Optimizer)全解析

在经历了前向传播计算 Loss、反向传播计算梯度(Gradient)后,我们来到了最关键的一步:更新参数。优化器就像是一位经验丰富的导航员,它根据梯度指示的方向,决定如何调整模型的权重,使 Loss 降到最低。

1. 优化器的核心逻辑

优化器的主要工作包含以下三个步骤,缺一不可:

  1. 梯度清零 ( zero_grad**)**:在每一轮计算开始前,必须把之前的梯度清空,否则梯度会不断累加,导致训练出错。
  2. 反向传播 ( backward**)**:计算当前误差对每个参数的梯度。
  3. 参数更新 ( step**)**:根据选定的算法(如 SGD、Adam)和梯度值,实际修改网络中的权重。

2. 实战代码:神经网络的完整训练循环

通过 CIFAR-10 数据集演示了如何使用 SGD(随机梯度下降) 优化器进行多轮(Epoch)训练。

复制代码
import torch
import torchvision
from torch import nn 
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)       
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x
    
loss = nn.CrossEntropyLoss() # 交叉熵    
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(),lr=0.01)   # 随机梯度下降优化器
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距
        optim.zero_grad()  # 梯度清零
        result_loss.backward() # 反向传播,计算损失函数的梯度
        optim.step()   # 根据梯度,对网络的参数进行调优
        running_loss = running_loss + result_loss
    print(running_loss) # 对这一轮所有误差的总和

3. 进阶:学习率调整策略 (LR Scheduler)

文件中还提到了一个进阶工具:StepLR

  • 为什么要调整学习率? 训练初期我们希望走得快(学习率大),训练后期为了精准落入最低点,我们需要走得慢(学习率小)。

  • 代码实现

    import torch
    import torchvision
    from torch import nn
    from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter

    dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
    dataloader = DataLoader(dataset, batch_size=64,drop_last=True)

    class Tudui(nn.Module):
    def init(self):
    super(Tudui, self).init()
    self.model1 = Sequential(
    Conv2d(3,32,5,padding=2),
    MaxPool2d(2),
    Conv2d(32,32,5,padding=2),
    MaxPool2d(2),
    Conv2d(32,64,5,padding=2),
    MaxPool2d(2),
    Flatten(),
    Linear(1024,64),
    Linear(64,10)
    )

    复制代码
      def forward(self, x):
          x = self.model1(x)
          return x

    loss = nn.CrossEntropyLoss() # 交叉熵
    tudui = Tudui()
    optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # 随机梯度下降优化器
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.1) # 每过 step_size 更新一次优化器,更新是学习率为原来的学习率的的 0.1 倍
    for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距
    optim.zero_grad() # 梯度清零
    result_loss.backward() # 反向传播,计算损失函数的梯度
    optim.step() # 根据梯度,对网络的参数进行调优
    scheduler.step() # 学习率太小了,所以20个轮次后,相当于没走多少
    running_loss = running_loss + result_loss
    print(running_loss) # 对这一轮所有误差的总和


4. 总结:训练全流程闭环

分析完此文件后,我们终于完成了 PyTorch 训练的完整拼图:

  1. 准备数据 (Dataset & DataLoader)
  2. 搭建结构 (nn.Module & Sequential)
  3. 衡量误差 (Loss Function)
  4. 计算方向 (Backward)
  5. 调整参数 (Optimizer)

💡 学习心得

优化器的 lr(学习率)设置非常关键。设置过大,模型可能在最低点附近反复横跳无法收敛;设置过小,模型学习速度会极慢。在实际开发中,Adam 优化器由于其自适应学习率的特性,通常比 SGD 更容易上手。

相关推荐
云程笔记3 小时前
004.环境搭建基础篇:Python、CUDA、cuDNN、PyTorch/TensorFlow安装与版本兼容性踩坑
pytorch·python·tensorflow
逻辑君8 小时前
认知神经科学研究报告【20260010】
人工智能·深度学习·神经网络·机器学习
龙文浩_9 小时前
Attention Mechanism: From Theory to Code
人工智能·深度学习·神经网络·学习·自然语言处理
微臣愚钝9 小时前
prompt
人工智能·深度学习·prompt
宝贝儿好10 小时前
【LLM】第二章:文本表示:词袋模型、小案例:基于文本的推荐系统(酒店推荐)
人工智能·python·深度学习·神经网络·自然语言处理·机器人·语音识别
YBAdvanceFu11 小时前
从零构建智能体:深入理解 ReAct Plan Solve Reflection 三大经典范式
人工智能·python·机器学习·数据挖掘·多智能体·智能体
啦啦啦在冲冲冲11 小时前
多头注意力机制的优势是啥,遇到长文本的情况,可以从哪些情况优化呢
人工智能·深度学习
CV-杨帆12 小时前
ICLR 2026 LLM安全相关论文整理
人工智能·深度学习·安全
小程故事多_8012 小时前
从零吃透Transformer核心,多头注意力、残差连接与前馈网络(大白话完整版)
人工智能·深度学习·架构·aigc·transformer
AI应用实战 | RE13 小时前
012、检索器(Retrievers)核心:从向量库中智能查找信息
人工智能·算法·机器学习·langchain