PyTorch深度学习实战(3)—— 小试牛刀:CIFAR-10分类

下面尝试从零搭建一个PyTorch模型来完成CIFAR-10数据集上的图像分类任务,步骤如下。

(1)使用torchvision加载并预处理CIFAR-10数据集。

(2)定义网络。

(3)定义损失函数和优化器。

(4)训练网络,并更新网络参数。

(5)测试网络。

1 CIFAR-10数据加载及预处理

CIFAR-10是一个常用的彩色图片数据集,它有10个类别:airplane、automobile、bird、cat、deer、dog、frog、horse、ship和truck。每张图片大小都是3\\times32\\times32,即3通道彩色图片,分辨率为32\\times32。下面举例说明如何完成图像加载与预处理:

复制代码
In:` `import torch as t`
    `import torchvision as tv`
    `import torchvision.transforms as transforms`
    `from torchvision.transforms import ToPILImage`
`    show = ToPILImage()` `# 可以把Tensor转成Image,Jupyter可直接显示Image对象In: # 第一次运行程序torchvision会自动下载CIFAR-10数据集,`
    `# 数据集大小约为100M,需花费一些时间,`
    `# 如果已经下载好CIFAR-10数据集,那么可通过root参数指定`
    
    `# 定义对数据的预处理`
`    transform = transforms.Compose([`
`            transforms.ToTensor(),` `# 转为Tensor`
`            transforms.Normalize((0.5,` `0.5,` `0.5),` `(0.5,` `0.5,` `0.5)),` `# 归一化`
                                 `])`
    `# 训练集`
`    trainset = tv.datasets.CIFAR10(`
`                        root='./pytorch-book-cifar10/',` 
`                        train=True,` 
`                        download=True,`
`                        transform=transform)`
    
`    trainloader = t.utils.data.DataLoader(`
`                        trainset,` 
`                        batch_size=4,`
`                        shuffle=True,` 
`                        num_workers=2)`
    
    `# 测试集`
`    testset = tv.datasets.CIFAR10(`
                        `'./pytorch-book-cifar10/',`
`                        train=False,` 
`                        download=True,` 
`                        transform=transform)`
    
`    testloader = t.utils.data.DataLoader(`
`                        testset,`
`                        batch_size=4,` 
`                        shuffle=False,`
`                        num_workers=2)`
    
`    classes =` `('plane',` `'car',` `'bird',` `'cat',` `'deer',` 
               `'dog',` `'frog',` `'horse',` `'ship',` `'truck')`
 
` Out:Files already downloaded and verified`
`    Files already downloaded and verifiedDataset`
`

对象是一个数据集,可以按下标访问,返回形如(data, label)的数据,举例说明如下:

复制代码
In:` `(data, label)` `= trainset[100]`
    `print(classes[label])`
    
    `# (data + 1) / 2目的是还原被归一化的数据`
`    show((data +` `1)` `/` `2).resize((100,` `100))Out:ship`
`

Dataloader是一个可迭代对象,它将Dataset返回的每一条数据样本拼接成一个batch,同时提供多线程加速优化和数据打乱等操作。当程序对Dataset的所有数据遍历完一遍后,对Dataloader也完成了一次迭代:

复制代码
In: dataiter =` `iter(trainloader)`     `# 生成迭代器`
`    images, labels = dataiter.next()` `# 返回4张图片及标签`
    `print(' '.join('%11s'%classes[labels[j]]` `for j in` `range(4)))` 
`    show(tv.utils.make_grid((images +` `1)` `/` `2)).resize((400,100))`

`Out:  horse        frog       plane        bird`
`

2 定义网络

拷贝上面的LeNet网络,因为CIFAR-10数据集中的数据是3通道的彩色图像,所以将self.conv1中第一个通道参数修改为3:

复制代码
In:` `import torch.nn as nn`
    `import torch.nn.functional as F`
    
    `class` `Net(nn.Module):`
        `def` `__init__(self):`
            `super(Net, self).__init__()`
