EPSANet2021笔记


来源:

EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network

相关工作:

#注意力机制 #多尺度特征表示

创新点:

贡献:

  1. 建立了长距离通道依赖关系
  2. 有效获取利用不同尺度特征图的空间信息

问题:

  • 作者提供代码和文章描述处理过程不一致
  • 在小样本上训练测试效果不佳

代码:

python 复制代码
# ---------------------------------------  
# 论文: EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network (AICV 2021)  
# Github:https://github.com/murufeng/EPSANet  
# ---------------------------------------  
import torch  
from torch import nn  
  
  
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):  
    """standard convolution with padding"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,  
                     padding=padding, dilation=dilation, groups=groups, bias=False)  
  
  
def conv1x1(in_planes, out_planes, stride=1):  
    """1x1 convolution"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)  
  
  
class SEWeightModule(nn.Module):  
  
    def __init__(self, channels, reduction=16):  
        super(SEWeightModule, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)  
        self.relu = nn.ReLU(inplace=True)  
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        out = self.avg_pool(x)  
        out = self.fc1(out)  
        out = self.relu(out)  
        out = self.fc2(out)  
        weight = self.sigmoid(out)  
  
        return weight  
  
  
class PSAModule(nn.Module):  
  
    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):  
        super(PSAModule, self).__init__()  
        self.conv_1 = conv(inplans, planes // 4, kernel_size=conv_kernels[0], padding=conv_kernels[0] // 2,  
                           stride=stride, groups=conv_groups[0])  
        self.conv_2 = conv(inplans, planes // 4, kernel_size=conv_kernels[1], padding=conv_kernels[1] // 2,  
                           stride=stride, groups=conv_groups[1])  
        self.conv_3 = conv(inplans, planes // 4, kernel_size=conv_kernels[2], padding=conv_kernels[2] // 2,  
                           stride=stride, groups=conv_groups[2])  
        self.conv_4 = conv(inplans, planes // 4, kernel_size=conv_kernels[3], padding=conv_kernels[3] // 2,  
                           stride=stride, groups=conv_groups[3])  
        self.se = SEWeightModule(planes // 4)  
        self.split_channel = planes // 4  
        self.softmax = nn.Softmax(dim=1)  
  
    def forward(self, x):  
        batch_size = x.shape[0]  
        x1 = self.conv_1(x)  
        x2 = self.conv_2(x)  
        x3 = self.conv_3(x)  
        x4 = self.conv_4(x)  
  
        feats = torch.cat((x1, x2, x3, x4), dim=1)  
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])  
  
        x1_se = self.se(x1)  
        x2_se = self.se(x2)  
        x3_se = self.se(x3)  
        x4_se = self.se(x4)  
  
        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)  
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)  
        attention_vectors = self.softmax(attention_vectors)  
        feats_weight = feats * attention_vectors  
        for i in range(4):  
            x_se_weight_fp = feats_weight[:, i, :, :]  
            if i == 0:  
                out = x_se_weight_fp  
            else:  
                out = torch.cat((x_se_weight_fp, out), 1)  
  
        return out  
  
  
#   输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    input = torch.randn(3, 64, 32, 32)  
    s2att = PSAModule(inplans=64, planes=64)  
    output = s2att(input)  
    print(output.shape)
    
相关推荐
jackson凌3 小时前
【Java学习笔记】SringBuffer类(重点)
java·笔记·学习
huangyuchi.3 小时前
【Linux】LInux下第一个程序:进度条
linux·运维·服务器·笔记·进度条·c/c++
大写-凌祁5 小时前
论文阅读:HySCDG生成式数据处理流程
论文阅读·人工智能·笔记·python·机器学习
Unpredictable2225 小时前
【VINS-Mono算法深度解析:边缘化策略、初始化与关键技术】
c++·笔记·算法·ubuntu·计算机视觉
傍晚冰川5 小时前
FreeRTOS任务调度过程vTaskStartScheduler()&任务设计和划分
开发语言·笔记·stm32·单片机·嵌入式硬件·学习
Love__Tay7 小时前
【学习笔记】Python金融基础
开发语言·笔记·python·学习·金融
半导体守望者7 小时前
ADVANTEST R3764 66 R3765 67爱德万测试networki connection programming网络程序设计手册
经验分享·笔记·功能测试·自动化·制造
柠石榴8 小时前
【论文阅读笔记】《A survey on deep learning approaches for text-to-SQL》
论文阅读·笔记·深度学习·nlp·text-to-sql
田梓燊9 小时前
数学复习笔记 27
笔记
Lester_11019 小时前
嵌入式学习笔记 - freeRTOS xTaskResumeAll( )函数解析
笔记·stm32·单片机·学习·freertos