深度学习基础——(3)视觉处理基础实战【CNN实现CIFAR10 多分类】

文章目录

  • 一、步骤说明
  • 二、实验代码
    • [2.1 代码](#2.1 代码)
    • [2.2 显示各层参数](#2.2 显示各层参数)
  • 三、改进
    • [3.1 改进1:全局池化](#3.1 改进1:全局池化)
    • [3.2 改进2:使用模型集成方法](#3.2 改进2:使用模型集成方法)
    • [3.2 改进3:使用现代经典模型VGG16](#3.2 改进3:使用现代经典模型VGG16)

一、步骤说明

CIFAR-10:包含 10 类小图片:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。

  • 图片尺寸:32 × 32 像素(非常小!)
  • 彩色图:3 通道(RGB)
  • 训练集:50000 张
  • 测试集:10000 张

通过一个用卷积网络实现分类的实例,来说明如何处理数据,借助 nn工具箱来实现神经网络,并实现训练测试等完整过程。

  • 数据处理
  • 搭建网络
  • 循环训练 epochs 轮
  • 训练结束:读取保存的各项list,绘制损失、准确率变化曲线

二、实验代码

2.1 代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets
import torchvision.transforms as transforms

# =============================== 数据处理 ===============================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = data.DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=4, shuffle=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# =============================== 搭建网络 ===============================
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()	
        self.conv1 = nn.Conv2d(3, 16, 5,1)	
        self.pool = nn.MaxPool2d(2, 2)	
        self.conv2 = nn.Conv2d(16, 36, 3,1)
        self.pool2 = nn.MaxPool2d(2, 2)	
        self.fc1 = nn.Linear(36 * 6 * 6, 128)	
        self.fc2 = nn.Linear(128, 10)	

    def forward(self, x):	# 输入:(3,32,32)
        x = self.pool(F.relu(self.conv1(x)))# conv1:(16,28,28),pool1(16,14,14)
        x = self.pool(F.relu(self.conv2(x)))# conv2:(16,14,14),pool1(36,12,12)
        x = x.view(-1, 36 * 6 * 6)# (36,6,6)
        x = F.relu(self.fc1(x))#把36×6×6 = 1296 维向量变成 128 维
        x = self.fc2(x)# 最后输出 10 个分类
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = CNN()
net.to(device)
# print(net)

# =============================== 训练模型 ===============================
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = net(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch + 1}/10], Average Loss: {avg_loss:.4f}')


# =============================== 测试模型 ===============================
class_correct = [0] * 10
class_total = [0] * 10

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)

        # 计算每一类正确数量
        c = (predicted == labels)
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

# 打印 10 个分类的正确率
for i in range(10):
    print(f'类别 {classes[i]:<6} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

# 总准确率
total_correct = sum(class_correct)
total_num = sum(class_total)
print(f"\n模型总准确率: {100 * total_correct / total_num:.2f}%")
bash 复制代码
Epoch [1/10], Average Loss: 1.5084
Epoch [2/10], Average Loss: 1.0949
Epoch [3/10], Average Loss: 0.9218
Epoch [4/10], Average Loss: 0.8020
Epoch [5/10], Average Loss: 0.7049
Epoch [6/10], Average Loss: 0.6223
Epoch [7/10], Average Loss: 0.5434
Epoch [8/10], Average Loss: 0.4839
Epoch [9/10], Average Loss: 0.4290
Epoch [10/10], Average Loss: 0.3802
类别 plane  的准确率: 72.70%
类别 car    的准确率: 70.60%
类别 bird   的准确率: 58.30%
类别 cat    的准确率: 49.40%
类别 deer   的准确率: 55.20%
类别 dog    的准确率: 54.20%
类别 frog   的准确率: 73.50%
类别 horse  的准确率: 69.80%
类别 ship   的准确率: 81.10%
类别 truck  的准确率: 78.60%

模型总准确率: 66.34%

2.2 显示各层参数

安装包:

powershell 复制代码
pip install torchsummary

代码:

python 复制代码
from torchsummary import summary
summary(net, (3, 32, 32))  # (输入通道, H, W)

三、改进

3.1 改进1:全局池化

PyTorch可以用 nn.AdaptiveAvgPool2d(1)实现全局平均池化或全局最大池化。

python 复制代码
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)	 
        self.pool1 = nn.MaxPool2d(2, 2)	 
        self.conv2 = nn.Conv2d(16, 36, 5)	
        self.pool2 = nn.MaxPool2d(2, 2)	# 
        self.aap = nn.AdaptiveAvgPool2d(1)	
        self.fc3 = nn.Linear(36, 10)	

    def forward(self, x):	# 输入:3×32×32
        x = self.pool1(F.relu(self.conv1(x)))  # conv1:16×28×28,pool1:16×14×14
        x = self.pool2(F.relu(self.conv2(x))) # conv2:36×10×10,pool2:36×5×5
        x = self.aap(x) 	# 自适应全局平均池化,36×1×1
        x = x.view(x.shape[0], -1)  # 拉平:36
        x = self.fc3(x)  # fc3:10
        return x

循环同样的次数,其精度达到 63% 左右,但其使用的参数比没使用全局池化层的网络少很

多。前者只用了16022个参数,后者使用了173742个参数,是前者的10倍多。这个网络比较

简单,如果遇到复杂网络,差距将更大。

bash 复制代码
Epoch [1/10], Average Loss: 1.7630
Epoch [2/10], Average Loss: 1.4681
Epoch [3/10], Average Loss: 1.3320
Epoch [4/10], Average Loss: 1.2455
Epoch [5/10], Average Loss: 1.1820
Epoch [6/10], Average Loss: 1.1338
Epoch [7/10], Average Loss: 1.0942
Epoch [8/10], Average Loss: 1.0623
Epoch [9/10], Average Loss: 1.0342
Epoch [10/10], Average Loss: 1.0080
类别 plane  的准确率: 71.50%
类别 car    的准确率: 84.00%
类别 bird   的准确率: 45.80%
类别 cat    的准确率: 28.90%
类别 deer   的准确率: 53.90%
类别 dog    的准确率: 62.30%
类别 frog   的准确率: 73.30%
类别 horse  的准确率: 79.30%
类别 ship   的准确率: 80.80%
类别 truck  的准确率: 58.30%

模型总准确率: 63.81%

使用全局平均池化确实能减少很多参数,而且泛化能力也比较好。它的缺点是收敛速度比较慢,但是这个不足可以通过增加循环次数进行弥补。

3.2 改进2:使用模型集成方法

3.2 改进3:使用现代经典模型VGG16

相关推荐
RWKV元始智能15 小时前
RWKV超并发项目教程,RWKV-LM训练提速40%
人工智能·rnn·深度学习·自然语言处理·开源
@insist12316 小时前
信息安全工程师考点精讲:身份认证核心原理与分类体系(上篇)
大数据·网络·分类·信息安全工程师·软件水平考试
AI技术增长17 小时前
Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题
pytorch·深度学习·机器学习
小糖学代码19 小时前
LLM系列:2.pytorch入门:8.神经网络的损失函数(criterion)
人工智能·深度学习·神经网络
Jmayday19 小时前
Pytorch:RNN理论基础
pytorch·rnn·深度学习
AI周红伟21 小时前
周红伟:GPT-Image-2深度解析:从技术原理到实战教程,为什么它能让整个AI圈炸锅?
人工智能·gpt·深度学习·机器学习·语言模型·openclaw
端平入洛1 天前
梯度是什么:PyTorch 自动求导详解
人工智能·深度学习
时序之心1 天前
上海交大、东北大学:时序分类与感知领域的两项前沿突破
人工智能·分类·时间序列
nap-joker1 天前
不完全多模分类的推断时间动态模式选择
人工智能·分类·数据挖掘·不完整模态·插补-丢弃困境