PyTorch入门学习(九):神经网络-最大池化使用

目录

一、数据准备

二、创建神经网络模型

三、可视化最大池化效果


一、数据准备

首先,需要准备一个数据集来演示最大池化层的应用。在本例中,使用了CIFAR-10数据集,这是一个包含10个不同类别图像的数据集,用于分类任务。我们使用PyTorch的torchvision库来加载CIFAR-10数据集并进行必要的数据转换。

python 复制代码
import torch
import torchvision
from torch.utils.data import DataLoader

# 数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)

# 使用DataLoader加载数据集,每批次包含64张图像
dataLoader = DataLoader(dataset, batch_size=64)

二、创建神经网络模型

接下来,创建一个简单的神经网络模型,其中包含一个卷积层和一个最大池化层。这个模型将帮助演示最大池化层的效果。首先定义一个Tudui类,该类继承了nn.Module,并在初始化方法中创建了一个卷积层和一个最大池化层。

python 复制代码
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn.functional import max_pool2d

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init()
        # 卷积层
        self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
        # 最大池化层
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        return x

tudui = Tudui()
print(tudui)

上述代码中,定义了Tudui类,包括了一个卷积层和一个最大池化层。在forward方法中,数据首先经过卷积层,然后通过最大池化层,以减小图像的维度。

三、可视化最大池化效果

最大池化层有助于减小图像的维度,提取图像中的主要特征。接下来将使用TensorBoard来可视化最大池化的效果,以更好地理解它。首先,导入SummaryWriter类并创建一个SummaryWriter对象。

python 复制代码
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

然后,遍历数据集,对每个批次的图像应用卷积和最大池化操作,并将卷积前后的图像写入TensorBoard。

python 复制代码
step = 0
for data in dataLoader:
    imgs, targets = data
    
    # 卷积和最大池化操作
    output = tudui(imgs)
    
    # 将输入图像写入TensorBoard
    writer.add_images("input", imgs, step)
    
    # 由于TensorBoard不能直接显示多通道图像,我们需要重定义输出图像的大小
    output = torch.reshape(output, (-1, 6, 15, 15))
    
    # 将卷积和最大池化后的图像写入TensorBoard
    writer.add_images("output", output, step)
    
    step += 1

writer.close()

在上述代码中,使用writer.add_images将输入和输出的图像写入TensorBoard,并使用torch.reshape来重定义输出图像的大小,以适应TensorBoard的显示要求。

运行上述代码后,将在TensorBoard中看到卷积和最大池化的效果。最大池化层有助于提取图像中的关键信息,减小图像维度,并提高模型的计算效率。

完整代码如下:

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#使用dataloader加载数据集,批次数为64
dataLoader = DataLoader(dataset,batch_size=64)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui,self).__init__()
        # 该神经网络调用conv2d进行一层卷积,输入通道为3层(彩色图像为3通道),卷积核大小为3*3,输出通道为6,设置步长为1,padding为0,不进行填充。
        self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

tudui = Tudui()
print(tudui)

# 生成日志
writer = SummaryWriter("logs")

step = 0
# 输出卷积前的图片大小和卷积后的图片大小
for data in dataLoader:
    imgs,targets = data
    # 卷积操作
    output = tudui(imgs)
    print(imgs.shape)
    print(output.shape)
    writer.add_images("input",imgs,step)
    """
     注意:使用tensorboard输出时需要重新定义图片大小
     对于输入的图片集imgs来说,tensor.size([64,3,32,32]),即一批次为64张,一张图片为三个通道,大小为32*32
     对于经过卷积后输出的图片集output来说,tensor.size([64,6,30,30]),通道数变成了6,tensorboard不知道怎么显示通道数为6的图片,所以如果直接输出会报错
     解决方案:
     使用reshape方法对outputs进行重定义,把通道数改成3,如果不知道批次数大小,可以使用-1代替,程序会自动匹配批次大小。
    """
    #重定义输出图片的大小
    output = torch.reshape(output,(-1,3,30,30))
    # 显示输出的图片
    writer.add_images("output",output,step)
    step = step + 1
writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关推荐
xuhaoyu_cpp_java6 小时前
项目学习(三)分页查询
java·经验分享·笔记·学习
小宋加油啊8 小时前
机械臂抓取物体 PVN3D算法调研学习
学习·算法·3d
Xzh04239 小时前
AI Agent 学习路线(Java 后端方向)
java·人工智能·学习
做cv的小昊9 小时前
计算机图形学:【Games101】学习笔记08——光线追踪(辐射度量学、渲染方程与全局光照、蒙特卡洛积分与路径追踪)
图像处理·笔记·学习·计算机视觉·游戏引擎·图形渲染·概率论
星恒随风10 小时前
C++ 类和对象入门(五):初始化列表、explicit 和 static 成员详解
开发语言·c++·笔记·学习·状态模式
sensen_kiss11 小时前
CPT304 SoftwareEngineeringII 软件工程 2 Pt.8 软件测试 (Software Testing)(上)
学习·软件工程
力学与人工智能11 小时前
PPT分享 | 洛桑联邦理工学院魏震:深度几何学习在工业设计优化中的应用
学习·优化·工业设计·深度几何学习·洛桑联邦理工学院
湘美书院--湘美谈教育12 小时前
湘美谈教育AI系列经验集锦:赋能整理聊斋志异大寓言
大数据·人工智能·深度学习·神经网络·机器学习
sensen_kiss13 小时前
CPT304 SoftwareEngineeringII 软件工程 2 Pt.9 软件测试 (Software Testing)(下)
学习·软件工程
wu_ye_m13 小时前
学习c语言第35天 函数声明和定义
c语言·开发语言·学习