PyTorch——优化器(9)

优化器根据梯度调整参数,以达到降低误差

python 复制代码
import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)

# 定义卷积神经网络模型
class TY(nn.Module):
    def __init__(self):
        super(TY, self).__init__()
        # 构建网络结构:3个卷积层+池化层组合,2个全连接层
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5
            MaxPool2d(2),                   # 最大池化,步长2
            Conv2d(32, 32, 5, padding=2),   # 第二层卷积
            MaxPool2d(2),                   # 第二次池化
            Conv2d(32, 64, 5, padding=2),   # 第三层卷积
            MaxPool2d(2),                   # 第三次池化
            Flatten(),                      # 将多维张量展平为向量
            Linear(1024, 64),               # 全连接层,输入1024维,输出64维
            Linear(64, 10),                 # 输出层,10个类别对应10个输出
        )

    def forward(self, x):
        # 定义前向传播路径
        x = self.model1(x)
        return x

# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)

# 训练20个完整轮次
for epoch in range(20):
    running_loss = 0.0  # 初始化本轮累计损失
    
    # 遍历数据加载器中的每个批次
    for data in dataloader:
        imgs, targets = data  # 获取图像和标签
        outputs = ty(imgs)    # 前向传播
        result_loss = loss(outputs, targets)  # 计算损失
        
        optim.zero_grad()     # 梯度清零,防止累积
        result_loss.backward()  # 反向传播计算梯度
        optim.step()          # 更新模型参数
        
        running_loss += result_loss  # 累加损失值
    
    # 打印本轮训练的累计损失
    print(f"Epoch {epoch+1}, Loss: {running_loss}")
相关推荐
zzc92140 分钟前
时频图数据集更正程序,去除坐标轴白边及调整对应的标签值
人工智能·深度学习·数据集·标签·时频图·更正·白边
Blossom.1182 小时前
机器学习在智能供应链中的应用:需求预测与物流优化
人工智能·深度学习·神经网络·机器学习·计算机视觉·机器人·语音识别
Gyoku Mint2 小时前
深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹
人工智能·pytorch·python·深度学习·神经网络·算法·聚类
m0_751336394 小时前
突破性进展:超短等离子体脉冲实现单电子量子干涉,为飞行量子比特奠定基础
人工智能·深度学习·量子计算·材料科学·光子器件·光子学·无线电电子
有Li7 小时前
通过具有一致性嵌入的大语言模型实现端到端乳腺癌放射治疗计划制定|文献速递-最新论文分享
论文阅读·深度学习·分类·医学生
郭庆汝8 小时前
pytorch、torchvision与python版本对应关系
人工智能·pytorch·python
叶子爱分享11 小时前
计算机视觉与图像处理的关系
图像处理·人工智能·计算机视觉
张较瘦_11 小时前
[论文阅读] 人工智能 | 深度学习系统崩溃恢复新方案:DaiFu框架的原位修复技术
论文阅读·人工智能·深度学习
cver12311 小时前
野生动物检测数据集介绍-5,138张图片 野生动物保护监测 智能狩猎相机系统 生态研究与调查
人工智能·pytorch·深度学习·目标检测·计算机视觉·目标跟踪
学技术的大胜嗷11 小时前
离线迁移 Conda 环境到 Windows 服务器:用 conda-pack 摆脱硬路径限制
人工智能·深度学习·yolo·目标检测·机器学习