PyTorch 实现 CIFAR10 数据集的 CNN 分类实践

基于CIFAR10 数据集,用 PyTorch 完成了一个卷积神经网络(CNN)图像分类的实践,这里把核心流程分享出来~

一、数据加载与预处理

CIFAR10 数据集包含 10 类32×32 像素的彩色图像 (如飞机、汽车、动物等)。借助torchvision工具,我们完成数据加载与预处理:

  • 预处理操作 :用ToTensor()将图像转为 PyTorch 张量,再用Normalize()对张量做标准化(让数据分布更适合模型训练)。
  • 数据加载器 :通过DataLoader创建trainloader(训练集)和testloader(测试集),设置batch_size=4(每批处理 4 张图像)、shuffle=True(训练集打乱数据增强)等参数,方便后续批量训练 / 测试。

二、卷积神经网络(CNN)构建

定义了名为CNNNet的网络类,遵循「卷积 + 池化 + 全连接」的经典范式:

  • 卷积层 :共 2 层卷积(conv1conv2)。conv1用 16 个5×5卷积核提取底层特征;conv2用 36 个3×3卷积核提取更抽象的特征。
  • 池化层 :搭配 2 层最大池化(pool1pool2),核大小2×2,用于缩小特征图尺寸、减少计算量,同时保留关键特征。
  • 全连接层 :共 2 层全连接(fc1fc2)。先将卷积后的特征展平,通过fc1映射到 128 维特征空间,再通过fc2最终映射到10 个类别(匹配 CIFAR10 的类别数)。

此外,还统计了模型总参数量(约 17 万参数),便于直观了解模型复杂度。

三、模型训练流程

训练阶段选择SGD 优化器 (学习率0.001、动量0.9),损失函数用CrossEntropyLoss(适合多分类任务),核心步骤如下:

  1. 迭代训练 :循环多个epoch,每个epoch遍历全部训练数据。
  2. 梯度管理 :每次训练前用optimizer.zero_grad()清空梯度,避免梯度累积影响参数更新。
  3. 前向 + 反向传播 :模型对批量数据做前向传播生成预测,计算预测与真实标签的损失后,通过loss.backward()反向传播求梯度,再用optimizer.step()更新模型参数。
  4. 损失监控:每训练 2000 个 mini-batch,打印一次损失值,能观察到损失逐渐下降,说明模型在有效学习。

四、数据与结果可视化

为了直观理解数据和模型效果,还做了图像可视化

  • 展示训练集 / 测试集的样本图像,以及对应的类别标签(比如 "plane""car" 等),能快速了解数据样貌;
  • 后续也可基于此扩展,可视化模型的预测结果(比如对比 "真实标签" 和 "模型预测标签")。

整个流程覆盖了「数据准备→模型构建→训练→可视化」等深度学习图像分类的核心环节,是入门 PyTorch 与 CNN 的很好实践~

相关推荐
热爱生活的猴子2 小时前
使用bert或roberta模型做分类训练时,分类数据不平衡时,可以采取哪些优化的措施
人工智能·分类·bert
jie*2 小时前
小杰机器学习高级(five)——分类算法的评估标准
人工智能·python·深度学习·神经网络·机器学习·分类·回归
彭祥.5 小时前
点云-标注-分类-航线规划软件 (一)点云自动分类
人工智能·分类·数据挖掘
Teacher.chenchong5 小时前
PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化技术
pytorch·深度学习·分类
yaso_zhang8 小时前
jetpack6.1 的新 pytorch 2.5.1 版本在哪里?下载中心仅提供 pytorch v2.5.0a0。
人工智能·pytorch·python
这儿有一堆花10 小时前
从图像到精准文字:基于PyTorch与CTC的端到端手写文本识别实战
人工智能·pytorch·python
丰年稻香11 小时前
神经网络二分类任务详解:前向传播与反向传播的数学计算
人工智能·神经网络·分类
缘友一世12 小时前
PyTorch深度学习实战【12】之基于RNN的自然语言处理入门
pytorch·rnn·深度学习
青春不败 177-3266-052012 小时前
基于PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化实践技术应用
人工智能·pytorch·深度学习·目标检测·生态学·遥感