CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务

一、任务背景

关于应用:(补)CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务的应用-CSDN博客

流程:神经网络训练核心循环流程图资源-CSDN下载

我们要解决的是CIFAR10 图像分类问题

  • CIFAR10 是一个经典数据集,包含 10 个类别的 32×32 彩色图像(飞机、汽车、鸟等)
  • 任务目标:++训练一个神经网络,输入图像后能正确预测其类别++

二、整体流程框架

任何深度学习任务的核心流程都可以总结为:

  1. 准备数据(加载、预处理、封装)
  2. 定义模型(神经网络结构)
  3. 设置训练工具(损失函数、优化器)
  4. 训练模型(循环迭代、反向传播、参数更新)
  5. 测试模型(评估精度、可视化结果)
  6. 保存模型

三、分步详解

1. 准备工作:环境与库

首先需要安装必要的库:

python 复制代码
pip install torch torchvision tensorboard  # PyTorch核心库、视觉工具库、可视化工具

代码中导入的库作用:

  • torch:PyTorch 核心,用于张量运算和神经网络
  • torchvision:提供经典数据集(如 CIFAR10)和图像预处理工具
  • torch.nn:神经网络模块(包含卷积、全连接等层)
  • DataLoader:批量加载数据的工具
  • tensorboard:可视化训练过程(损失、精度等)

