对比tensorflow,从0开始学pytorch(三)--自定义层

上文虽然实现了GMS层的效果,但是前端代码太多,太ugly,也不好复用。今天抽空看了下pytorch中怎么自定义层,很简单,比tensorflow好用。

  1. 任意文件夹创建个文件,和所有编程语言一样
  1. 一样,集成nn.Module,然后自定义一个形参

这里需要花时间搞明白torch.nn.functional下的函数和torch.nn下的类的区别,一开始有点懵,想着为什么不做高级语言当中的静态函数,想明白了也就简单了。

图中的SPP_Sizes做了类型定义,python中一般情况不需要定义类型,但不定义在后面循环就会报错,看了下pytorch自带conv2d的源码,发现源码中非常严谨,每一个变量都定义了类型。

  1. 调用就非常简单了,上一篇笔记中的冗长的代码,就可以一行调用
  1. 简化后,代码看过去顺眼多了。附上GMS封装后的源码和训练结果
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class GMS(nn.Module):
    def __init__(self, Spp_Sizes:[]):
        super().__init__()
        if len(Spp_Sizes) == 0:
            self.SPP_Sizes = [2, 3, 4]
        else:
            self.SPP_Sizes = Spp_Sizes

    def forward(self, x):
        x_gap = F.adaptive_avg_pool2d(x, (1, 1))
        x_gap = torch.flatten(x_gap, 1)

        x_gmp = F.adaptive_max_pool2d(x, (1, 1))
        x_gmp = torch.flatten(x_gmp, 1)

        x_gms = torch.cat((x_gap, x_gmp), dim=1)

        for spp_size in self.SPP_Sizes:
            x_spp = F.adaptive_max_pool2d(x, (spp_size,spp_size))
            x_spp = torch.flatten(x_spp, 1)
            x_gms = torch.cat((x_gms, x_spp), dim=1)
        return x_gms
相关推荐
OpenBayes2 小时前
教程上新|DeepSeek-OCR 2公式/表格解析同步改善,以低视觉token成本实现近4%的性能跃迁
人工智能·深度学习·目标检测·机器学习·大模型·ocr·gpu算力
退休钓鱼选手2 小时前
[ Pytorch教程 ] 神经网络的基本骨架 torch.nn -Neural Network
pytorch·深度学习·神经网络
冰糖猕猴桃3 小时前
【AI】把“大杂烩抽取”拆成多步推理:一个从单提示到多阶段管线的实践案例
大数据·人工智能·ai·提示词·多步推理
PPIO派欧云3 小时前
PPIO上线GLM-OCR:0.9B参数SOTA性能,支持一键部署
人工智能·ai·大模型·ocr·智谱
雨大王5123 小时前
怎么打造一个能自我进化的制造数字基座?
人工智能·汽车·制造
fengfuyao9853 小时前
基于MATLAB的表面织构油润滑轴承故障频率提取(改进VMD算法)
人工智能·算法·matlab
爱吃泡芙的小白白3 小时前
深入解析CNN中的Dropout层:从基础原理到最新变体实战
人工智能·神经网络·cnn·dropout·防止过拟合
DeniuHe3 小时前
用 PyTorch 库创建了一个随机张量,并演示了多种张量取整和分解操作
pytorch
Eloudy3 小时前
全文 -- TileLang: A Composable Tiled Programming Model for AISystems
人工智能·量子计算·arch
才盛智能科技4 小时前
K链通×才盛云:自助KTV品牌从0到1孵化超简单
大数据·人工智能·物联网·自助ktv系统·才盛云