EPSANet2021笔记


来源:

EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network

相关工作:

#注意力机制 #多尺度特征表示

创新点:

贡献:

  1. 建立了长距离通道依赖关系
  2. 有效获取利用不同尺度特征图的空间信息

问题:

  • 作者提供代码和文章描述处理过程不一致
  • 在小样本上训练测试效果不佳

代码:

python 复制代码
# ---------------------------------------  
# 论文: EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network (AICV 2021)  
# Github:https://github.com/murufeng/EPSANet  
# ---------------------------------------  
import torch  
from torch import nn  
  
  
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):  
    """standard convolution with padding"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,  
                     padding=padding, dilation=dilation, groups=groups, bias=False)  
  
  
def conv1x1(in_planes, out_planes, stride=1):  
    """1x1 convolution"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)  
  
  
class SEWeightModule(nn.Module):  
  
    def __init__(self, channels, reduction=16):  
        super(SEWeightModule, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)  
        self.relu = nn.ReLU(inplace=True)  
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        out = self.avg_pool(x)  
        out = self.fc1(out)  
        out = self.relu(out)  
        out = self.fc2(out)  
        weight = self.sigmoid(out)  
  
        return weight  
  
  
class PSAModule(nn.Module):  
  
    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):  
        super(PSAModule, self).__init__()  
        self.conv_1 = conv(inplans, planes // 4, kernel_size=conv_kernels[0], padding=conv_kernels[0] // 2,  
                           stride=stride, groups=conv_groups[0])  
        self.conv_2 = conv(inplans, planes // 4, kernel_size=conv_kernels[1], padding=conv_kernels[1] // 2,  
                           stride=stride, groups=conv_groups[1])  
        self.conv_3 = conv(inplans, planes // 4, kernel_size=conv_kernels[2], padding=conv_kernels[2] // 2,  
                           stride=stride, groups=conv_groups[2])  
        self.conv_4 = conv(inplans, planes // 4, kernel_size=conv_kernels[3], padding=conv_kernels[3] // 2,  
                           stride=stride, groups=conv_groups[3])  
        self.se = SEWeightModule(planes // 4)  
        self.split_channel = planes // 4  
        self.softmax = nn.Softmax(dim=1)  
  
    def forward(self, x):  
        batch_size = x.shape[0]  
        x1 = self.conv_1(x)  
        x2 = self.conv_2(x)  
        x3 = self.conv_3(x)  
        x4 = self.conv_4(x)  
  
        feats = torch.cat((x1, x2, x3, x4), dim=1)  
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])  
  
        x1_se = self.se(x1)  
        x2_se = self.se(x2)  
        x3_se = self.se(x3)  
        x4_se = self.se(x4)  
  
        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)  
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)  
        attention_vectors = self.softmax(attention_vectors)  
        feats_weight = feats * attention_vectors  
        for i in range(4):  
            x_se_weight_fp = feats_weight[:, i, :, :]  
            if i == 0:  
                out = x_se_weight_fp  
            else:  
                out = torch.cat((x_se_weight_fp, out), 1)  
  
        return out  
  
  
#   输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    input = torch.randn(3, 64, 32, 32)  
    s2att = PSAModule(inplans=64, planes=64)  
    output = s2att(input)  
    print(output.shape)
    
相关推荐
安冬的码畜日常3 小时前
【Vim Masterclass 笔记18】第八章 + S08L35:Vim 的可视化模式(二)
笔记·vim·自学笔记·vim可视化模式·vim可视模式·vim视觉模式
小殷要努力刷题!5 小时前
JavaWeb项目——如何处理管理员登录和退出——笔记
java·javascript·笔记·学习·servlet·javaweb·寒假
?-ldl6 小时前
docker使用笔记
笔记·docker·容器
月印千江6717 小时前
从密码学原理与应用新方向到移动身份认证与实践
经验分享·笔记·其他·网络安全·密码学
bohu837 小时前
opencv笔记2
人工智能·笔记·opencv
Pandaconda8 小时前
【新人系列】Python 入门(二十七):Python 库
开发语言·笔记·后端·python·面试··python库
小菜鸟博士8 小时前
大模型学习笔记 - 第一期 - Milvus向量数据库
数据库·笔记·学习·算法·milvus
索然无味io8 小时前
PHP基础--流程控制
前端·笔记·后端·学习·web安全·网络安全·php
Pandaconda8 小时前
【Golang 面试题】每日 3 题(三十六)
开发语言·经验分享·笔记·后端·面试·golang·go