Python day46

@浙大疏锦行 Python day46.

  • 注意力机制可以理解为对输入特征进行加权求和,注意力权重也是学习到的,类似于卷积,不过卷积的权重一般时是固定的,而注意力机制的权重根据输入数据不同权重也不同;
  • 常见的注意力模块有自注意力、通道注意力、空间注意力、多头注意力集以编码器-解码器注意力;
  • 通道注意力机制的执行过程为,先压缩空间维度只保留通道信息,接下来通过全连接层学习通道之间的权重信息,最后进行相应的加权操作即可;
python 复制代码
class ChannelAttention(nn.Module):
    """通道注意力模块(Squeeze-and-Excitation)"""
    def __init__(self, in_channels, reduction_ratio=16):
        """
        参数:
            in_channels: 输入特征图的通道数
            reduction_ratio: 降维比例,用于减少参数量
        """
        super(ChannelAttention, self).__init__()
        
        # 全局平均池化 - 将空间维度压缩为1x1,保留通道信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 全连接层 + 激活函数,用于学习通道间的依赖关系
        self.fc = nn.Sequential(
            # 降维:压缩通道数,减少计算量
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            # 升维:恢复原始通道数
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
            # Sigmoid将输出值归一化到[0,1],表示通道重要性权重
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        
        返回:
            加权后的特征图,形状不变
        """
        batch_size, channels, height, width = x.size()
        
        # 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1]
        avg_pool_output = self.avg_pool(x)
        
        # 2. 展平为一维向量:[batch_size, channels, 1, 1] → [batch_size, channels]
        avg_pool_output = avg_pool_output.view(batch_size, channels)
        
        # 3. 通过全连接层学习通道权重:[batch_size, channels] → [batch_size, channels]
        channel_weights = self.fc(avg_pool_output)
        
        # 4. 重塑为二维张量:[batch_size, channels] → [batch_size, channels, 1, 1]
        channel_weights = channel_weights.view(batch_size, channels, 1, 1)
        
        # 5. 将权重应用到原始特征图上(逐通道相乘)
        return x * channel_weights  # 输出形状:[batch_size, channels, height, width]
相关推荐
Q26433650235 分钟前
【有源码】基于Hadoop+Spark的AI就业影响数据分析与可视化系统-AI驱动下的就业市场变迁数据分析与可视化研究-基于大数据的AI就业趋势分析可视化平台
大数据·hadoop·机器学习·数据挖掘·数据分析·spark·毕业设计
en-route9 分钟前
从零开始学神经网络——前馈神经网络
人工智能·深度学习·神经网络
前端伪大叔23 分钟前
第13篇:🎯 如何精准控制买入卖出价格?entry/exit\_pricing 实战配置
javascript·python
城南皮卡丘39 分钟前
基于YOLO8+flask+layui的行人跌倒行为检测系统【源码+模型+数据集】
python·flask·layui
成成成成成成果1 小时前
软件测试面试八股文:测试技术 10 大核心考点(二)
python·功能测试·测试工具·面试·职场和发展·安全性测试
二向箔reverse1 小时前
人脸特征可视化进阶:用 dlib+OpenCV 绘制面部轮廓与器官凸包
开发语言·python
这儿有一堆花1 小时前
从图像到精准文字:基于PyTorch与CTC的端到端手写文本识别实战
人工智能·pytorch·python
Michelle80232 小时前
决策树习题
算法·决策树·机器学习
SunnyDays10112 小时前
Python 高效实现 PDF 转 Word:告别手动复制粘贴
python·pdf转word·pdf转docx·pdf转doc·pdf到word转换
hhzz2 小时前
Pythoner 的Flask项目实践-绘制点/线/面并分类型保存为shpfile功能(Mapboxgl底图)
python·flask·gis·mapboxgl