【深度学习注意力机制系列】—— SENet注意力机制(附pytorch实现)

深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息,提高模型的性能和泛化能力。

卷积神经网络引入的注意力机制主要有以下几种方法:

  • 在空间维度上增加注意力机制
  • 在通道维度上增加注意力机制
  • 在两者的混合维度上增加注意力机制

我们将在本系列对多种注意力机制进行讲解,并使用pytorch进行实现,今天我们讲解SENet注意力机制

SENet(Squeeze-and-Excitation Networks)注意力机制通道维度上引入注意力机制,其核心思想在于通过网络根据loss去学习特征权重,使得有效的feature map权重大,无效或效果小的feature map权重小的方式训练模型达到更好的结果。SE block嵌在原有的一些分类网络中不可避免地增加了一些参数和计算量,但是在效果面前还是可以接受的 。Sequeeze-and-Excitation(SE) block并不是一个完整的网络结构,而是一个子结构,可以嵌到其他分类或检测模型中。

以上是SENet的结构示意图, 其关键操作为squeeze和excitation. 通过自动学习获得特征图在每个通道上的重要程度,以此为不同通道赋予不同的权重,提升有用通道的贡献程度.

实现机制:

  1. Squeeze: 通过全剧平均池化层,将每个通道大的二维特征(h*w)压缩为一个实数,维度变化: (C, H, W) -> (C, 1, 1)
  2. Excitation: 给予每个通道的一个特征权重, 然后经过两次全连接层的信息整合提取,构建通道间的自相关性,输出权重数目和特征图通道数一致, 维度变化: (C, 1, 1) -> (C, 1, 1)
  3. Scale: 将归一化后的权重加权道每个通道的特征上, 论文中使用的是相乘加权, 维度变化: (C, H, W) * (C, 1, 1) -> (C, H, W)

pytorch实现:

python 复制代码
class SENet(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super(SENet, self).__init__()
        self.in_channels = in_channels
        self.fgp = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(self.in_channels, int(self.in_channels / ratio), bias=False)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(int(self.in_channels / ratio), self.in_channels, bias=False)
        self.act2 = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        output = self.fgp(x)
        output = output.view(b, c)
        output = self.fc1(output)
        output = self.act1(output)
        output = self.fc2(output)
        output = self.act2(output)
        output = output.view(b, c, 1, 1)
        return torch.multiply(x, output)
相关推荐
roman_日积跬步-终至千里17 分钟前
【强化学习基础(5)】策略搜索与学徒学习:从专家行为中学习加速学习过程
人工智能
梁正雄43 分钟前
2、Python流程控制
开发语言·python
Eric.Lee20212 小时前
ubuntu 安装 Miniconda
linux·运维·python·ubuntu·miniconda
无心水2 小时前
【Python实战进阶】1、Python高手养成指南:四阶段突破法从入门到架构师
开发语言·python·django·matplotlib·gil·python实战进阶·python工程化实战进阶
杭州泽沃电子科技有限公司2 小时前
在线监测:为医药精细化工奠定安全、合规与质量基石
运维·人工智能·物联网·安全·智能监测
GIS数据转换器2 小时前
GIS+大模型助力安全风险精细化管理
大数据·网络·人工智能·安全·无人机
李剑一2 小时前
Python学习笔记1
python
OJAC1112 小时前
AI跨界潮:金融精英与应届生正涌入人工智能领域
人工智能·金融
机器之心2 小时前
Adam的稳+Muon的快?华为诺亚开源ROOT破解大模型训练「既要又要」的两难困境
人工智能·openai
可观测性用观测云3 小时前
观测云 MCP Server 接入和使用最佳实践
人工智能