`            self.conv1 = nn.Conv2d(3,` `6,` `5)` `# 将第一个通道参数修改为3`
`            self.conv2 = nn.Conv2d(6,` `16,` `5)`  
`            self.fc1   = nn.Linear(16` `*` `5` `*` `5,` `120)`  
`            self.fc2   = nn.Linear(120,` `84)`
`            self.fc3   = nn.Linear(84,` `10)` `# 类别数为10`
    
        `def` `forward(self, x):` 
`            x = F.max_pool2d(F.relu(self.conv1(x)),` `(2,` `2))` 
`            x = F.max_pool2d(F.relu(self.conv2(x)),` `2)` 
`            x = x.view(x.size()[0],` `-1)` 
`            x = F.relu(self.fc1(x))`
`            x = F.relu(self.fc2(x))`
`            x = self.fc3(x)`        
            `return x`
 
`    net = Net()`
    `print(net)Out:Net(`
        `(conv1): Conv2d(3,` `6, kernel_size=(5,` `5), stride=(1,` `1))`
        `(conv2): Conv2d(6,` `16, kernel_size=(5,` `5), stride=(1,` `1))`
        `(fc1): Linear(in_features=400, out_features=120, bias=True)`
        `(fc2): Linear(in_features=120, out_features=84, bias=True)`
        `(fc3): Linear(in_features=84, out_features=10, bias=True)`
    `)`
`

3 定义损失函数和优化器

这里使用交叉熵nn.CrossEntropyLoss作为损失函数,随机梯度下降法作为优化器:

复制代码
In:` `from torch import optim`
`    criterion = nn.CrossEntropyLoss()` `# 交叉熵损失函数`
`    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)`
`

4 训练网络

所有网络的训练流程都是类似的,也就是不断地执行如下流程。

(1)输入数据。

(2)前向传播、反向传播。

(3)更新参数。

复制代码
In:` `for epoch in` `range(2):`  
`        running_loss =` `0.0`
        `for i, data in` `enumerate(trainloader,` `0):`
            `# 输入数据`
`            inputs, labels = data`
            
            `# 梯度清零`
`            optimizer.zero_grad()`
            
            `# forward + backward `
`            outputs = net(inputs)`
`            loss = criterion(outputs, labels)`
`            loss.backward()`   
            
            `# 更新参数 `
`            optimizer.step()`
            
            `# 打印log信息`
`            running_loss += loss.item()`
            `if i %` `2000` `==` `1999:` `# 每2000个batch打印一下训练状态`
                `print('[%d, %5d] loss: %.3f' \`
                      `%` `(epoch+1, i+1, running_loss /` `2000))`
`                running_loss =` `0.0`
    `print('Finished Training')`
 
` Out:[1,`  `2000] loss:` `2.228`
    `[1,`  `4000] loss:` `1.890`
    `[1,`  `6000] loss:` `1.683`
    `[1,`  `8000] loss:` `1.592`
    `[1,` `10000] loss:` `1.513`
    `[1,` `12000] loss:` `1.478`
    `[2,`  `2000] loss:` `1.387`
    `[2,`  `4000] loss:` `1.368`
    `[2,`  `6000] loss:` `1.346`
    `[2,`  `8000] loss:` `1.324`
    `[2,` `10000] loss:` `1.300`
    `[2,` `12000] loss:` `1.255`
`    Finished Training`
`

这里仅训练了2个epoch(遍历完一遍数据集称为1个epoch),下面来看看网络有没有效果。将测试图片输入到网络中,计算它的label,然后与实际的label进行比较:

复制代码
In: dataiter =` `iter(testloader)`
`    images, labels = dataiter.next()` `# 一个batch返回4张图片`
    `print('实际的label: ',` `' '.join(\`
                `'%08s'%classes[labels[j]]` `for j in` `range(4)))`
`    show(tv.utils.make_grid(images /` `2` `-` `0.5)).resize((400,` `100))`
 
 `Out:实际的label:       cat     ship     ship    plane`
`

