PyTorch 实现 CIFAR-10 图像分类:从基础 CNN 到全局平均池化的探索

引言

CIFAR-10 数据集包含 10 类、共 60000 张 32×32 的彩色图像,是计算机视觉领域入门图像分类任务 的经典数据集。本文基于 PyTorch 框架,完整演示从数据预处理CNN 模型构建模型训练与优化性能评估的全流程,并介绍 "全局平均池化" 对模型结构的优化思路。

一、数据预处理与加载

数据预处理是深度学习任务的 "第一步",需将原始图像转换为模型可学习的格式,并通过归一化让数据分布更适合训练。

1. 数据转换与归一化

借助torchvision.transforms,完成两个核心操作:

  • ToTensor():将图像(PIL/NumPy 格式)转换为 PyTorch 张量,并将像素值缩放到[0, 1]区间。
  • Normalize(mean, std):对张量做归一化,使每个通道的均值为mean、标准差为std(本文将像素映射到[-1, 1]区间,便于模型稳定训练)。

python

运行

复制代码
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

2. 数据集与数据加载器

通过torchvision.datasets.CIFAR10加载数据集(本地已下载时设download=False),再用torch.utils.data.DataLoader创建批量迭代器,实现 "批次划分、数据打乱、多进程读取" 等功能,提升训练效率。

python

运行

复制代码
trainset = torchvision.datasets.CIFAR10(
    root=r'本地数据集路径', train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=4, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root=r'本地数据集路径', train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=4, shuffle=False, num_workers=2
)

二、CNN 模型的构建与原理

卷积神经网络(CNN)通过 "卷积 - 激活 - 池化 " 的层级结构提取图像特征,再通过全连接层完成 "特征到类别" 的映射。以下是基础 CNN 模型的核心设计。

1. 模型结构解析

定义的CNNNet包含三类核心组件:

  • 卷积层nn.Conv2d负责提取局部特征(如边缘、纹理)。例:self.conv1 = nn.Conv2d(3, 16, 5)表示 "输入 3 通道(彩色图像)、输出 16 通道、卷积核大小 5×5"。
  • 池化层nn.MaxPool2d对特征图 "下采样",减少参数数量同时保留关键特征。例:self.pool1 = nn.MaxPool2d(2, 2)表示 "核大小 2×2、步长 2"。
  • 全连接层nn.Linear将 "展平的特征" 映射到 "类别空间"。例:self.fc1 = nn.Linear(36 * 6 * 6, 128)将卷积层输出的 "36 个 6×6 特征图" 展平后,映射到 128 维隐藏层。

python

运行

复制代码
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 36, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(36 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)  # 10类输出

    def forward(self, x):
        # 卷积→激活→池化
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        # 展平特征图,为全连接层做准备
        x = x.view(-1, 36 * 6 * 6)
        # 全连接→激活→最终分类
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

2. 前向传播逻辑

forward方法定义了 "数据在模型中的流动路径":卷积层提取局部特征→ReLU 激活引入 "非线性"→池化层压缩特征维度→全连接层完成 "特征到类别" 的映射。

三、模型的训练与优化

训练过程需结合损失函数优化器迭代训练逻辑,让模型逐步学习数据中的模式。

1. 损失函数与优化器

  • 损失函数 :选用nn.CrossEntropyLoss,它融合了 "softmax(将输出映射为概率)" 和 "交叉熵(衡量预测与真实标签的差异)",是多分类任务的常用选择。
  • 优化器 :使用optim.SGD(随机梯度下降),并加入momentum=0.9(动量)加速收敛;lr=0.001(学习率)控制参数更新的 "步长"。

python

运行

复制代码
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

2. 迭代训练流程

通过多轮(epochs=10)迭代,每次遍历训练集(trainloader),完成 "前向传播→损失计算→反向传播→参数更新" 的闭环,并定期打印损失以监控训练过程。

python

运行

复制代码
epochs = 10  # 训练轮次
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 数据迁移到设备(CPU/GPU)
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()  # 清空历史梯度
        outputs = net(inputs)  # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播(求梯度)
        optimizer.step()  # 更新模型参数

        running_loss += loss.item()
        # 每2000批次打印一次损失
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

四、模型评估:从 "整体准确" 到 "类别精细分析"

