EPSANet2021笔记


来源:

EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network

相关工作:

#注意力机制 #多尺度特征表示

创新点:

贡献:

  1. 建立了长距离通道依赖关系
  2. 有效获取利用不同尺度特征图的空间信息

问题:

  • 作者提供代码和文章描述处理过程不一致
  • 在小样本上训练测试效果不佳

代码:

python 复制代码
# ---------------------------------------  
# 论文: EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network (AICV 2021)  
# Github:https://github.com/murufeng/EPSANet  
# ---------------------------------------  
import torch  
from torch import nn  
  
  
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):  
    """standard convolution with padding"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,  
                     padding=padding, dilation=dilation, groups=groups, bias=False)  
  
  
def conv1x1(in_planes, out_planes, stride=1):  
    """1x1 convolution"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)  
  
  
class SEWeightModule(nn.Module):  
  
    def __init__(self, channels, reduction=16):  
        super(SEWeightModule, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)  
        self.relu = nn.ReLU(inplace=True)  
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        out = self.avg_pool(x)  
        out = self.fc1(out)  
        out = self.relu(out)  
        out = self.fc2(out)  
        weight = self.sigmoid(out)  
  
        return weight  
  
  
class PSAModule(nn.Module):  
  
    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):  
        super(PSAModule, self).__init__()  
        self.conv_1 = conv(inplans, planes // 4, kernel_size=conv_kernels[0], padding=conv_kernels[0] // 2,  
                           stride=stride, groups=conv_groups[0])  
        self.conv_2 = conv(inplans, planes // 4, kernel_size=conv_kernels[1], padding=conv_kernels[1] // 2,  
                           stride=stride, groups=conv_groups[1])  
        self.conv_3 = conv(inplans, planes // 4, kernel_size=conv_kernels[2], padding=conv_kernels[2] // 2,  
                           stride=stride, groups=conv_groups[2])  
        self.conv_4 = conv(inplans, planes // 4, kernel_size=conv_kernels[3], padding=conv_kernels[3] // 2,  
                           stride=stride, groups=conv_groups[3])  
        self.se = SEWeightModule(planes // 4)  
        self.split_channel = planes // 4  
        self.softmax = nn.Softmax(dim=1)  
  
    def forward(self, x):  
        batch_size = x.shape[0]  
        x1 = self.conv_1(x)  
        x2 = self.conv_2(x)  
        x3 = self.conv_3(x)  
        x4 = self.conv_4(x)  
  
        feats = torch.cat((x1, x2, x3, x4), dim=1)  
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])  
  
        x1_se = self.se(x1)  
        x2_se = self.se(x2)  
        x3_se = self.se(x3)  
        x4_se = self.se(x4)  
  
        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)  
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)  
        attention_vectors = self.softmax(attention_vectors)  
        feats_weight = feats * attention_vectors  
        for i in range(4):  
            x_se_weight_fp = feats_weight[:, i, :, :]  
            if i == 0:  
                out = x_se_weight_fp  
            else:  
                out = torch.cat((x_se_weight_fp, out), 1)  
  
        return out  
  
  
#   输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    input = torch.randn(3, 64, 32, 32)  
    s2att = PSAModule(inplans=64, planes=64)  
    output = s2att(input)  
    print(output.shape)
    
相关推荐
AA陈超14 分钟前
ASC学习笔记0012:查找现有的属性集,如果不存在则断言
笔记·学习
wshzd39 分钟前
LLM之Agent(二十八)|AI音视频转笔记方法揭秘
人工智能·笔记
The_Second_Coming1 小时前
Python 学习笔记:基础篇
运维·笔记·python·学习
思成不止于此1 小时前
软考中级软件设计师备考指南(二):计算机体系结构与指令系统
笔记·学习·软件设计师
潇冉沐晴9 小时前
div2 1052 个人补题笔记
笔记
蒙奇D索大10 小时前
【计算机网络】[特殊字符] 408高频考点 | 数据链路层组帧:从字符计数到违规编码,一文学透四大实现方法
网络·笔记·学习·计算机网络·考研
njsgcs11 小时前
tekla 使用笔记 切管 分割指定长度的管
笔记·tekla
蒙奇D索大12 小时前
【算法】 递归实战应用:从暴力迭代到快速幂的优化之路
笔记·考研·算法·改行学it
('-')13 小时前
《从根上理解MySQL》第一章学习笔记
笔记·学习·mysql