对比tensorflow,从0开始学pytorch(三)--自定义层

上文虽然实现了GMS层的效果,但是前端代码太多,太ugly,也不好复用。今天抽空看了下pytorch中怎么自定义层,很简单,比tensorflow好用。

  1. 任意文件夹创建个文件,和所有编程语言一样
  1. 一样,集成nn.Module,然后自定义一个形参

这里需要花时间搞明白torch.nn.functional下的函数和torch.nn下的类的区别,一开始有点懵,想着为什么不做高级语言当中的静态函数,想明白了也就简单了。

图中的SPP_Sizes做了类型定义,python中一般情况不需要定义类型,但不定义在后面循环就会报错,看了下pytorch自带conv2d的源码,发现源码中非常严谨,每一个变量都定义了类型。

  1. 调用就非常简单了,上一篇笔记中的冗长的代码,就可以一行调用
  1. 简化后,代码看过去顺眼多了。附上GMS封装后的源码和训练结果
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class GMS(nn.Module):
    def __init__(self, Spp_Sizes:[]):
        super().__init__()
        if len(Spp_Sizes) == 0:
            self.SPP_Sizes = [2, 3, 4]
        else:
            self.SPP_Sizes = Spp_Sizes

    def forward(self, x):
        x_gap = F.adaptive_avg_pool2d(x, (1, 1))
        x_gap = torch.flatten(x_gap, 1)

        x_gmp = F.adaptive_max_pool2d(x, (1, 1))
        x_gmp = torch.flatten(x_gmp, 1)

        x_gms = torch.cat((x_gap, x_gmp), dim=1)

        for spp_size in self.SPP_Sizes:
            x_spp = F.adaptive_max_pool2d(x, (spp_size,spp_size))
            x_spp = torch.flatten(x_spp, 1)
            x_gms = torch.cat((x_gms, x_spp), dim=1)
        return x_gms
相关推荐
杜子不疼.10 小时前
计算机视觉热门模型手册:Spring Boot 3.2 自动装配新机制:@AutoConfiguration 使用指南
人工智能·spring boot·计算机视觉
无心水12 小时前
【分布式利器:腾讯TSF】7、TSF高级部署策略全解析:蓝绿/灰度发布落地+Jenkins CI/CD集成(Java微服务实战)
java·人工智能·分布式·ci/cd·微服务·jenkins·腾讯tsf
北辰alk17 小时前
RAG索引流程详解:如何高效解析文档构建知识库
人工智能
九河云17 小时前
海上风电“AI偏航对风”:把发电量提升2.1%,单台年增30万度
大数据·人工智能·数字化转型
wm104317 小时前
机器学习第二讲 KNN算法
人工智能·算法·机器学习
沈询-阿里18 小时前
Skills vs MCP:竞合关系还是互补?深入解析Function Calling、MCP和Skills的本质差异
人工智能·ai·agent·ai编程
xiaobai17818 小时前
测试工程师入门AI技术 - 前序:跨越焦虑,从优势出发开启学习之旅
人工智能·学习
盛世宏博北京18 小时前
云边协同・跨系统联动:智慧档案馆建设与功能落地
大数据·人工智能
TGITCIC19 小时前
讲透知识图谱Neo4j在构建Agent时到底怎么用(二)
人工智能·知识图谱·neo4j·ai agent·ai智能体·大模型落地·graphrag
逆羽飘扬19 小时前
DeepSeek-mHC深度拆解:流形约束如何驯服狂暴的超连接?
人工智能