2. 模型定义(model.py

这部分定义了一个卷积神经网络(CNN),用于提取图像特征并分类。

核心思路

图像分类任务中,CNN 通过卷积层提取空间特征,池化层压缩数据,全连接层完成分类。CIFAR10 是 32×32 的彩色图(3 通道),最终要分为 10 类,因此输出层维度为 10。

python 复制代码
import torch
from torch import nn  # 导入神经网络模块

class Prayer(nn.Module):  # 继承nn.Module(所有神经网络的基类)
    def __init__(self):
        super(Prayer, self).__init__()  # 初始化父类
        # 定义网络序列(用nn.Sequential封装层,简化代码)
        self.module = nn.Sequential(
            # 第1组:卷积+池化
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            # 卷积层:输入3通道(彩色图),输出32通道(32个特征图),5×5卷积核,步长1,padding=2(保持尺寸不变)
            nn.MaxPool2d(kernel_size=2),  # 2×2最大池化,尺寸减半(32×32→16×16)
            
            # 第2组:卷积+池化
            nn.Conv2d(32, 32, 5, 1, 2),  # 输入32通道,输出32通道,参数同上
            nn.MaxPool2d(2),  # 尺寸减半(16×16→8×8)
            
            # 第3组:卷积+池化
            nn.Conv2d(32, 64, 5, 1, 2),  # 输入32通道,输出64通道
            nn.MaxPool2d(2),  # 尺寸减半(8×8→4×4)
            
            # 分类头:全连接层
            nn.Flatten(),  # 展平特征图(64通道×4×4 → 64*4*4=1024维向量)
            nn.Linear(64*4*4, 64),  # 全连接层:1024→64(压缩特征)
            nn.Linear(64, 10)  # 输出层:64→10(对应10个类别)
        )
    
    def forward(self, x):  # 前向传播(必须实现)
        x = self.module(x)  # 输入经过网络序列处理
        return x

# 测试模型输入输出是否正确(调试用)
if __name__ == "__main__":
    prayer = Prayer()
    input = torch.ones(64, 3, 32, 32)  # 模拟输入:64个样本,3通道,32×32
    output = prayer(input)
    print(output.shape)  # 输出应为(64, 10),与预期一致则模型结构正确

3. 训练过程(train.py

这部分是核心,实现数据加载、训练循环、测试评估等完整流程。

步骤 1:加载数据集

python 复制代码
# 加载CIFAR10训练集和测试集
train_data = torchvision.datasets.CIFAR10(
    root='./data',  # 数据保存路径
    train=True,     # True=训练集,False=测试集
    transform=torchvision.transforms.ToTensor(),  # 转换为Tensor(0-255→0-1,维度从HWC→CHW)
    download=True   # 若本地没有数据则自动下载
)
test_data = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False, 
    transform=torchvision.transforms.ToTensor(), 
    download=True
)

步骤 2:查看数据集大小

python 复制代码
train_data_size = len(train_data)  # CIFAR10训练集共50000个样本
test_data_size = len(test_data)    # 测试集共10000个样本
print(f"训练数据集的长度为:{train_data_size}")  # 注意:原代码的format语法错误,应改为f-string
print(f"测试数据集的长度为:{test_data_size}")

步骤 3:用 DataLoader 批量加载数据

python 复制代码
train_dataloader = DataLoader(train_data, batch_size=64)  # 训练集:每次加载64个样本
test_dataloader = DataLoader(test_data, batch_size=64)    # 测试集:每次加载64个样本
  • batch_size=64:每次训练 / 测试用 64 个样本,避免一次性加载全部数据导致内存不足
  • DataLoader 会自动打乱训练集顺序(默认shuffle=True)

步骤 4:初始化模型、损失函数和优化器

python 复制代码
# 创建网络模型
prayer = Prayer()  # 实例化我们定义的Prayer模型

# 损失函数:分类任务常用交叉熵损失
loss_fn = nn.CrossEntropyLoss()  # 内置了Softmax,直接输入原始输出即可

# 优化器:用于更新模型参数(这里用SGD)
learning_rate = 1e-2  # 学习率:控制参数更新幅度
optimizer = torch.optim.SGD(prayer.parameters(), lr=learning_rate)  # 优化模型所有参数

深度学习中的优化器-CSDN博客

步骤 5:设置训练参数

python 复制代码
total_train_step = 0  # 记录总训练步数(每个batch算一步)
total_test_step = 0   # 记录总测试轮数
epoch = 10  # 训练轮数:整个数据集循环10次

步骤 6:训练循环(核心部分)

每一轮训练包含两个阶段:训练阶段(更新参数)和测试阶段(评估性能)。

python 复制代码
for i in range(epoch):  # 循环10轮
    print(f"------第{i+1}轮训练开始------")  # 原代码写的"测试"是笔误
    
    # 训练阶段:开启模型训练模式(对Dropout、BN等层有效)
    prayer.train()
    for data in train_dataloader:  # 遍历训练集中的每个batch
        imgs, targets = data  # 解包:imgs是图像数据,targets是标签(0-9)
        
        # 前向传播:计算模型输出
        outputs = prayer(imgs)  # 注意:原代码导入了outputs变量,这里会冲突,应删除from nn_loss_network import outputs
        
        # 计算损失
        loss = loss_fn(outputs, targets)  # 输出与标签的差距
        
        # 反向传播+参数更新
        optimizer.zero_grad()  # 清空上一轮的梯度(否则会累积)
        loss.backward()        # 计算梯度(链式法则)
        optimizer.step()       # 根据梯度更新参数
        
        # 记录并打印训练信息
        total_train_step += 1
        if total_train_step % 100 == 0:  # 每100步打印一次(原代码每步都打,会刷屏)
            print(f"训练次数:{total_train_step}, loss:{loss.item()}")  # loss.item()获取数值
        
        # 用tensorboard记录损失
        writer.add_scalar('train_loss', loss.item(), total_train_step)  # 原代码未初始化writer,需补充:writer = SummaryWriter("logs")

步骤 7:测试阶段(每轮训练后评估)

python 复制代码
# 测试阶段:开启评估模式(固定Dropout、BN等层)
prayer.eval()
total_test_loss = 0  # 记录测试集总损失
total_accuracy = 0   # 记录正确预测的样本数
with torch.no_grad():  # 关闭梯度计算(节省内存,加快速度)
    for data in test_dataloader:
        imgs, targets = data
        outputs = prayer(imgs)  # 模型预测
        
        # 计算测试损失
        loss = loss_fn(outputs, targets)
        total_test_loss += loss.item()
        
        # 计算准确率:预测类别(outputs.argmax(1))与真实标签一致的数量
        accuracy = (outputs.argmax(1) == targets).sum()  # 逐元素比较,求和得到正确数
        total_accuracy += accuracy

# 打印测试结果
print(f"整体测试集上的Loss:{total_test_loss}")
print(f"整体测试集上的正确率:{total_accuracy / test_data_size}")  # 正确数/总样本数

# 记录测试指标到tensorboard
writer.add_scalar('test_loss', total_test_loss, total_test_step)
writer.add_scalar('test_accuracy', total_accuracy / test_data_size, total_test_step)
total_test_step += 1

步骤 8:保存模型

python 复制代码
torch.save(prayer, f"prayer_{i}.pth")  
print("模型已保存")

步骤 9:关闭 tensorboard 写入器

python 复制代码
writer.close()

四、复现步骤(从 0 开始)

1.创建文件夹结构

python 复制代码
project/
├── model.py       # 模型定义
├── train.py       # 训练代码
└── data/          # 存放CIFAR10数据(自动生成)

2.编写 model.py

确保测试输入输出形状正确

python 复制代码
import torch
from torch import nn  # 导入神经网络模块

class Prayer(nn.Module):  # 继承nn.Module(所有神经网络的基类)
    def __init__(self):
        super(Prayer, self).__init__()  # 初始化父类
        # 定义网络序列(用nn.Sequential封装层,简化代码)
        self.module = nn.Sequential(
            # 第1组:卷积+池化
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            # 卷积层:输入3通道(彩色图),输出32通道(32个特征图),5×5卷积核,步长1,padding=2(保持尺寸不变)
            nn.MaxPool2d(kernel_size=2),  # 2×2最大池化,尺寸减半(32×32→16×16)
            
            # 第2组:卷积+池化
            nn.Conv2d(32, 32, 5, 1, 2),  # 输入32通道,输出32通道,参数同上
            nn.MaxPool2d(2),  # 尺寸减半(16×16→8×8)
            
            # 第3组:卷积+池化
            nn.Conv2d(32, 64, 5, 1, 2),  # 输入32通道,输出64通道
            nn.MaxPool2d(2),  # 尺寸减半(8×8→4×4)
            
            # 分类头:全连接层
            nn.Flatten(),  # 展平特征图(64通道×4×4 → 64*4*4=1024维向量)
            nn.Linear(64*4*4, 64),  # 全连接层:1024→64(压缩特征)
            nn.Linear(64, 10)  # 输出层:64→10(对应10个类别)
        )
    
    def forward(self, x):  # 前向传播(必须实现)
        x = self.module(x)  # 输入经过网络序列处理
        return x

# 测试模型输入输出是否正确(调试用)
if __name__ == "__main__":
    prayer = Prayer()
    input = torch.ones(64, 3, 32, 32)  # 模拟输入:64个样本,3通道,32×32
    output = prayer(input)
    print(output.shape)  # 输出应为(64, 10),与预期一致则模型结构正确

(运行model.py应输出torch.Size([64, 10]))。

3.编写 train.py

python 复制代码
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter  # 修正:添加SummaryWriter

from model import Prayer

# 初始化tensorboard
writer = SummaryWriter("logs")

# 准备数据集
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)

