Python day46

@浙大疏锦行 Python day46.

  • 注意力机制可以理解为对输入特征进行加权求和,注意力权重也是学习到的,类似于卷积,不过卷积的权重一般时是固定的,而注意力机制的权重根据输入数据不同权重也不同;
  • 常见的注意力模块有自注意力、通道注意力、空间注意力、多头注意力集以编码器-解码器注意力;
  • 通道注意力机制的执行过程为,先压缩空间维度只保留通道信息,接下来通过全连接层学习通道之间的权重信息,最后进行相应的加权操作即可;
python 复制代码
class ChannelAttention(nn.Module):
    """通道注意力模块(Squeeze-and-Excitation)"""
    def __init__(self, in_channels, reduction_ratio=16):
        """
        参数:
            in_channels: 输入特征图的通道数
            reduction_ratio: 降维比例,用于减少参数量
        """
        super(ChannelAttention, self).__init__()
        
        # 全局平均池化 - 将空间维度压缩为1x1,保留通道信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 全连接层 + 激活函数,用于学习通道间的依赖关系
        self.fc = nn.Sequential(
            # 降维:压缩通道数,减少计算量
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            # 升维:恢复原始通道数
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
            # Sigmoid将输出值归一化到[0,1],表示通道重要性权重
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        
        返回:
            加权后的特征图,形状不变
        """
        batch_size, channels, height, width = x.size()
        
        # 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1]
        avg_pool_output = self.avg_pool(x)
        
        # 2. 展平为一维向量:[batch_size, channels, 1, 1] → [batch_size, channels]
        avg_pool_output = avg_pool_output.view(batch_size, channels)
        
        # 3. 通过全连接层学习通道权重:[batch_size, channels] → [batch_size, channels]
        channel_weights = self.fc(avg_pool_output)
        
        # 4. 重塑为二维张量:[batch_size, channels] → [batch_size, channels, 1, 1]
        channel_weights = channel_weights.view(batch_size, channels, 1, 1)
        
        # 5. 将权重应用到原始特征图上(逐通道相乘)
        return x * channel_weights  # 输出形状:[batch_size, channels, height, width]
相关推荐
m0_7467523034 分钟前
c++怎么利用std--variant处理多种二进制子协议包的自动分支解析【进阶】
jvm·数据库·python
m0_734949798 小时前
MySQL如何配置定时清理过期备份文件_find命令与保留周期策略
jvm·数据库·python
思绪无限8 小时前
YOLOv5至YOLOv12升级:钢材表面缺陷检测系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·yolov12·yolo全家桶·钢材表面缺陷检测
m0_514520578 小时前
MySQL索引优化后性能没提升_通过EXPLAIN查看索引命中率
jvm·数据库·python
H Journey8 小时前
Python 国内pip install 安装缓慢
python·pip·install 加速
Jmayday9 小时前
机器学习基本理论
人工智能·机器学习
王_teacher9 小时前
机器学习 矩阵求导 完整公式+严谨推导
人工智能·线性代数·考研·机器学习·矩阵·线性回归
Polar__Star9 小时前
如何在 AWS Lambda 中正确使用临时凭证生成 S3 预签名 URL
jvm·数据库·python
m0_7436239210 小时前
React 自定义 Hook 的命名规范与调用规则详解
jvm·数据库·python
xiaotao13110 小时前
02-机器学习基础: 无监督学习——scikit-learn实战与模型管理
学习·机器学习·scikit-learn