EPSANet2021笔记


来源:

EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network

相关工作:

#注意力机制 #多尺度特征表示

创新点:

贡献:

  1. 建立了长距离通道依赖关系
  2. 有效获取利用不同尺度特征图的空间信息

问题:

  • 作者提供代码和文章描述处理过程不一致
  • 在小样本上训练测试效果不佳

代码:

python 复制代码
# ---------------------------------------  
# 论文: EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network (AICV 2021)  
# Github:https://github.com/murufeng/EPSANet  
# ---------------------------------------  
import torch  
from torch import nn  
  
  
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):  
    """standard convolution with padding"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,  
                     padding=padding, dilation=dilation, groups=groups, bias=False)  
  
  
def conv1x1(in_planes, out_planes, stride=1):  
    """1x1 convolution"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)  
  
  
class SEWeightModule(nn.Module):  
  
    def __init__(self, channels, reduction=16):  
        super(SEWeightModule, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)  
        self.relu = nn.ReLU(inplace=True)  
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        out = self.avg_pool(x)  
        out = self.fc1(out)  
        out = self.relu(out)  
        out = self.fc2(out)  
        weight = self.sigmoid(out)  
  
        return weight  
  
  
class PSAModule(nn.Module):  
  
    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):  
        super(PSAModule, self).__init__()  
        self.conv_1 = conv(inplans, planes // 4, kernel_size=conv_kernels[0], padding=conv_kernels[0] // 2,  
                           stride=stride, groups=conv_groups[0])  
        self.conv_2 = conv(inplans, planes // 4, kernel_size=conv_kernels[1], padding=conv_kernels[1] // 2,  
                           stride=stride, groups=conv_groups[1])  
        self.conv_3 = conv(inplans, planes // 4, kernel_size=conv_kernels[2], padding=conv_kernels[2] // 2,  
                           stride=stride, groups=conv_groups[2])  
        self.conv_4 = conv(inplans, planes // 4, kernel_size=conv_kernels[3], padding=conv_kernels[3] // 2,  
                           stride=stride, groups=conv_groups[3])  
        self.se = SEWeightModule(planes // 4)  
        self.split_channel = planes // 4  
        self.softmax = nn.Softmax(dim=1)  
  
    def forward(self, x):  
        batch_size = x.shape[0]  
        x1 = self.conv_1(x)  
        x2 = self.conv_2(x)  
        x3 = self.conv_3(x)  
        x4 = self.conv_4(x)  
  
        feats = torch.cat((x1, x2, x3, x4), dim=1)  
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])  
  
        x1_se = self.se(x1)  
        x2_se = self.se(x2)  
        x3_se = self.se(x3)  
        x4_se = self.se(x4)  
  
        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)  
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)  
        attention_vectors = self.softmax(attention_vectors)  
        feats_weight = feats * attention_vectors  
        for i in range(4):  
            x_se_weight_fp = feats_weight[:, i, :, :]  
            if i == 0:  
                out = x_se_weight_fp  
            else:  
                out = torch.cat((x_se_weight_fp, out), 1)  
  
        return out  
  
  
#   输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    input = torch.randn(3, 64, 32, 32)  
    s2att = PSAModule(inplans=64, planes=64)  
    output = s2att(input)  
    print(output.shape)
    
相关推荐
m0_689618281 小时前
从海洋生物找灵感:造个机器人RoboPteropod,它能在水下干啥?
笔记·机器人
龙湾开发3 小时前
轻量级高性能推理引擎MNN 学习笔记 02.MNN主要API
人工智能·笔记·学习·机器学习·mnn
HappyAcmen3 小时前
线代第二章矩阵第八节逆矩阵、解矩阵方程
笔记·学习·线性代数·矩阵
愚润求学5 小时前
【递归、搜索与回溯】专题一:递归(二)
c++·笔记·算法·leetcode
愚润求学5 小时前
【Linux】基础 IO(一)
linux·运维·服务器·开发语言·c++·笔记
Wallace Zhang6 小时前
STM32F103_LL库+寄存器学习笔记22 - 基础定时器TIM实现1ms周期回调
笔记·stm32·学习
大白的编程日记.6 小时前
【Linux学习笔记】理解一切皆文件实现原理和文件缓冲区
linux·笔记·学习
孞㐑¥6 小时前
Linux之进程控制
linux·开发语言·c++·经验分享·笔记
Alessio Micheli6 小时前
奇怪的公式
笔记·线性代数
愚润求学6 小时前
【Linux】简单设计libc库
linux·运维·开发语言·c++·笔记