DAY52 通道注意力(SE注意力)

目录

[1. 不同 CNN 层的特征图:从"看轮廓"到"理解语义"](#1. 不同 CNN 层的特征图:从“看轮廓”到“理解语义”)

[2. 什么是注意力:不仅是算法,更是"动物园"](#2. 什么是注意力:不仅是算法,更是“动物园”)

[3. 通道注意力:SE 模块的定义与插入](#3. 通道注意力:SE 模块的定义与插入)

[4. 特征图与热力图:给模型做"脑部扫描"](#4. 特征图与热力图:给模型做“脑部扫描”)

附例:


1. 不同 CNN 层的特征图:从"看轮廓"到"理解语义"

卷积神经网络(CNN)通过层层堆叠,实现特征从低级到高级的抽象转换。

  • 浅层卷积层(如 conv1

    • 特点:特征图尺寸大,保留了大量的边缘、纹理和颜色细节。

    • 功能:类似于人眼的初步视觉感知,识别物体的轮廓。

  • 中层卷积层(如 conv2

    • 特点:尺寸减小,语义开始抽象。

    • 功能:对边缘进行组合,识别局部特征(如某个器官、某种特定形状的色块)。

  • 深层卷积层(如 conv3

    • 特点:尺寸极小(高度抽象),肉眼已无法识别具体的图像内容。

    • 功能:聚焦全局语义特征,是模型判断"这到底是什么类别"的关键依据。

  • 总结:不同通道在每一层各司其职,有的负责找横线,有的负责找红色。层数越深,模型越从"看图像细节"转向"理解物体本质"。


2. 什么是注意力:不仅是算法,更是"动物园"

注意力机制本质上是让模型学会"在海量信息中挑重点"。

  • 核心定义:它是一种"动态权重"提取器。普通卷积的权重是训练好后固定的,而注意力的权重会根据输入内容实时变化(输入一张猫,它看耳朵;输入一张车,它看轮子)。

  • "动物园"比喻:注意力是一个庞大的家族,里面有很多不同的"物种":

    • 自注意力(Self-Attention):建模同一输入内部的依赖(Transformer 的核心)。

    • 通道注意力(Channel Attention):关注"哪些特征通道"重要。

    • 空间注意力(Spatial Attention):关注图像中"哪个区域"重要。

  • 意义:现实中没有万能的模块,只有适合不同任务的模块。


3. 通道注意力:SE 模块的定义与插入

代码中实现的是经典的 SE(Squeeze-and-Excitation)模块,它是通道注意力的代表。

  • 三个关键动作

    1. Squeeze(压缩):用全局平均池化(GAP)把一张特征图压扁成一个数字,代表该通道的平均强度。

    2. Excitation(激发):用全连接层学习通道之间的重要性,输出 0~1 之间的权重。

    3. Reweight(重加权):把权重乘回原图。重要的通道被放大(如猫的特征),不重要的被抑制(如杂乱的背景)。

  • 插入位置 :在代码中,SE 模块被插入在每一层卷积+激活(ReLU)之后,池化(Pooling)之前。这样可以在特征被缩小之前,先对现有特征进行一轮"优胜劣汰"。


4. 特征图与热力图:给模型做"脑部扫描"

为了看清模型到底在想什么,我们使用了热力图(Heatmap)可视化

  • 原理 :通过"钩子(Hook)"机制拦截最后一次卷积(conv3)的输出。由于 conv3 包含了最核心的分类信息,我们分析哪些通道最活跃,并将其还原到原图大小。

  • 视觉解读

    • 红色区域:表示模型注意力高度集中的地方。如果识别狗时,红区在狗头上,说明模型学对了;如果红区在草地上,说明模型过拟合了背景。

    • 蓝色区域:模型忽略的区域,通常是背景或无关信息。

  • 结论:热力图让黑盒模型变得"可解释",帮助我们判断模型是真的聪明,还是在"投机取巧"。

附例:

对比不同卷积层特征图可视化的结果

python 复制代码
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from scipy.ndimage import zoom

# ==========================================
# 1. 定义通道注意力模块 (SE 模块)
# ==========================================
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1) # 把特征图压扁成 1x1
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
            nn.Sigmoid() # 输出 0~1 之间的权重
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y # 将权重乘回原特征图

# ==========================================
# 2. 定义完整的 CNN 模型
# ==========================================
class SimpleAttentionCNN(nn.Module):
    def __init__(self):
        super(SimpleAttentionCNN, self).__init__()
        # 第一层卷积:抓取基础纹理
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.ca1 = ChannelAttention(32) # 在第一层后加入注意力
        
        # 第二层卷积:组合局部特征
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.ca2 = ChannelAttention(64) # 在第二层后加入注意力
        
        # 第三层卷积:提取核心语义
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.ca3 = ChannelAttention(128) # 在第三层后加入注意力
        
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(128 * 4 * 4, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.ca1(self.conv1(x)))) # 输出 16x16
        x = self.pool(torch.relu(self.ca2(self.conv2(x)))) # 输出 8x8
        x = self.pool(torch.relu(self.ca3(self.conv3(x)))) # 输出 4x4
        x = x.view(-1, 128 * 4 * 4)
        x = self.fc(x)
        return x

# ==========================================
# 3. 核心可视化函数
# ==========================================
def run_visualization():
    # A. 环境准备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleAttentionCNN().to(device)


    # --- 就在这里加上下面这段代码 ---
    import os
    checkpoint_path = 'best_model.pth' # 这里的名字必须和你保存的文件名一模一样

    if os.path.exists(checkpoint_path):
        # 核心动作:加载"记忆"
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"成功加载模型权重:{checkpoint_path},现在生成的对比图是有意义的!")
    else:
        print(f"警告:找不到文件 {checkpoint_path},当前显示的是随机初始化的结果。")
    # ----------------------------




    
    model.eval() # 开启评估模式
    
    # B. 加载数据 (使用 CIFAR-10 测试集)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=True)
    
    # 类别名称
    classes = ['飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车']
    
    # C. 准备监控钩子 (Hook)
    captured_features = {}
    def get_hook(name):
        def hook(m, i, o): captured_features[name] = o.detach().cpu()
        return hook

    # 我们想对比这三个卷积层
    layer_names = ['conv1', 'conv2', 'conv3']
    for name in layer_names:
        layer = getattr(model, name)
        layer.register_forward_hook(get_hook(name))

    # D. 获取一张图片并进行推理
    img_tensor, label_idx = next(iter(test_loader))
    img_tensor = img_tensor.to(device)
    
    with torch.no_grad():
        _ = model(img_tensor)

    # E. 绘图对比
    fig, axes = plt.subplots(1, 4, figsize=(18, 5))
    plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文显示

    # 1. 处理原图
    orig_img = img_tensor[0].cpu().permute(1, 2, 0).numpy()
    orig_img = orig_img * np.array([0.2023, 0.1994, 0.2010]) + np.array([0.4914, 0.4822, 0.4465])
    orig_img = np.clip(orig_img, 0, 1)
    
    axes[0].imshow(orig_img)
    axes[0].set_title(f"原始图像: {classes[label_idx]}")
    axes[0].axis('off')

    # 2. 对比三个卷积层
    for i, name in enumerate(layer_names):
        f_map = captured_features[name][0] # 取出捕获到的特征图
        # 计算所有通道的平均响应
        heatmap = torch.mean(f_map, dim=0).numpy()
        # 归一化到 0-1
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
        # 将小图拉伸到 32x32,盖在原图上
        heatmap_resized = zoom(heatmap, (32/heatmap.shape[0], 32/heatmap.shape[1]))
        
        # 叠加绘图
        axes[i+1].imshow(orig_img)
        axes[i+1].imshow(heatmap_resized, alpha=0.5, cmap='jet') # 用 jet 色阶(红高蓝低)
        axes[i+1].set_title(f"层级: {name}\n(分辨率: {f_map.shape[1]}x{f_map.shape[2]})")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

# 运行程序
if __name__ == "__main__":
    run_visualization()
相关推荐
GitCode官方2 小时前
【无标题】
人工智能·开源·atomgit
三不原则2 小时前
实战:ELK 分析 AI 系统日志,快速定位接口报错问题
人工智能·elk
AI_56782 小时前
Postman接口测试极速入门指南
开发语言·人工智能·学习·测试工具·lua
我的golang之路果然有问题2 小时前
开源绘画大模型简单了解
人工智能·ai作画·stable diffusion·人工智能作画
极智视界2 小时前
目标检测数据集 - 自动驾驶场景车辆方向检测数据集下载
人工智能·目标检测·自动驾驶
田井中律.2 小时前
知识图谱(四)之LSTM+CRF
人工智能·机器学习
Hcoco_me2 小时前
大模型面试题74:在使用GRPO训练LLM时,训练数据有什么要求?
人工智能·深度学习·算法·机器学习·chatgpt·机器人
筱昕~呀2 小时前
基于深度生成对抗网络的智能实时美妆设计
人工智能·python·生成对抗网络·mediapipe·beautygan
qunaa01012 小时前
钻井作业场景下设备与产品识别与检测:基于YOLO11-SRFD的目标检测系统实现与应用
人工智能·目标检测·计算机视觉