神经网络:池化层

  1. 神经网络 池化操作

    下采样:减少特征数量

    先看池化操作:

    其中ceil_mode设置成True和False池化区别:

    在这个例子当中:ceil_mode=True表示边缘不满3x3的部分也会被池化,False表示边缘不满3x3的部分不会采样

    code:

    python 复制代码
    import torch
    from torch import nn
    from torch.nn import MaxPool2d
    
    input = torch.tensor([[1,2,0,3,1],
                          [0,1,2,3,1],
                          [1,2,1,0,0],
                          [5,2,3,1,1],
                          [2,1,0,1,1]
    ],dtype = torch.float32)
    #这里dtype为float是因为maxpool2d只能处理float类型的数据
    
    input = torch.reshape(input,(-1,1,5,5))
    print(input.shape)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            #ceil_mode=True表示边缘不满3x3的部分也会被池化
            #kernel_size=3 默认是卷积核的大小
            self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
            self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)
    
        def forward(self,input):
            #output = self.maxpool1(input)
            output = self.maxpool2(input)
            return output
    
    net = Net()
    output = net(input)
    print(output)

    ceil_mode=True:

    ceil_mode=False:

  2. 神经网络 池化层

    这里需要先看这篇博客:

https://blog.csdn.net/whdehcy/article/details/149486555?fromshare=blogdetail\&sharetype=blogdetail\&sharerId=149486555\&sharerefer=PC\&sharesource=whdehcy\&sharefrom=from_link

是讲卷积层的

现在将上一步的卷积得到的特征图作为池化的输入

python 复制代码
    pool_output = poolnet(conv_output)
    writer.add_images('pool_output',pool_output,cnt)

只需要添加一下池化的操作

python 复制代码
class poolNet(nn.Module):
    def __init__(self):
        super(poolNet,self).__init__()
        #ceil_mode=True表示边缘不满3x3的部分也会被池化
        #kernel_size=3 默认是卷积核的大小
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
        self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        #output = self.maxpool2(input)
        return output

poolnet = poolNet()

完整版代码:

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from torchvision import transforms

#数据预处理
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = [0.5,0.5,0.5],
        std = [0.5,0.5,0.5]
    )
])

#加载数据集
folder_path = '../images'
dataset = ImageFolder(folder_path,transform=transform)
dataloader = DataLoader(dataset,batch_size=1)

#卷积
class convNet(nn.Module):
    def __init__(self):
        #调用父类nn.Module的构造函数
        super(convNet,self).__init__()
        self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

convnet = convNet()

#池化
class poolNet(nn.Module):
    def __init__(self):
        super(poolNet,self).__init__()
        #ceil_mode=True表示边缘不满3x3的部分也会被池化
        #kernel_size=3 默认是卷积核的大小
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
        self.maxpool2 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        #output = self.maxpool2(input)
        return output

poolnet = poolNet()

writer = SummaryWriter('../logs')

cnt = 0
for data in dataloader:
    img,label = data
    print(img.shape)
    conv_output = convnet(img)
    print(conv_output.shape)
    writer.add_images('input',img,cnt)
    conv_output = torch.reshape(conv_output,(-1,3,222,222))
    writer.add_images('conv_output',conv_output,cnt)
    pool_output = poolnet(conv_output)
    writer.add_images('pool_output',pool_output,cnt)
    cnt = cnt + 1

writer.close()

卷积:

池化:

相关推荐
白露与泡影1 小时前
2025年高质量Java面试真题汇总
java·python·面试
程序员三藏1 小时前
Fiddler抓取HTTPS
自动化测试·软件测试·python·测试工具·https·fiddler·接口测试
文心快码 Baidu Comate4 小时前
您的前端开发智能工作流待升级,查收最新 Figma2Code!
人工智能·ai编程·文心快码·ai ide·comate ai ide
gc_22994 小时前
学习Python中Selenium模块的基本用法(15:窗口操作)
python·selenium
AIminminHu5 小时前
实战项目(十二:《AI画质增强与LED驱动控制:一场关于‘创造’与‘还原’的对话》):从LED冬奥会、奥运会及春晚等大屏,到手机小屏,快来挖一挖里面都有什么
人工智能·智能手机
skywalk81635 小时前
在Windows10 Edge浏览器里安装DeepSider大模型插件来免费使用gpt-4o、NanoBanana等AI大模型
人工智能
顾道长生'5 小时前
(Arxiv-2025)OmniInsert:无遮罩视频插入任意参考通过扩散 Transformer 模型
深度学习·音视频·transformer
汽车仪器仪表相关领域6 小时前
工业安全新利器:NHQT-4四合一检测线系统深度解析
网络·数据库·人工智能·安全·汽车·检测站·汽车检测
有Li6 小时前
基于神经控制微分方程的采集无关深度学习用于定量MRI参数估计|文献速递-文献分享
论文阅读·人工智能·文献·医学生
keep_di6 小时前
06-django中配置接口文档coreapi
后端·python·django