接着计算网络预测的分类结果:

复制代码
In:` `# 计算图片在每个类别上的分数`
`    outputs = net(images)`
    `# 得分最高的那个类`
`    _, predicted = t.max(outputs.data,` `1)` 
    
    `print('预测结果: ',` `' '.join('%5s'% classes[predicted[j]]` `for j in` `range(4)))`
 
`Out:预测结果:    cat  ship  ship  ship`
`

从上述结果可以看出:网络的准确率很高,针对这四张图片达到了75%的准确率。然而,这只是一部分图片,下面再来看看在整个测试集上的效果:

复制代码
In: correct =` `0` `# 预测正确的图片数`
`    total =` `0` `# 总共的图片数`
 
    `# 由于测试的时候不需要求导,可以暂时关闭autograd,提高速度,节约内存`
    `with t.no_grad():`
        `for data in testloader:`
`            images, labels = data`
`            outputs = net(images)`
`            _, predicted = t.max(outputs,` `1)` 
`            total += labels.size(0)`
`            correct +=` `(predicted == labels).sum()`
    
    `print('10000张测试集中的准确率为: %f %%'` `%` `(100` `* correct // total))`

`Out:10000张测试集中的准确率为:` `52.000000` `%`
`

训练结果的准确率远比随机猜测(准确率为10%)好,证明网络确实学到了东西。

5 在GPU上训练

就像把Tensor从CPU转移到GPU一样,模型也可以类似地从CPU转移到GPU,从而加速网络训练:

复制代码
In: device = t.device("cuda:0"` `if t.cuda.is_available()` `else` `"cpu")`
`    net.to(device)`
`    images = images.to(device)`
`    labels = labels.to(device)`
`    output = net(images)`
`    loss= criterion(output,labels)`
    
`    lossOut:tensor(0.5668, device='cuda:0', grad_fn=<NllLossBackward>)`

`

6 小结

本文给出了一个PyTorch快速入门指南,具体包含以下内容。

  • Tensor:类似NumPy数组的数据结构,它的接口与NumPy的接口类似,可以方便地互相转换。
  • autograd:为Tensor提供自动求导功能。
  • nn:专门为神经网络设计的接口,提供了很多有用的功能,如神经网络层、损失函数、优化器等。
  • 神经网络训练:以CIFAR-10分类为例,演示了神经网络的训练流程,包括数据加载、网络搭建、模型训练及模型测试。

通过本文的学习,可以大概了解PyTorch的主要功能,并能够使用PyTorch编写简单的模型。从下一篇开始,将深入系统地讲解PyTorch的各部分知识。

相关推荐
Python图像识别19 小时前
71_基于深度学习的布料瑕疵检测识别系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo
哥布林学者1 天前
吴恩达深度学习课程一:神经网络和深度学习 第三周:浅层神经网络(二)
深度学习·ai
weixin_519535771 天前
从ChatGPT到新质生产力:一份数据驱动的AI研究方向指南
人工智能·深度学习·机器学习·ai·chatgpt·数据分析·aigc
生命是有光的1 天前
【深度学习】神经网络基础
人工智能·深度学习·神经网络
信田君95271 天前
瑞莎星瑞(Radxa Orion O6) 基于 Android OS 使用 NPU的图片模糊查找APP 开发
android·人工智能·深度学习·神经网络
StarPrayers.1 天前
卷积神经网络(CNN)入门实践及Sequential 容器封装
人工智能·pytorch·神经网络·cnn
数智顾问1 天前
基于深度学习的卫星图像分类(Kaggle比赛实战)——从数据预处理到模型调优的全流程解析
深度学习
望获linux1 天前
【实时Linux实战系列】Linux 内核的实时组调度(Real-Time Group Scheduling)
java·linux·服务器·前端·数据库·人工智能·深度学习
程序员大雄学编程1 天前
「深度学习笔记4」深度学习优化算法完全指南:从梯度下降到Adam的实战详解
笔记·深度学习·算法·机器学习