神经网络:池化层

  1. 神经网络 池化操作

    下采样:减少特征数量

    先看池化操作:

    其中ceil_mode设置成True和False池化区别:

    在这个例子当中:ceil_mode=True表示边缘不满3x3的部分也会被池化,False表示边缘不满3x3的部分不会采样

    code:

    python 复制代码
    import torch
    from torch import nn
    from torch.nn import MaxPool2d
    
    input = torch.tensor([[1,2,0,3,1],
                          [0,1,2,3,1],
                          [1,2,1,0,0],
                          [5,2,3,1,1],
                          [2,1,0,1,1]
    ],dtype = torch.float32)
    #这里dtype为float是因为maxpool2d只能处理float类型的数据
    
    input = torch.reshape(input,(-1,1,5,5))
    print(input.shape)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            #ceil_mode=True表示边缘不满3x3的部分也会被池化
            #kernel_size=3 默认是卷积核的大小
            self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
            self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)
    
        def forward(self,input):
            #output = self.maxpool1(input)
            output = self.maxpool2(input)
            return output
    
    net = Net()
    output = net(input)
    print(output)

    ceil_mode=True:

    ceil_mode=False:

  2. 神经网络 池化层

    这里需要先看这篇博客:

https://blog.csdn.net/whdehcy/article/details/149486555?fromshare=blogdetail\&sharetype=blogdetail\&sharerId=149486555\&sharerefer=PC\&sharesource=whdehcy\&sharefrom=from_link

是讲卷积层的

现在将上一步的卷积得到的特征图作为池化的输入

python 复制代码
    pool_output = poolnet(conv_output)
    writer.add_images('pool_output',pool_output,cnt)

只需要添加一下池化的操作

python 复制代码
class poolNet(nn.Module):
    def __init__(self):
        super(poolNet,self).__init__()
        #ceil_mode=True表示边缘不满3x3的部分也会被池化
        #kernel_size=3 默认是卷积核的大小
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
        self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        #output = self.maxpool2(input)
        return output

poolnet = poolNet()

完整版代码:

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from torchvision import transforms

#数据预处理
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = [0.5,0.5,0.5],
        std = [0.5,0.5,0.5]
    )
])

#加载数据集
folder_path = '../images'
dataset = ImageFolder(folder_path,transform=transform)
dataloader = DataLoader(dataset,batch_size=1)

#卷积
class convNet(nn.Module):
    def __init__(self):
        #调用父类nn.Module的构造函数
        super(convNet,self).__init__()
        self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

convnet = convNet()

#池化
class poolNet(nn.Module):
    def __init__(self):
        super(poolNet,self).__init__()
        #ceil_mode=True表示边缘不满3x3的部分也会被池化
        #kernel_size=3 默认是卷积核的大小
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
        self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        #output = self.maxpool2(input)
        return output

poolnet = poolNet()

writer = SummaryWriter('../logs')

cnt = 0
for data in dataloader:
    img,label = data
    print(img.shape)
    conv_output = convnet(img)
    print(conv_output.shape)
    writer.add_images('input',img,cnt)
    conv_output = torch.reshape(conv_output,(-1,3,222,222))
    writer.add_images('conv_output',conv_output,cnt)
    pool_output = poolnet(conv_output)
    writer.add_images('pool_output',pool_output,cnt)
    cnt = cnt + 1

writer.close()

卷积:

池化:

相关推荐
神码小Z几秒前
特斯拉前AI总监开源的一款“小型本地版ChatGPT”,普通家用电脑就能运行!
人工智能·chatgpt
IT_陈寒几秒前
Redis性能翻倍的7个冷门技巧:从P5到P8都在偷偷用的优化策略!
前端·人工智能·后端
AKAMAI4 分钟前
直播监控的生死时速:深夜告警引发的系统崩溃危机
人工智能·云计算·直播
B站计算机毕业设计之家5 分钟前
深度学习实战:python动物识别分类检测系统 计算机视觉 Django框架 CNN算法 深度学习 卷积神经网络 TensorFlow 毕业设计(建议收藏)✅
python·深度学习·算法·计算机视觉·分类·毕业设计·动物识别
程序猿小D11 分钟前
【完整源码+数据集+部署教程】 【运输&加载码头】仓库新卸物料检测系统源码&数据集全套:改进yolo11-DRBNCSPELAN
python·yolo·计算机视觉·目标跟踪·数据集·yolo11·仓库新卸物料检测系统
SiYuanFeng23 分钟前
《Synthetic Visual Genome》论文数据集的预处理
python·场景图
MUTA️24 分钟前
python中进程和线程
python
CoovallyAIHub30 分钟前
顶刊新发!上海交大提出PreCM:即插即用的旋转等变卷积,显著提升分割模型鲁棒性
人工智能·算法·计算机视觉
Francek Chen40 分钟前
【深度学习计算机视觉】12:风格迁移
人工智能·pytorch·深度学习·计算机视觉·风格迁移
拓端研究室41 分钟前
专题:2025年AI Agent智能体行业价值及应用分析报告:核心趋势、经济影响与治理框架|附700+份报告PDF、数据仪表盘汇总下载
人工智能