基于 PyTorch 的 CIFAR-10 图像分类学习总结

在本次学习中,我通过 PyTorch 实现了一个基于 CNN 的 CIFAR-10 图像分类模型,完整掌握了从数据加载、模型构建到训练评估的全流程。以下是具体学习内容总结,包含关键代码实现:

CIFAR-10图像分类完整代码

V1

创建时间:15:53

关键步骤解析

1. 数据处理

数据处理是深度学习任务的基础,主要包括:

  • 数据转换:使用transforms将图像转为张量并归一化,使模型更容易学习
  • 数据集加载:利用torchvision.datasets加载 CIFAR-10 数据集
  • 数据加载器:通过DataLoader实现批处理、打乱数据和多进程加载
  • 数据可视化:编写imshow函数直观查看数据,验证数据加载是否正确

2. 模型构建

CNN 模型是处理图像任务的有效工具,本模型结构包括:

  • 卷积层:使用nn.Conv2d提取图像特征,通过卷积核捕获局部特征
  • 池化层:使用nn.MaxPool2d降低特征图维度,减少计算量并增强鲁棒性
  • 全连接层:使用nn.Linear实现最终分类,将提取的特征映射到 10 个类别
  • 激活函数:使用 ReLU 增加模型非线性表达能力,解决梯度消失问题

3. 模型训练

训练过程是模型学习的核心,主要步骤包括:

  • 损失函数:选择交叉熵损失函数,适合多分类任务
  • 优化器:使用 SGD 优化器,通过学习率和动量控制参数更新
  • 训练循环:多轮迭代训练,每个批次包括前向传播、损失计算、反向传播和参数更新
  • 设备加速:自动检测 GPU 并利用 CUDA 加速训练过程

4. 模型评估

通过测试集验证模型性能:

  • 加载测试数据并可视化
  • 使用训练好的模型进行预测
  • 对比预测结果与真实标签,直观评估模型分类效果

学习心得

  1. 数据预处理对模型性能影响很大,合适的归一化能加速模型收敛
  2. 网络结构设计需要平衡复杂度和计算效率,过深或过浅的网络都可能影响性能
  3. 训练过程中的超参数(如学习率、批次大小、训练轮次)需要根据实际情况调整
  4. GPU 加速能显著提高训练速度,特别是对于图像等数据量大的任务
  5. 可视化是调试和理解模型的有效手段,有助于发现数据或模型中的问题

通过这个实例,我掌握了 PyTorch 的基本使用方法和 CNN 图像分类的完整流程,为后续更复杂的深度学习任务打下了基础。

基于PyTorch的CIFAR-10图像分类实现

一、准备工作:导入必要的库

import torch

import torchvision

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import numpy as np

import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

二、数据加载与预处理

1. 定义数据转换:将图像转为张量并归一化