# 数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为:{train_data_size}")
print(f"测试数据集的长度为:{test_data_size}")

# 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)  # 修正:用test_data

# 创建模型
prayer = Prayer()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(prayer.parameters(), lr=learning_rate)

# 训练参数
total_train_step = 0
total_test_step = 0
epoch = 10

for i in range(epoch):
    print(f"------第{i+1}轮训练开始------")
    
    # 训练
    prayer.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = prayer(imgs)  # 修正:删除冲突的outputs导入
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:  # 每100步打印一次
            print(f"训练次数:{total_train_step}, loss:{loss.item()}")
        writer.add_scalar('train_loss', loss.item(), total_train_step)

    # 测试
    prayer.eval()
    total_test_loss = 0
    total_accuracy = 0  # 修正:每轮测试前初始化
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = prayer(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    print(f"整体测试集上的Loss:{total_test_loss}")
    print(f"整体测试集上的正确率:{total_accuracy / test_data_size}")
    writer.add_scalar('test_loss', total_test_loss, total_test_step)
    writer.add_scalar('test_accuracy', total_accuracy / test_data_size, total_test_step)
    total_test_step += 1

    # 保存模型
    torch.save(prayer, f"prayer_{i}.pth")  # 修正拼写错误
    print("模型已保存")

writer.close()

4.运行训练代码

python 复制代码
python train.py

5.查看训练可视化结果

打开新终端,运行:

python 复制代码
tensorboard --logdir=logs  # logs是保存日志的文件夹

在浏览器访问提示的地址(通常是http://localhost:6006),可查看损失和准确率曲线。

六、关键知识点总结

  1. 数据处理torchvision.datasets加载经典数据集,DataLoader实现批量加载
  2. 模型定义 :继承nn.Module,实现__init__(定义层)和forward(前向传播)
  3. 训练三要素
    • 模型输出:outputs = model(imgs)
    • 损失计算:loss = loss_fn(outputs, targets)
    • 参数更新:optimizer.zero_grad()loss.backward()optimizer.step()
  4. 测试技巧model.eval()torch.no_grad()关闭梯度计算,提高效率
  5. 可视化tensorboard记录训练过程,便于分析模型性能

通过以上步骤,可以完全复现这个 CIFAR10 分类任务。训练完成后,模型文件(prayer_0.pth等)可用于后续的图像预测。

七、测试结果

复制代码
训练数据集的长度为:50000
测试数据集的长度为:10000
------第1轮训练开始------
训练次数:100, loss:2.292586326599121
训练次数:200, loss:2.2810027599334717
训练次数:300, loss:2.2674713134765625
训练次数:400, loss:2.188361644744873
训练次数:500, loss:2.067962646484375
训练次数:600, loss:2.0249435901641846
训练次数:700, loss:2.0278453826904297
整体测试集上的Loss:318.6863216161728
整体测试集上的正确率:0.2669000029563904
模型已保存
------第2轮训练开始------
训练次数:800, loss:1.8828389644622803
训练次数:900, loss:1.8570338487625122
训练次数:1000, loss:1.9429094791412354
训练次数:1100, loss:1.957800269126892
训练次数:1200, loss:1.7039369344711304
训练次数:1300, loss:1.6572980880737305
训练次数:1400, loss:1.7372108697891235
训练次数:1500, loss:1.8313004970550537
整体测试集上的Loss:298.9056884050369
整体测试集上的正确率:0.32670000195503235
模型已保存
------第3轮训练开始------
训练次数:1600, loss:1.7614809274673462
训练次数:1700, loss:1.6178432703018188
训练次数:1800, loss:1.9279940128326416
训练次数:1900, loss:1.6716428995132446
训练次数:2000, loss:1.9582788944244385
训练次数:2100, loss:1.4907219409942627
训练次数:2200, loss:1.4549304246902466
训练次数:2300, loss:1.8031885623931885
整体测试集上的Loss:261.5988036394119
整体测试集上的正确率:0.4001999795436859
模型已保存
------第4轮训练开始------
训练次数:2400, loss:1.7241209745407104
训练次数:2500, loss:1.356602668762207
训练次数:2600, loss:1.6022555828094482
训练次数:2700, loss:1.6960515975952148
训练次数:2800, loss:1.4876827001571655
训练次数:2900, loss:1.6000099182128906
训练次数:3000, loss:1.353346586227417
训练次数:3100, loss:1.5333189964294434
整体测试集上的Loss:257.71949791908264
整体测试集上的正确率:0.4041999876499176
模型已保存
------第5轮训练开始------
训练次数:3200, loss:1.360127329826355
训练次数:3300, loss:1.4818179607391357
训练次数:3400, loss:1.4619064331054688
训练次数:3500, loss:1.562610149383545
训练次数:3600, loss:1.5666663646697998
训练次数:3700, loss:1.3336483240127563
训练次数:3800, loss:1.2914965152740479
训练次数:3900, loss:1.451879858970642
整体测试集上的Loss:255.3749772310257
整体测试集上的正确率:0.4197999835014343
模型已保存
------第6轮训练开始------
训练次数:4000, loss:1.370840311050415
训练次数:4100, loss:1.4321507215499878
训练次数:4200, loss:1.5426149368286133
训练次数:4300, loss:1.2162929773330688
训练次数:4400, loss:1.1133944988250732
训练次数:4500, loss:1.3967304229736328
训练次数:4600, loss:1.3962610960006714
整体测试集上的Loss:240.8339524269104
整体测试集上的正确率:0.4487999975681305
模型已保存
------第7轮训练开始------
训练次数:4700, loss:1.3106122016906738
训练次数:4800, loss:1.5192369222640991
训练次数:4900, loss:1.4253202676773071
训练次数:5000, loss:1.4319653511047363
训练次数:5100, loss:0.996913492679596
训练次数:5200, loss:1.3940272331237793
训练次数:5300, loss:1.177820086479187
训练次数:5400, loss:1.3661075830459595
整体测试集上的Loss:229.9147927761078
整体测试集上的正确率:0.47599998116493225
模型已保存
------第8轮训练开始------
训练次数:5500, loss:1.241894006729126
训练次数:5600, loss:1.2053898572921753
训练次数:5700, loss:1.2199749946594238
训练次数:5800, loss:1.2041431665420532
训练次数:5900, loss:1.3678442239761353
训练次数:6000, loss:1.5297755002975464
训练次数:6100, loss:1.0360474586486816
训练次数:6200, loss:1.1391490697860718
整体测试集上的Loss:218.75120985507965
整体测试集上的正确率:0.5063999891281128
模型已保存
------第9轮训练开始------
训练次数:6300, loss:1.4089828729629517
训练次数:6400, loss:1.1370986700057983
训练次数:6500, loss:1.5074737071990967
训练次数:6600, loss:1.076493501663208
训练次数:6700, loss:1.0735267400741577
训练次数:6800, loss:1.1444426774978638
训练次数:6900, loss:1.110133409500122
训练次数:7000, loss:0.8954514265060425
整体测试集上的Loss:207.2557065486908
整体测试集上的正确率:0.5336999893188477
模型已保存
------第10轮训练开始------
训练次数:7100, loss:1.2467191219329834
训练次数:7200, loss:0.9513144493103027
训练次数:7300, loss:1.1857160329818726
训练次数:7400, loss:0.8370732665061951
训练次数:7500, loss:1.2208924293518066
训练次数:7600, loss:1.1963437795639038
训练次数:7700, loss:0.855516254901886
训练次数:7800, loss:1.2431912422180176
整体测试集上的Loss:198.26297962665558
整体测试集上的正确率:0.5557999610900879
模型已保存

Process finished with exit code 0
相关推荐
王嘉俊9254 小时前
HarmonyOS 分布式与 AI 集成:构建智能协同应用的进阶实践
人工智能·分布式·harmonyos
赋创小助手4 小时前
实测对比 32GB RTX 5090 与 48GB RTX 4090,多场景高并发测试,全面解析 AI 服务器整机性能与显存差异。
运维·服务器·人工智能·科技·深度学习·神经网络·自然语言处理
阿水实证通4 小时前
能源经济选题推荐:可再生能源转型政策如何提高能源韧性?基于双重机器学习的因果推断
人工智能·机器学习·能源
掘金安东尼4 小时前
大模型嵌入浏览器,Atlas 和 Gemini 将带来怎样的变革?
人工智能
亚马逊云开发者4 小时前
基于Amazon Bedrock的TwelveLabs Marengo Embed 2.7多模态搜索系统
人工智能
Geoking.4 小时前
深度学习基础:Tensor(张量)的创建方法详解
人工智能·深度学习
海拥4 小时前
合合信息推出“多模态文本智能技术”:让AI真正理解与守护信息
人工智能
suke4 小时前
LLM入局,OCR换代:DeepSeek与PaddleOCR-VL等LLM-OCR引领的文档理解新浪潮
人工智能·程序员·开源
良策金宝AI4 小时前
良策金宝AI实战录:效率如何从口号照进现实?
人工智能·工程设计