PyTorch|构建自己的卷积神经网络--池化操作

在卷积神经网络中,一般在卷积层后,我们往往进行池化操作。实现池化操作很简单,pytorch中早已有相应的实现。

nn.MaxPool2d(kernel_size= ,stride= )

这种池化叫做最大池化

最大池化原理很简单,就是一个filter以一定的stride在原数据上进行操作,就像这样:

这里是一个2x2的filter,同时stride为2,在原始数据上扫描,最终的到新的数据

通过代码来实现:

复制代码
>>> import torch>>> import torch.nn as nn>>> data=torch.tensor([[1,1,2,4],[5,6,7,8],[3,2,1,0],[1,2,3,4]],dtype=torch.float32)>>> data.size()torch.Size([4, 4])>>> data=data.unsqueeze(0)>>> data.size()torch.Size([1, 4, 4])>>> data=data.unsqueeze(0)>>> data.size()torch.Size([1, 1, 4, 4])

之所以这样操作,是为了模仿实际训练网络时的场景。这样data的尺寸变为了1x1x4x4

可以理解为Batch为1,Channel为1,Height为4,Width为4​​​​​​​

复制代码
>>> net=nn.MaxPool2d(kernel_size=2,stride=2)>>> end=net(data)>>> endtensor([[[[6., 8.],          [3., 4.]]]])

结果符合预期

当然,nn.MaxPool2d()的可接受参数很多,完整如下:

nn.MaxPool2d(kernel_size,stride=None,padding=0,dilation=1,return_indices=False,ceil_mode=False)

  • kernel_size:设置filter大小

  • stride:控制移动步长,默认为kernel_size的尺寸

  • padding:对原始数据周围进行填充

  • dilation:给原始数据之间添加0

  • return_indices :如果为True,会返回输出最大值对应的序号序列。

  • ceil_mode:控制当卷积核超过原始图像时,是否对max进行保留

这里参数很多,但很多时候我们并不需要一些参数,就像这样:

nn.MaxPool2d(kernel_size=,stride=,padding=)

或者这样:

nn.MaxPool2d(kernel_size,stride)

当然,进行池化操作后,我们很可能再来连接一个卷积层或者全连接层,所以对数据进行池化后,数据的尺寸如何变化是必须要知道的。

公式(假设输入的数据H和W相同):

输入:N,C,H,W

输出:N,C,H',W'

  • H'=(H+2*padding-kernel_size)/stride+1

  • W'=(W'+2*padding-kernel_size)/stride+1

再次回到最初的代码:​​​​​​​

复制代码
>>> import torch>>> import torch.nn as nn>>> data=torch.tensor([[1,1,2,4],[5,6,7,8],[3,2,1,0],[1,2,3,4]],dtype=torch.float32)>>> data.size()torch.Size([4, 4])>>> data=data.unsqueeze(0)>>> data.size()torch.Size([1, 4, 4])>>> data=data.unsqueeze(0)>>> data.size()torch.Size([1, 1, 4, 4])
>>> net=nn.MaxPool2d(kernel_size=2,stride=2)>>> end=net(data)>>> endtensor([[[[6., 8.],          [3., 4.]]]])

输入:1x1x4x4

输出:1x1x2x2

2=(4+2*0-2)/2+1

当然池化种类也有很多,但理解起来都不难,同时,kernel_size,dilation,stride,padding的尺寸也不一定为正方形。

相关推荐
Kusunoki_D1 小时前
PyTorch 环境配置
人工智能·pytorch·python
魔乐社区3 小时前
OpenAI重新开源!gpt-oss-20b适配昇腾并上线魔乐社区
人工智能·gpt·深度学习·开源·大模型
Coovally AI模型快速验证5 小时前
全景式综述|多模态目标跟踪全面解析:方法、数据、挑战与未来
人工智能·深度学习·算法·机器学习·计算机视觉·目标跟踪·无人机
格林威5 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型和EasyOCR实现汽车牌照动态检测和识别(C#代码,UI界面版)
人工智能·深度学习·数码相机·yolo·c#·汽车·视觉检测
左灯右行的爱情6 小时前
深度学习设计模式:责任链(Chain of Responsibility)模式(例子+业务场景+八股)
深度学习·设计模式·责任链模式
Virgil13916 小时前
【TrOCR】模型预训练权重各个文件解读
人工智能·pytorch·计算机视觉·自然语言处理·ocr·transformer
MaxCode-116 小时前
【机器学习 / 深度学习】基础教程
人工智能·深度学习·机器学习
先做个垃圾出来………16 小时前
神经网络(Neural Network, NN)
人工智能·深度学习·神经网络
朝日六六花_LOCK20 小时前
深度学习之NLP基础
人工智能·深度学习·自然语言处理
Hao想睡觉21 小时前
循环神经网络实战:用 LSTM 做中文情感分析(二)
rnn·深度学习·lstm