使用 PyTorch 框架对 CIFAR - 10 数据集进行CNN分类

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# 设置在 Jupyter Notebook 中显示 matplotlib 图像
%matplotlib inline

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# 定义类别
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 显示图像的函数
def imshow(img):
    img = img / 2 + 0.5  # 反标准化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机获取部分训练数据并显示
dataiter = iter(trainloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'%5s' % classes[labels[j]] for j in range(4)))

# 构建 CNN 网络
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(1296, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 36 * 6 * 6)
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return x

net = CNNNet()
net = net.to(device)

# 输出网络总参数数量
print("net have {} parameters in total".format(sum(x.numel() for x in net.parameters())))

# 设置损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
print('Finished Training')

# 在测试集上显示图像和真实标签
dataiter = iter(testloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'%5s' % classes[labels[j]] for j in range(4)))

# 在测试集上进行预测并显示预测结果
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'%5s' % classes[predicted[j]] for j in range(4)))

一、数据加载与预处理

  1. 数据集选择:使用的是 CIFAR - 10 数据集,它包含 10 个类别(plane、car、bird、cat、deer、dog、frog、horse、ship、truck)的彩色图像,每个图像大小为 32×32×3。
  2. 数据转换(transforms
    • transforms.ToTensor():将 PIL 图像或 numpy 数组转换为张量(Tensor),并将图像像素值从 [0,255] 范围归一化到 [0,1] 范围。
    • transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):对张量进行标准化处理。这里的两个元组分别是均值和标准差,每个通道(R、G、B)都使用相同的均值和标准差,处理后像素值范围变为 [- 1,1],有助于模型的训练收敛。
  3. 数据加载器(DataLoader
    • torch.utils.data.DataLoader用于批量加载数据。batch_size指定每个批次的样本数量(这里是 4);shuffleTrue时在每个 epoch 开始时打乱数据顺序,有助于模型泛化;num_workers指定用于数据加载的子进程数量(这里是 2),可以加速数据加载。

二、图像显示函数(imshow

  1. 反标准化 :因为之前对图像进行了标准化(像素值范围 [-1,1]),所以在显示前需要通过img = img / 2 + 0.5将其转换回 [0,1] 范围。
  2. 维度转换np.transpose(npimg, (1, 2, 0))将张量的维度从(通道数,高度,宽度)转换为(高度,宽度,通道数),这是matplotlib.pyplot.imshow函数要求的图像维度格式。

三、卷积神经网络(CNN)构建

  1. 网络结构
    • 卷积层(nn.Conv2d
      • 第一个卷积层self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1):输入通道数为 3(对应彩色图像的 R、G、B 通道),输出通道数为 16(即 16 个卷积核,产生 16 个特征图),卷积核大小为 5×5,步长为 1。
      • 第二个卷积层self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1):输入通道数为 16(来自上一层的输出),输出通道数为 36,卷积核大小为 3×3,步长为 1。卷积层的作用是提取图像的局部特征。
    • 池化层(nn.MaxPool2d
      • self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2):都是最大池化层,核大小为 2×2,步长为 2。池化层的作用是降低特征图的空间维度,减少参数数量,同时保留重要特征,增强模型的平移不变性。
    • 全连接层(nn.Linear
      • self.fc1 = nn.Linear(1296, 128):输入特征数为 1296(由前面卷积和池化操作后特征图的尺寸计算得到,36×6×6),输出特征数为 128。
      • self.fc2 = nn.Linear(128, 10):输入特征数为 128,输出特征数为 10(对应 CIFAR - 10 的 10 个类别)。全连接层用于将提取的特征映射到类别空间。
  2. 前向传播(forward方法)
    • 首先经过第一个卷积层,然后通过 ReLU 激活函数(F.relu),再经过第一个池化层。
    • 接着经过第二个卷积层、ReLU 激活函数和第二个池化层。
    • 然后通过x.view(-1, 36 * 6 * 6)将特征图展平为一维向量,作为全连接层的输入。
    • 最后经过两个全连接层,中间使用 ReLU 激活函数,得到最终的输出。

四、模型训练

  1. 损失函数(nn.CrossEntropyLoss :用于多分类问题,它结合了nn.LogSoftmaxnn.NLLLoss(负对数似然损失),能够计算模型预测概率与真实标签之间的交叉熵损失。
  2. 优化器(optim.SGD :随机梯度下降优化器,lr=0.001是学习率,控制参数更新的步长;momentum=0.9是动量,有助于加速梯度下降过程,减少震荡。
  3. 训练循环
    • 遍历多个 epoch(这里是 10 个),每个 epoch 遍历整个训练数据集。
    • 在每个批次中,首先将输入数据和标签移动到指定设备(CPU 或 GPU)。
    • 使用optimizer.zero_grad()清除之前的梯度,防止梯度累积。
    • 进行正向传播(net(inputs))得到模型输出,计算损失(criterion(outputs, labels))。
    • 反向传播(loss.backward())计算梯度,然后通过optimizer.step()更新模型参数。
    • 累计损失并定期(每 2000 个批次)输出,以监控训练过程。

五、模型评估

  1. 测试数据加载与显示:从测试数据加载器中获取一批数据,显示图像并打印真实标签,用于直观对比模型预测结果。
  2. 预测过程 :将测试图像移动到指定设备,通过模型得到输出(net(images))。使用torch.max(outputs, 1)获取每个样本预测概率最大的类别索引(predicted),然后根据类别索引获取对应的类别名称并打印,与真实标签对比,评估模型的分类效果。
相关推荐
凳子(刘博浩)3 小时前
使用 PyTorch 实现 CIFAR-10 图像分类:从数据加载到模型训练全流程
人工智能·pytorch·分类
史锦彪5 小时前
PyTorch 实现 CIFAR-10 图像分类:从基础 CNN 到全局平均池化的探索
pytorch·分类·cnn
41号学员6 小时前
构建神经网络的两大核心工具
人工智能·pytorch·深度学习
Wah-Aug8 小时前
PyTorch 模型评估与全局平均池化的应用实践
人工智能·pytorch·python
诸葛箫声8 小时前
基于PyTorch的CIFAR-10图像分类项目总结(2)
人工智能·pytorch·分类
鲸鱼24018 小时前
图像分类笔记
大数据·笔记·分类
麒羽76019 小时前
PyTorch 实现 CIFAR10 数据集的 CNN 分类实践
pytorch·分类·cnn
热爱生活的猴子19 小时前
使用bert或roberta模型做分类训练时,分类数据不平衡时,可以采取哪些优化的措施
人工智能·分类·bert
jie*19 小时前
小杰机器学习高级(five)——分类算法的评估标准
人工智能·python·深度学习·神经网络·机器学习·分类·回归