神经网络:池化层

  1. 神经网络 池化操作

    下采样:减少特征数量

    先看池化操作:

    其中ceil_mode设置成True和False池化区别:

    在这个例子当中:ceil_mode=True表示边缘不满3x3的部分也会被池化,False表示边缘不满3x3的部分不会采样

    code:

    python 复制代码
    import torch
    from torch import nn
    from torch.nn import MaxPool2d
    
    input = torch.tensor([[1,2,0,3,1],
                          [0,1,2,3,1],
                          [1,2,1,0,0],
                          [5,2,3,1,1],
                          [2,1,0,1,1]
    ],dtype = torch.float32)
    #这里dtype为float是因为maxpool2d只能处理float类型的数据
    
    input = torch.reshape(input,(-1,1,5,5))
    print(input.shape)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            #ceil_mode=True表示边缘不满3x3的部分也会被池化
            #kernel_size=3 默认是卷积核的大小
            self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
            self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)
    
        def forward(self,input):
            #output = self.maxpool1(input)
            output = self.maxpool2(input)
            return output
    
    net = Net()
    output = net(input)
    print(output)

    ceil_mode=True:

    ceil_mode=False:

  2. 神经网络 池化层

    这里需要先看这篇博客:

https://blog.csdn.net/whdehcy/article/details/149486555?fromshare=blogdetail\&sharetype=blogdetail\&sharerId=149486555\&sharerefer=PC\&sharesource=whdehcy\&sharefrom=from_link

是讲卷积层的

现在将上一步的卷积得到的特征图作为池化的输入

python 复制代码
    pool_output = poolnet(conv_output)
    writer.add_images('pool_output',pool_output,cnt)

只需要添加一下池化的操作

python 复制代码
class poolNet(nn.Module):
    def __init__(self):
        super(poolNet,self).__init__()
        #ceil_mode=True表示边缘不满3x3的部分也会被池化
        #kernel_size=3 默认是卷积核的大小
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
        self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        #output = self.maxpool2(input)
        return output

poolnet = poolNet()

完整版代码:

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from torchvision import transforms

#数据预处理
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = [0.5,0.5,0.5],
        std = [0.5,0.5,0.5]
    )
])

#加载数据集
folder_path = '../images'
dataset = ImageFolder(folder_path,transform=transform)
dataloader = DataLoader(dataset,batch_size=1)

#卷积
class convNet(nn.Module):
    def __init__(self):
        #调用父类nn.Module的构造函数
        super(convNet,self).__init__()
        self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

convnet = convNet()

#池化
class poolNet(nn.Module):
    def __init__(self):
        super(poolNet,self).__init__()
        #ceil_mode=True表示边缘不满3x3的部分也会被池化
        #kernel_size=3 默认是卷积核的大小
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
        self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        #output = self.maxpool2(input)
        return output

poolnet = poolNet()

writer = SummaryWriter('../logs')

cnt = 0
for data in dataloader:
    img,label = data
    print(img.shape)
    conv_output = convnet(img)
    print(conv_output.shape)
    writer.add_images('input',img,cnt)
    conv_output = torch.reshape(conv_output,(-1,3,222,222))
    writer.add_images('conv_output',conv_output,cnt)
    pool_output = poolnet(conv_output)
    writer.add_images('pool_output',pool_output,cnt)
    cnt = cnt + 1

writer.close()

卷积:

池化:

相关推荐
会飞的老朱1 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º3 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
寻星探路3 小时前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
Codebee5 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º5 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys6 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56786 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder6 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能6 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算