EPSANet2021笔记


来源:

EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network

相关工作:

#注意力机制 #多尺度特征表示

创新点:

贡献:

  1. 建立了长距离通道依赖关系
  2. 有效获取利用不同尺度特征图的空间信息

问题:

  • 作者提供代码和文章描述处理过程不一致
  • 在小样本上训练测试效果不佳

代码:

python 复制代码
# ---------------------------------------  
# 论文: EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network (AICV 2021)  
# Github:https://github.com/murufeng/EPSANet  
# ---------------------------------------  
import torch  
from torch import nn  
  
  
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):  
    """standard convolution with padding"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,  
                     padding=padding, dilation=dilation, groups=groups, bias=False)  
  
  
def conv1x1(in_planes, out_planes, stride=1):  
    """1x1 convolution"""  
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)  
  
  
class SEWeightModule(nn.Module):  
  
    def __init__(self, channels, reduction=16):  
        super(SEWeightModule, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)  
        self.relu = nn.ReLU(inplace=True)  
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        out = self.avg_pool(x)  
        out = self.fc1(out)  
        out = self.relu(out)  
        out = self.fc2(out)  
        weight = self.sigmoid(out)  
  
        return weight  
  
  
class PSAModule(nn.Module):  
  
    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):  
        super(PSAModule, self).__init__()  
        self.conv_1 = conv(inplans, planes // 4, kernel_size=conv_kernels[0], padding=conv_kernels[0] // 2,  
                           stride=stride, groups=conv_groups[0])  
        self.conv_2 = conv(inplans, planes // 4, kernel_size=conv_kernels[1], padding=conv_kernels[1] // 2,  
                           stride=stride, groups=conv_groups[1])  
        self.conv_3 = conv(inplans, planes // 4, kernel_size=conv_kernels[2], padding=conv_kernels[2] // 2,  
                           stride=stride, groups=conv_groups[2])  
        self.conv_4 = conv(inplans, planes // 4, kernel_size=conv_kernels[3], padding=conv_kernels[3] // 2,  
                           stride=stride, groups=conv_groups[3])  
        self.se = SEWeightModule(planes // 4)  
        self.split_channel = planes // 4  
        self.softmax = nn.Softmax(dim=1)  
  
    def forward(self, x):  
        batch_size = x.shape[0]  
        x1 = self.conv_1(x)  
        x2 = self.conv_2(x)  
        x3 = self.conv_3(x)  
        x4 = self.conv_4(x)  
  
        feats = torch.cat((x1, x2, x3, x4), dim=1)  
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])  
  
        x1_se = self.se(x1)  
        x2_se = self.se(x2)  
        x3_se = self.se(x3)  
        x4_se = self.se(x4)  
  
        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)  
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)  
        attention_vectors = self.softmax(attention_vectors)  
        feats_weight = feats * attention_vectors  
        for i in range(4):  
            x_se_weight_fp = feats_weight[:, i, :, :]  
            if i == 0:  
                out = x_se_weight_fp  
            else:  
                out = torch.cat((x_se_weight_fp, out), 1)  
  
        return out  
  
  
#   输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    input = torch.randn(3, 64, 32, 32)  
    s2att = PSAModule(inplans=64, planes=64)  
    output = s2att(input)  
    print(output.shape)
    
相关推荐
東雪木7 小时前
多线程与并发编程 专属复习笔记
java·开发语言·笔记·java面试
Oll Correct7 小时前
实验二十九:TCP的运输连接管理
网络·笔记
飞翔中文网9 小时前
Java学习笔记之抽象类与接口(设计思想)
java·笔记·学习
智者知已应修善业9 小时前
【proteus设计文氏正弦波信号发生器】2023-5-9
驱动开发·经验分享·笔记·硬件架构·proteus·硬件工程
凉、介11 小时前
深入理解 ARMv8-A|处理器模式与寄存器
笔记·学习·嵌入式·arm
whyTeaFo12 小时前
MIT 6.1810: Lec 5: calling conventions and stack frames RISC-V
笔记
上课不要睡觉了12 小时前
【统计法规】4.1统计管理体制概述
笔记·统计师考试
墨白曦煜13 小时前
算法实战笔记:剥开回溯算法的外衣——从通用模板到高阶去重(八)
笔记·算法
Upsy-Daisy14 小时前
IOTA 学习笔记(四):当前 IOTA 架构总览
笔记·学习·架构
山楂树の14 小时前
JS中??和||的区别
笔记