最大池化pytorch

**前置知识:

1、

复制代码
self.maxpool_2=MaxPool2d(kernel_size=3,ceil_mode=True)
复制代码
output=self.maxpool_2(input)

输入:张量的形状是(N,C,H,W)或(C,H,W)

  • Input: (N,C,Hin,Win)or (C,Hin,Win)

  • Output: (N,C,Hout,Wout)or (C,Hout,Wout)

参数:

  • 池化核(池化窗口)大小:kernel_size (Union[ int, Tuple[ int, int] ]) -- the size of the window to take a max over

  • 步长:stride (Union[ int, Tuple[ int, int] ] ) -- the stride of the window. Default value is kernel_size(默认是池化核的大小)

  • 补边缘padding (Union[ int, Tuple[ int, int] ]) -- Implicit negative infinity padding to be added on both sides

  • 取整方式:ceil_mode (bool) -- when True, will use ceil instead of floor to compute the output shape(True:向上取整,保留不足的部分;False:向下取整,去除不足一份的部分)

  • 空洞卷积dilation (Union[ int, Tuple[ int, int] ]) -- a parameter that controls the stride of elements in the window

2、池化的作用:

从特征图中提取最有代表性的特征;防止过拟合,实现降维;保持平移不变性。

(即保留重要特征,同时减少数据量,使模型训练得更快 eg: 1080P高清------>720P高清)

**代码:

1、对单一二维矩阵进行最大池化:

input 单一二维矩阵reshape(变成3D或4D)------>nn 创建神经元------>output 计算并输出

python 复制代码
import torch
from torch import nn
from torch.nn import MaxPool2d

input=torch.tensor([
    [1,2,0,3,1],
    [0,1,2,3,1],
    [1,2,1,0,0],
    [5,2,3,1,1],
    [2,1,0,1,1] #dtype=torch.float32把整数变成小数
])

input=torch.reshape(input,(-1,1,5,5)) #-1是占位符,后续自动计算batch_size的大小
print(input.shape)

#神经元
class Xigua(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool_2=MaxPool2d(kernel_size=3,ceil_mode=True) #保留不足的部分,也把它算进去

    def forward(self,input):
        output=self.maxpool_2(input)
        return output

xigua1=Xigua()
output=xigua1(input)
print(output)

2、对RGB图像进行池化:

input 导入并加载RGB图片数据集------>nn 创建神经元------>output 计算并记录

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

dataloader=DataLoader(test_set,batch_size=64)

#神经元
class Xigua(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool_2=MaxPool2d(kernel_size=3,ceil_mode=True) #保留不足的部分,也把它算进去

    def forward(self,input):
        output=self.maxpool_2(input)
        return output

xigua1=Xigua()
writer=SummaryWriter("logs2")
step=1
for imgs,targets in dataloader:
    print(imgs.shape)
    writer.add_images("input",imgs,step)
    imgs=xigua1(imgs)
    print(imgs.shape)
    writer.add_images("output",imgs,step)
    step=step+1
    if step>=3:
        break
writer.close()
相关推荐
IT古董28 分钟前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师1 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)2 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
小王子10242 小时前
设计模式Python版 组合模式
python·设计模式·组合模式
kakaZhui3 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20253 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥4 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin4 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客4 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
云空5 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析