Python day46

@浙大疏锦行 Python day46.

  • 注意力机制可以理解为对输入特征进行加权求和,注意力权重也是学习到的,类似于卷积,不过卷积的权重一般时是固定的,而注意力机制的权重根据输入数据不同权重也不同;
  • 常见的注意力模块有自注意力、通道注意力、空间注意力、多头注意力集以编码器-解码器注意力;
  • 通道注意力机制的执行过程为,先压缩空间维度只保留通道信息,接下来通过全连接层学习通道之间的权重信息,最后进行相应的加权操作即可;
python 复制代码
class ChannelAttention(nn.Module):
    """通道注意力模块(Squeeze-and-Excitation)"""
    def __init__(self, in_channels, reduction_ratio=16):
        """
        参数:
            in_channels: 输入特征图的通道数
            reduction_ratio: 降维比例,用于减少参数量
        """
        super(ChannelAttention, self).__init__()
        
        # 全局平均池化 - 将空间维度压缩为1x1,保留通道信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 全连接层 + 激活函数,用于学习通道间的依赖关系
        self.fc = nn.Sequential(
            # 降维:压缩通道数,减少计算量
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            # 升维:恢复原始通道数
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
            # Sigmoid将输出值归一化到[0,1],表示通道重要性权重
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        
        返回:
            加权后的特征图,形状不变
        """
        batch_size, channels, height, width = x.size()
        
        # 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1]
        avg_pool_output = self.avg_pool(x)
        
        # 2. 展平为一维向量:[batch_size, channels, 1, 1] → [batch_size, channels]
        avg_pool_output = avg_pool_output.view(batch_size, channels)
        
        # 3. 通过全连接层学习通道权重:[batch_size, channels] → [batch_size, channels]
        channel_weights = self.fc(avg_pool_output)
        
        # 4. 重塑为二维张量:[batch_size, channels] → [batch_size, channels, 1, 1]
        channel_weights = channel_weights.view(batch_size, channels, 1, 1)
        
        # 5. 将权重应用到原始特征图上(逐通道相乘)
        return x * channel_weights  # 输出形状:[batch_size, channels, height, width]
相关推荐
Coding茶水间9 分钟前
基于深度学习的非机动车头盔检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
Rose sait19 分钟前
【环境配置】Linux配置虚拟环境pytorch
linux·人工智能·python
过期动态1 小时前
JDBC高级篇:优化、封装与事务全流程指南
android·java·开发语言·数据库·python·mysql
baby_hua1 小时前
20251024_PyTorch深度学习快速入门教程
人工智能·pytorch·深度学习
brave and determined1 小时前
CANN训练营 学习(day9)昇腾AscendC算子开发实战:从零到性能冠军
人工智能·算法·机器学习·ai·开发环境·算子开发·昇腾ai
一世琉璃白_Y1 小时前
pg配置国内数据源安装
linux·python·postgresql·centos
liwulin05061 小时前
【PYTHON】COCO数据集中的物品ID
开发语言·python
小鸡吃米…1 小时前
Python - XML 处理
xml·开发语言·python·开源
我赵帅的飞起1 小时前
python国密SM4加解密
python·sm4加解密·国密sm4加解密
yaoh.wang2 小时前
力扣(LeetCode) 1: 两数之和 - 解法思路
python·程序人生·算法·leetcode·面试·跳槽·哈希算法