21 - GAM模块

论文《Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions》

1、作用

这篇论文提出了全局注意力机制(Global Attention Mechanism, GAM),旨在通过保留通道和空间方面的信息来增强跨维度交互,从而提升深度神经网络的性能。GAM通过引入3D排列与多层感知器(MLP)用于通道注意力,并辅以卷积空间注意力子模块,提高了图像分类任务的表现。该方法在CIFAR-100和ImageNet-1K数据集上的图像分类任务中均稳定地超越了几种最新的注意力机制,包括在ResNet和轻量级MobileNet模型上的应用。

2、机制

1、通道注意力子模块

利用3D排列保留跨三个维度的信息,并通过两层MLP放大跨维度的通道-空间依赖性。这个子模块通过编码器-解码器结构,以一个缩减比例r(与BAM相同)来实现。

2、空间注意力子模块

为了聚焦空间信息,使用了两个卷积层进行空间信息的融合。同时,为了进一步保留特征图,移除了池化操作。此外,为了避免参数数量显著增加,当应用于ResNet50时,采用了分组卷积与通道混洗。

3、独特优势

1、效率与灵活性

GAM展示了与现有的高效SR方法相比,如IMDN,其模型大小小了3倍,同时实现了可比的性能,展现了在内存使用上的高效性。

2、动态空间调制

通过利用独立学习的多尺度特征表示并动态地进行空间调制,GAM能够高效地聚合特征,提升重建性能,同时保持低计算和存储成本。

3、有效整合局部和非局部特征

GAM通过其层和CCM的结合,有效地整合了局部和非局部特征信息,实现了更精确的图像超分辨率重建。

4、代码

python 复制代码
import torch.nn as nn
import torch

class GAM_Attention(nn.Module):
    def __init__(self, in_channels, rate=4):
        super(GAM_Attention, self).__init__()

        # 通道注意力子模块
        self.channel_attention = nn.Sequential(
            # 降维,减少参数数量和计算复杂度
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),  # 非线性激活
            # 升维,恢复到原始通道数
            nn.Linear(int(in_channels / rate), in_channels)
        )

        # 空间注意力子模块
        self.spatial_attention = nn.Sequential(
            # 使用7x7卷积核进行空间特征的降维处理
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),  # 批归一化,加速收敛,提升稳定性
            nn.ReLU(inplace=True),  # 非线性激活
            # 使用7x7卷积核进行空间特征的升维处理
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)  # 批归一化
        )

    def forward(self, x):
        b, c, h, w = x.shape  # 输入张量的维度信息
        # 调整张量形状以适配通道注意力处理
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        # 应用通道注意力,并恢复原始张量形状
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        # 生成通道注意力图
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()

        # 应用通道注意力图进行特征加权
        x = x * x_channel_att

        # 生成空间注意力图并应用进行特征加权
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att

        return out

# 示例代码:使用GAM_Attention对一个随机初始化的张量进行处理
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)  # 随机生成输入张量
    b, c, h, w = x.shape  # 获取输入张量的维度信息
    net = GAM_Attention(in_channels=c)  # 实例化GAM_Attention模块
    y = net(x)  # 通过GAM_Attention模块处理输入张量
    print(y.shape)  # 打印输出张量的维度信息
相关推荐
Wendy14413 分钟前
【灰度实验】——图像预处理(OpenCV)
人工智能·opencv·计算机视觉
中杯可乐多加冰15 分钟前
五大低代码平台横向深度测评:smardaten 2.0领衔AI原型设计
人工智能
无线图像传输研究探索25 分钟前
单兵图传终端:移动场景中的 “实时感知神经”
网络·人工智能·5g·无线图传·5g单兵图传
zzywxc7872 小时前
AI在编程、测试、数据分析等领域的前沿应用(技术报告)
人工智能·深度学习·机器学习·数据挖掘·数据分析·自动化·ai编程
铭keny2 小时前
YOLOv8 基于RTSP流目标检测
人工智能·yolo·目标检测
墨尘游子2 小时前
11-大语言模型—Transformer 盖楼,BERT 装修,RoBERTa 直接 “拎包入住”|预训练白话指南
人工智能·语言模型·自然语言处理
金井PRATHAMA2 小时前
主要分布于内侧内嗅皮层的层Ⅲ的网格-速度联合细胞(Grid × Speed Conjunctive Cells)对NLP中的深层语义分析的积极影响和启示
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·知识图谱
天道哥哥3 小时前
InsightFace(RetinaFace + ArcFace)人脸识别项目(预训练模型,鲁棒性很好)
人工智能·目标检测
幻风_huanfeng3 小时前
学习人工智能所需知识体系及路径详解
人工智能·学习
云道轩3 小时前
使用Docker在Rocky Linux 9.5上在线部署LangFlow
linux·人工智能·docker·容器·langflow