Python day46

@浙大疏锦行 Python day46.

  • 注意力机制可以理解为对输入特征进行加权求和,注意力权重也是学习到的,类似于卷积,不过卷积的权重一般时是固定的,而注意力机制的权重根据输入数据不同权重也不同;
  • 常见的注意力模块有自注意力、通道注意力、空间注意力、多头注意力集以编码器-解码器注意力;
  • 通道注意力机制的执行过程为,先压缩空间维度只保留通道信息,接下来通过全连接层学习通道之间的权重信息,最后进行相应的加权操作即可;
python 复制代码
class ChannelAttention(nn.Module):
    """通道注意力模块(Squeeze-and-Excitation)"""
    def __init__(self, in_channels, reduction_ratio=16):
        """
        参数:
            in_channels: 输入特征图的通道数
            reduction_ratio: 降维比例,用于减少参数量
        """
        super(ChannelAttention, self).__init__()
        
        # 全局平均池化 - 将空间维度压缩为1x1,保留通道信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 全连接层 + 激活函数,用于学习通道间的依赖关系
        self.fc = nn.Sequential(
            # 降维:压缩通道数,减少计算量
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            # 升维:恢复原始通道数
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
            # Sigmoid将输出值归一化到[0,1],表示通道重要性权重
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        
        返回:
            加权后的特征图,形状不变
        """
        batch_size, channels, height, width = x.size()
        
        # 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1]
        avg_pool_output = self.avg_pool(x)
        
        # 2. 展平为一维向量:[batch_size, channels, 1, 1] → [batch_size, channels]
        avg_pool_output = avg_pool_output.view(batch_size, channels)
        
        # 3. 通过全连接层学习通道权重:[batch_size, channels] → [batch_size, channels]
        channel_weights = self.fc(avg_pool_output)
        
        # 4. 重塑为二维张量:[batch_size, channels] → [batch_size, channels, 1, 1]
        channel_weights = channel_weights.view(batch_size, channels, 1, 1)
        
        # 5. 将权重应用到原始特征图上(逐通道相乘)
        return x * channel_weights  # 输出形状:[batch_size, channels, height, width]
相关推荐
跟橙姐学代码17 分钟前
学Python像学做人:从基础语法到人生哲理的成长之路
前端·python
Keying,,,,29 分钟前
力扣hot100 | 矩阵 | 73. 矩阵置零、54. 螺旋矩阵、48. 旋转图像、240. 搜索二维矩阵 II
python·算法·leetcode·矩阵
桃源学社(接毕设)1 小时前
基于人工智能和物联网融合跌倒监控系统(LW+源码+讲解+部署)
人工智能·python·单片机·yolov8
yunhuibin1 小时前
pycharm2025导入anaconda创建的各个AI环境
人工智能·python
杨荧1 小时前
基于Python的电影评论数据分析系统 Python+Django+Vue.js
大数据·前端·vue.js·python
2502_927161281 小时前
DAY 40 训练和测试的规范写法
人工智能·深度学习·机器学习
python-行者2 小时前
akamai鼠标轨迹
爬虫·python·计算机外设·akamai
R-G-B2 小时前
【P14 3-6 】OpenCV Python——视频加载、摄像头调用、视频基本信息获取(宽、高、帧率、总帧数)
python·opencv·视频加载·摄像头调用·获取视频基本信息·获取视频帧率·获取视频帧数
Monkey PilotX2 小时前
机器人“ChatGPT 时刻”倒计时
人工智能·机器学习·计算机视觉·自动驾驶