基于 PyTorch 的 CIFAR-10 图像分类学习总结

在本次学习中,我通过 PyTorch 实现了一个基于 CNN 的 CIFAR-10 图像分类模型,完整掌握了从数据加载、模型构建到训练评估的全流程。以下是具体学习内容总结,包含关键代码实现:

CIFAR-10图像分类完整代码

V1

创建时间:15:53

关键步骤解析

1. 数据处理

数据处理是深度学习任务的基础,主要包括:

  • 数据转换:使用transforms将图像转为张量并归一化,使模型更容易学习
  • 数据集加载:利用torchvision.datasets加载 CIFAR-10 数据集
  • 数据加载器:通过DataLoader实现批处理、打乱数据和多进程加载
  • 数据可视化:编写imshow函数直观查看数据,验证数据加载是否正确

2. 模型构建

CNN 模型是处理图像任务的有效工具,本模型结构包括:

  • 卷积层:使用nn.Conv2d提取图像特征,通过卷积核捕获局部特征
  • 池化层:使用nn.MaxPool2d降低特征图维度,减少计算量并增强鲁棒性
  • 全连接层:使用nn.Linear实现最终分类,将提取的特征映射到 10 个类别
  • 激活函数:使用 ReLU 增加模型非线性表达能力,解决梯度消失问题

3. 模型训练

训练过程是模型学习的核心,主要步骤包括:

  • 损失函数:选择交叉熵损失函数,适合多分类任务
  • 优化器:使用 SGD 优化器,通过学习率和动量控制参数更新
  • 训练循环:多轮迭代训练,每个批次包括前向传播、损失计算、反向传播和参数更新
  • 设备加速:自动检测 GPU 并利用 CUDA 加速训练过程

4. 模型评估

通过测试集验证模型性能:

  • 加载测试数据并可视化
  • 使用训练好的模型进行预测
  • 对比预测结果与真实标签,直观评估模型分类效果

学习心得

  1. 数据预处理对模型性能影响很大,合适的归一化能加速模型收敛
  2. 网络结构设计需要平衡复杂度和计算效率,过深或过浅的网络都可能影响性能
  3. 训练过程中的超参数(如学习率、批次大小、训练轮次)需要根据实际情况调整
  4. GPU 加速能显著提高训练速度,特别是对于图像等数据量大的任务
  5. 可视化是调试和理解模型的有效手段,有助于发现数据或模型中的问题

通过这个实例,我掌握了 PyTorch 的基本使用方法和 CNN 图像分类的完整流程,为后续更复杂的深度学习任务打下了基础。

基于PyTorch的CIFAR-10图像分类实现

一、准备工作:导入必要的库

import torch

import torchvision

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import numpy as np

import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

二、数据加载与预处理

1. 定义数据转换:将图像转为张量并归一化

transform = transforms.Compose(

[transforms.ToTensor(), # 转换为PyTorch张量

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] # 归一化到-1, 1范围

)

2. 加载训练集和测试集

trainset = torchvision.datasets.CIFAR10(

root='./data', # 数据存储路径

train=True, # 训练集

download=True, # 如果本地没有数据则下载

transform=transform # 应用数据转换

)

testset = torchvision.datasets.CIFAR10(

root='./data',

train=False, # 测试集

download=True,

transform=transform

)

3. 创建数据加载器

trainloader = torch.utils.data.DataLoader(

trainset,

batch_size=4, # 批处理大小

shuffle=True, # 训练时打乱数据顺序

num_workers=2 # 多进程加载数据

)

testloader = torch.utils.data.DataLoader(

testset,

batch_size=4,

shuffle=False, # 测试时不打乱顺序

num_workers=2

)

4. 定义类别标签