训练完成后,需在测试集上评估模型性能,包括整体准确率 (全局表现)和各类别准确率(精细分析)。

1. 整体准确率测试

遍历测试集,统计 "预测正确的样本数" 与 "总样本数" 的比例。用torch.no_grad()关闭梯度计算(测试阶段无需更新参数,提升效率)。

python

运行

复制代码
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取每个样本概率最大的类别
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)  # 累加总样本数
        correct += (predicted == labels).sum().item()  # 累加正确数

print(f'模型在10000张测试图像上的整体准确率: {100 * correct / total} %')

2. 各类别准确率分析

为每个类别单独统计 "正确数 / 总数",可分析模型对不同类别的 "识别偏好"(如对 "car" 识别准度高,对 "cat" 准度低)。

python

运行

复制代码
class_correct = [0.] * 10  # 每个类别正确数
class_total = [0.] * 10    # 每个类别总数
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        # 标记每个样本是否预测正确
        c = (predicted == labels).squeeze()
        for i in range(4):  # batch_size=4,遍历每个样本
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

# 打印每个类别的准确率
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
for i in range(10):
    print(f'{classes[i]} 类别的准确率: {100 * class_correct[i] / class_total[i]} %')

五、模型优化:全局平均池化的应用

传统 CNN 依赖 "大维度全连接层" 完成分类,易导致参数过多过拟合全局平均池化(Global Average Pooling, GAP) 可替代全连接层,简化结构并保留 "全局特征"。

1. 全局平均池化的原理

nn.AdaptiveAvgPool2d(1)每个特征图 压缩为 "一个平均值"(即每个通道输出1×1的特征)。这样无需手动计算 "展平后的特征维度",且能直接衔接 "小维度全连接层"(甚至直接输出类别)。

2. 改进后的模型结构

用 "全局平均池化 + 小全连接层" 替代 "大维度全连接层",模型更轻量化:

python

运行

复制代码
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 36, 5)
        self.aap = nn.AdaptiveAvgPool2d(1)  # 全局平均池化
        self.fc3 = nn.Linear(36, 10)  # 小全连接层(36→10类)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.aap(x)  # 全局平均池化,输出形状(batch_size, 36, 1, 1)
        x = x.view(x.size(0), -1)  # 展平为(batch_size, 36)
        x = self.fc3(x)  # 全连接输出类别
        return x

3. 参数数量对比

全局平均池化避免了 "大维度全连接层" 的海量参数。通过sum(x.numel() for x in net.parameters())统计参数总数,可发现 GAP 模型的参数显著少于传统 CNN。

总结

本文基于 PyTorch 完成了 CIFAR-10 图像分类任务,覆盖了数据预处理、CNN 模型构建、训练优化、多维度评估的全流程,并引入 "全局平均池化" 优化模型结构。从实践中可观察到:

  • 基础 CNN 能学习到图像分类能力,但依赖大全连接层,参数较多;
  • 全局平均池化通过 "压缩特征 + 小全连接层",在减少参数的同时保留了有效特征提取能力,是 ResNet 等大型 CNN 常用的优化技巧。

这套流程也可迁移到其他图像分类任务,为更复杂的计算机视觉应用(如目标检测、语义分割)打下基础。

相关推荐
41号学员2 小时前
构建神经网络的两大核心工具
人工智能·pytorch·深度学习
Wah-Aug5 小时前
PyTorch 模型评估与全局平均池化的应用实践
人工智能·pytorch·python
诸葛箫声5 小时前
基于PyTorch的CIFAR-10图像分类项目总结(2)
人工智能·pytorch·分类
鲸鱼24015 小时前
图像分类笔记
大数据·笔记·分类
麒羽76016 小时前
PyTorch 实现 CIFAR10 数据集的 CNN 分类实践
pytorch·分类·cnn
热爱生活的猴子16 小时前
使用bert或roberta模型做分类训练时,分类数据不平衡时,可以采取哪些优化的措施
人工智能·分类·bert
jie*16 小时前
小杰机器学习高级(five)——分类算法的评估标准
人工智能·python·深度学习·神经网络·机器学习·分类·回归
彭祥.18 小时前
点云-标注-分类-航线规划软件 (一)点云自动分类
人工智能·分类·数据挖掘
Teacher.chenchong19 小时前
PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化技术
pytorch·深度学习·分类