Python day46

@浙大疏锦行 Python day46.

  • 注意力机制可以理解为对输入特征进行加权求和,注意力权重也是学习到的,类似于卷积,不过卷积的权重一般时是固定的,而注意力机制的权重根据输入数据不同权重也不同;
  • 常见的注意力模块有自注意力、通道注意力、空间注意力、多头注意力集以编码器-解码器注意力;
  • 通道注意力机制的执行过程为,先压缩空间维度只保留通道信息,接下来通过全连接层学习通道之间的权重信息,最后进行相应的加权操作即可;
python 复制代码
class ChannelAttention(nn.Module):
    """通道注意力模块(Squeeze-and-Excitation)"""
    def __init__(self, in_channels, reduction_ratio=16):
        """
        参数:
            in_channels: 输入特征图的通道数
            reduction_ratio: 降维比例,用于减少参数量
        """
        super(ChannelAttention, self).__init__()
        
        # 全局平均池化 - 将空间维度压缩为1x1,保留通道信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 全连接层 + 激活函数,用于学习通道间的依赖关系
        self.fc = nn.Sequential(
            # 降维:压缩通道数,减少计算量
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            # 升维:恢复原始通道数
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
            # Sigmoid将输出值归一化到[0,1],表示通道重要性权重
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        
        返回:
            加权后的特征图,形状不变
        """
        batch_size, channels, height, width = x.size()
        
        # 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1]
        avg_pool_output = self.avg_pool(x)
        
        # 2. 展平为一维向量:[batch_size, channels, 1, 1] → [batch_size, channels]
        avg_pool_output = avg_pool_output.view(batch_size, channels)
        
        # 3. 通过全连接层学习通道权重:[batch_size, channels] → [batch_size, channels]
        channel_weights = self.fc(avg_pool_output)
        
        # 4. 重塑为二维张量:[batch_size, channels] → [batch_size, channels, 1, 1]
        channel_weights = channel_weights.view(batch_size, channels, 1, 1)
        
        # 5. 将权重应用到原始特征图上(逐通道相乘)
        return x * channel_weights  # 输出形状:[batch_size, channels, height, width]
相关推荐
勾股导航10 小时前
K-means
人工智能·机器学习·kmeans
HAPPY酷10 小时前
C++ 和 Python 的“容器”对决:从万金油到核武器
开发语言·c++·python
Jay Kay10 小时前
GVPO:Group Variance Policy Optimization
人工智能·算法·机器学习
gpfyyds66610 小时前
Python代码练习
开发语言·python
小鸡吃米…11 小时前
机器学习面试问题及答案
机器学习
aiguangyuan12 小时前
使用LSTM进行情感分类:原理与实现剖析
人工智能·python·nlp
小小张说故事12 小时前
BeautifulSoup:Python网页解析的优雅利器
后端·爬虫·python
Yeats_Liao12 小时前
评估体系构建:基于自动化指标与人工打分的双重验证
运维·人工智能·深度学习·算法·机器学习·自动化
luoluoal12 小时前
基于python的医疗领域用户问答的意图识别算法研究(源码+文档)
python
Shi_haoliu12 小时前
python安装操作流程-FastAPI + PostgreSQL简单流程
python·postgresql·fastapi