classes = ('plane', 'car', 'bird', 'cat',

'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

三、数据可视化

定义图像显示函数

def imshow(img):

img = img / 2 + 0.5 # 反归一化

npimg = img.numpy() # 转换为numpy数组

调整通道顺序:从(C, H, W)转为(H, W, C)

plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.show()

获取一些随机的训练图像

dataiter = iter(trainloader)

images, labels = next(dataiter)

显示图像

imshow(torchvision.utils.make_grid(images))

打印标签

print(' '.join(f'{classeslabels\[j]:5s}' for j in range(4)))

四、构建CNN模型

class CNNNet(nn.Module):

def init(self):

super(CNNNet, self).init()

第一个卷积层:3输入通道,16输出通道,5x5卷积核

self.conv1 = nn.Conv2d(3, 16, 5)

第一个池化层:2x2池化核,步长为2

self.pool = nn.MaxPool2d(2, 2)

第二个卷积层:16输入通道,36输出通道,3x3卷积核

self.conv2 = nn.Conv2d(16, 36, 3)

第一个全连接层

self.fc1 = nn.Linear(36 * 6 * 6, 128)

第二个全连接层(输出层,10个类别)

self.fc2 = nn.Linear(128, 10)

def forward(self, x):

第一个卷积块:卷积->ReLU->池化

x = self.pool(F.relu(self.conv1(x)))

第二个卷积块:卷积->ReLU->池化

x = self.pool(F.relu(self.conv2(x)))

展平特征图

x = x.view(-1, 36 * 6 * 6)

第一个全连接层->ReLU

x = F.relu(self.fc1(x))

输出层

x = self.fc2(x)

return x

实例化模型并移动到可用设备(GPU/CPU)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = CNNNet()

net.to(device)

五、定义损失函数和优化器

criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,适合多分类任务

optimizer = optim.SGD( # 随机梯度下降优化器

net.parameters(),

lr=0.001, # 学习率

momentum=0.9 # 动量参数

)

六、训练模型

for epoch in range(10): # 训练10个epoch

running_loss = 0.0

for i, data in enumerate(trainloader, 0):

获取输入数据和标签,并移动到设备

inputs, labels = data0.to(device), data1.to(device)

清零梯度

optimizer.zero_grad()

前向传播、计算损失、反向传播、参数更新

outputs = net(inputs)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

打印训练状态

running_loss += loss.item()

if i % 2000 == 1999: # 每2000个批次打印一次

print(f'{epoch + 1}, {i + 1:5d} loss: {running_loss / 2000:.3f}')

running_loss = 0.0

print('Finished Training')

七、模型测试

在测试集上进行预测

dataiter = iter(testloader)

images, labels = next(dataiter)

显示测试图像

imshow(torchvision.utils.make_grid(images))

print('GroundTruth: ', ' '.join(f'{classeslabels\[j]:5s}' for j in range(4)))

进行预测

images, labels = images.to(device), labels.to(device)

outputs = net(images)

_, predicted = torch.max(outputs, 1) # 获取预测概率最大的类别

print('Predicted: ', ' '.join(f'{classespredicted\[j]:5s}' for j in range(4)))

相关推荐
菜板春9 小时前
jupyter入门-手册-特征探索
python·jupyter
Metaphor6929 小时前
使用 Python 将 PDF 转换为 HTML
python·pdf·html
极光代码工作室9 小时前
基于数据仓库的电商数据分析平台
大数据·hadoop·python·spark·数据可视化
开发小能手-roy9 小时前
StringBuilder vs StringBuffer:2024年还需要线程安全字符串吗?
开发语言·python·安全
AC赳赳老秦10 小时前
用 OpenClaw 搭建服务器故障应急响应系统,自动处理 80% 常见运维故障
android·运维·服务器·python·rxjava·deepseek·openclaw
2601_9547064910 小时前
云手机技术详解+Python实战调用|2026高稳云手机平台推荐
开发语言·python·智能手机
chushiyunen10 小时前
java中的路径处理、左右斜杠
java·开发语言·python
jay神10 小时前
基于 FastAPI + Vue 的宠物领养管理系统
前端·vue.js·python·毕业设计·fastapi·宠物
程序员小远11 小时前
自动化测试基础知识总结
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
GEO优化小助手11 小时前
2026临沂GEO优化公司实测解析:3家本土机构适配性参考
大数据·人工智能·python