transform = transforms.Compose(

transforms.ToTensor(), # 转换为PyTorch张量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\] # 归一化到\[-1, 1\]范围 ) # 2. 加载训练集和测试集 trainset = torchvision.datasets.CIFAR10( root='./data', # 数据存储路径 train=True, # 训练集 download=True, # 如果本地没有数据则下载 transform=transform # 应用数据转换 ) testset = torchvision.datasets.CIFAR10( root='./data', train=False, # 测试集 download=True, transform=transform ) # 3. 创建数据加载器 trainloader = torch.utils.data.DataLoader( trainset, batch_size=4, # 批处理大小 shuffle=True, # 训练时打乱数据顺序 num_workers=2 # 多进程加载数据 ) testloader = torch.utils.data.DataLoader( testset, batch_size=4, shuffle=False, # 测试时不打乱顺序 num_workers=2 ) # 4. 定义类别标签 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 三、数据可视化 # 定义图像显示函数 def imshow(img): img = img / 2 + 0.5 # 反归一化 npimg = img.numpy() # 转换为numpy数组 # 调整通道顺序:从(C, H, W)转为(H, W, C) plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 获取一些随机的训练图像 dataiter = iter(trainloader) images, labels = next(dataiter) # 显示图像 imshow(torchvision.utils.make_grid(images)) # 打印标签 print(' '.join(f'{classes\[labels\[j\]\]:5s}' for j in range(4))) # 四、构建CNN模型 class CNNNet(nn.Module): def __init__(self): super(CNNNet, self).__init__() # 第一个卷积层:3输入通道,16输出通道,5x5卷积核 self.conv1 = nn.Conv2d(3, 16, 5) # 第一个池化层:2x2池化核,步长为2 self.pool = nn.MaxPool2d(2, 2) # 第二个卷积层:16输入通道,36输出通道,3x3卷积核 self.conv2 = nn.Conv2d(16, 36, 3) # 第一个全连接层 self.fc1 = nn.Linear(36 \* 6 \* 6, 128) # 第二个全连接层(输出层,10个类别) self.fc2 = nn.Linear(128, 10) def forward(self, x): # 第一个卷积块:卷积-\>ReLU-\>池化 x = self.pool(F.relu(self.conv1(x))) # 第二个卷积块:卷积-\>ReLU-\>池化 x = self.pool(F.relu(self.conv2(x))) # 展平特征图 x = x.view(-1, 36 \* 6 \* 6) # 第一个全连接层-\>ReLU x = F.relu(self.fc1(x)) # 输出层 x = self.fc2(x) return x # 实例化模型并移动到可用设备(GPU/CPU) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = CNNNet() net.to(device) # 五、定义损失函数和优化器 criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,适合多分类任务 optimizer = optim.SGD( # 随机梯度下降优化器 net.parameters(), lr=0.001, # 学习率 momentum=0.9 # 动量参数 ) # 六、训练模型 for epoch in range(10): # 训练10个epoch running_loss = 0.0 for i, data in enumerate(trainloader, 0): # 获取输入数据和标签,并移动到设备 inputs, labels = data\[0\].to(device), data\[1\].to(device) # 清零梯度 optimizer.zero_grad() # 前向传播、计算损失、反向传播、参数更新 outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 打印训练状态 running_loss += loss.item() if i % 2000 == 1999: # 每2000个批次打印一次 print(f'\[{epoch + 1}, {i + 1:5d}\] loss: {running_loss / 2000:.3f}') running_loss = 0.0 print('Finished Training') # 七、模型测试 # 在测试集上进行预测 dataiter = iter(testloader) images, labels = next(dataiter) # 显示测试图像 imshow(torchvision.utils.make_grid(images)) print('GroundTruth: ', ' '.join(f'{classes\[labels\[j\]\]:5s}' for j in range(4))) # 进行预测 images, labels = images.to(device), labels.to(device) outputs = net(images) _, predicted = torch.max(outputs, 1) # 获取预测概率最大的类别 print('Predicted: ', ' '.join(f'{classes\[predicted\[j\]\]:5s}' for j in range(4)))

相关推荐
A尘埃3 小时前
线性代数(标量与向量+矩阵与张量+矩阵求导)
python·线性代数·矩阵
数据牧羊人的成长笔记3 小时前
python爬虫进阶版练习(只说重点,selenium)
开发语言·chrome·python
databook3 小时前
Manim实现渐变填充特效
后端·python·动效
可触的未来,发芽的智生3 小时前
新奇特:神经网络的自洁之道,学会出淤泥而不染
人工智能·python·神经网络·算法·架构
计算机毕设残哥3 小时前
基于Hadoop+Spark的商店购物趋势分析与可视化系统技术实现
大数据·hadoop·python·scrapy·spark·django·dash
FserSuN3 小时前
python模块导入冲突问题笔记
开发语言·python
马诗剑3 小时前
使用 uv 在 Windows 上快速搭建 Python 开发环境
python
ZHOU_WUYI5 小时前
构建实时网络速度监控面板:Python Flask + SSE 技术详解
网络·python·flask
chinesegf5 小时前
conda虚拟环境直接复制依赖包可能会报错
python·conda