CoordAtt注意力网络结构

源码:

python 复制代码
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out



def CA_onnx_gen():
    conv=CoordAtt(64,64)
    dummy_input = torch.randn(8,64, 128, 128)
    out=conv(dummy_input)
    print(out.shape)
 
    print(conv)
    # conv.load_state_dict(checkpoint)
    conv.eval()
    input_names = ["input"]
    output_names = ["output"]
    torch.onnx.export(conv, dummy_input, "CA.onnx", verbose=True, opset_version=13,input_names=input_names,
                      output_names=output_names)


if __name__=="__main__":
    CA_onnx_gen()

onnx结构:

相关推荐
阿拉斯攀登2 小时前
YOLO 视觉检测全栈核心名词指南:从训练调参到边缘部署,商用落地必懂
人工智能·yolo·计算机视觉·视觉检测·bytetrack
科研实践课堂(小绿书)2 小时前
机器学习在智能水泥基复合材料中的应用与实践
人工智能·机器学习·复合材料·水泥基·混凝土
墨韵流芳3 小时前
CCF-CSP第41次认证第三题——进程通信
c++·人工智能·算法·机器学习·csp·ccf
君科程序定做3 小时前
多源遥感与深度学习视角下耕地识别与耕地监测的局限性、研究空白与科学问题
人工智能·深度学习
七夜zippoe3 小时前
可解释AI:构建可信的机器学习系统——反事实解释与概念激活实战
人工智能·python·机器学习·可解释性·概念激活
AI先驱体验官5 小时前
智能体变现:从技术实现到产品化的实践路径
大数据·人工智能·深度学习·重构·aigc
Zero5 小时前
机器学习概率论与统计学--(8)概率论:数字特征
机器学习·概率论·随机变量·统计学·方差·协方差·期望
Zero5 小时前
机器学习概率论与统计学--(9)统计学:参数估计
机器学习·概率论·统计学·矩估计·最大似然估计·点估计
纪伊路上盛名在5 小时前
机器学习中的固定随机种子方案
人工智能·机器学习·数据分析·随机种子
龙腾AI白云6 小时前
什么是AI智能体(AI Agent)
人工智能·深度学习·自然语言处理·数据分析