Day 47 注意力热图可视化

@浙大疏锦行
今日任务:

  1. 通道注意力:模型的定义和插入的位置
  2. 通道注意力后的特征图和热力图

作业:对比不同卷积层热图可视化的结果

通道注意力

通道注意力是一种让网络能够自动学习每个特征通道重要性 的机制。它的核心思想:不同特征通道对当前任务的贡献不同,应该增强重要通道,抑制不重要通道。

复制代码
# 伪代码:所有通道注意力的核心
def channel_attention(x):
    # 1. 计算每个通道的重要性权重
    channel_weights = compute_importance(x)  # 不同实现方式
    
    # 2. 对原始特征图进行通道级别加权
    output = x * channel_weights             # 相同的加权操作
    
    return output

其中,SE模块 是通道注意力实现 的一种具体方式 。SE模块包括三个核心操作:Squeeze (压缩)、Excitation (激励)、Scale(缩放)。

复制代码
# SE 模块
输入 [B,C,H,W]
    ↓
Squeeze: 全局平均池化 → [B,C,1,1]  
    ↓
Excitation: FC → ReLU → FC → Sigmoid → [B,C,1,1]
    ↓  
Scale: 与输入逐通道相乘 → 输出 [B,C,H,W]

定义注意力模块

这里以SE模块的定义为例:

(1)Squeeze

  • 通过全局平均池化 将每个通道的二维特征图(H×W)压缩为一个标量 ,保留通道的全局信息
  • 物理意义:计算每个通道在整个图像中的 "平均响应强度",例如,"边缘检测通道" 在有物体边缘的图像中响应值会更高。
python 复制代码
# 全局平均池化
self.avg_pool = nn.AdaptiveAvgPool2d(1)

(2)Excitation

  • 通过全连接层 + Sigmoid 激活 ,学习通道间的依赖关系,输出 0-1 之间的权重值
  • 物理意义:让模型自动判断哪些通道更重要 (权重接近 1),哪些通道可忽略(权重接近 0)。
python 复制代码
# 全连接层 + Sigmoid激活
self.fc = nn.Sequential(
    nn.Linear(in_channel,in_channel//reduction_ratio,bias=False),
    nn.ReLU(inplace=True),
    nn.Linear(in_channel//reduction_ratio,in_channel,bias=False),
    nn.Sigmoid(),
    )

(3)Scale

  • 将学习到的通道权重与原始特征图逐通道相乘,增强重要通道,抑制不重要通道。
  • 物理意义:类似人类视觉系统聚焦于关键特征(如猫的轮廓),忽略无关特征(如背景颜色)

完整:

python 复制代码
# SE模块的定义
class ChannelAttention(nn.Module):
    def __init__(self,in_channel,reduction_ratio=16):
        super(ChannelAttention,self).__init__()
        # 全局平均池化
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 全连接层 + Sigmoid激活
        self.fc = nn.Sequential(
            nn.Linear(in_channel,in_channel//reduction_ratio,bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channel//reduction_ratio,in_channel,bias=False),
            nn.Sigmoid(),
        )
       
    def forward(self,x):
        # 获得bath_size,channels等数值,为后续变换维度作准备
        batch_size,channels,height,width = x.size()

        # 1-squeeze
        avg_pool_out = self.avg_pool(x)

        # 展平为一维:[batch_size,channels],全连接层要求输入为一维
        avg_pool_out = avg_pool_out.view(batch_size,channels)
        # 2-获得通道权重
        channel_weights = self.fc(avg_pool_out)

        # 维度变换:匹配原始特征图的维度,[batch_size, channels, 1, 1]
        channel_weights = channel_weights.view(batch_size,channels,1,1)
        # 3-将权重应用到原始特征图上(逐通道相乘)
        return x*channel_weights # 输出形状:[batch_size, channels, height, width]

插入模型定义部分

在每一个卷积块中,插入SE模块:Conv → BN → ReLU → SE → Pool(常见顺序)。

  • 信息完整:ReLU后的特征包含所有激活信息
  • 池化前加权:在降采样前强调重要特征
python 复制代码
# 在CNN架构中插入SE模块
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        # ---------------------- 第一个卷积块 ----------------------
        self.conv1 = nn.Conv2d(3,32,3,padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        # 添加通道注意力
        self.ca1 = ChannelAttention(in_channel=32,reduction_ratio=16) 
        self.pool1 = nn.MaxPool2d(2)
        # ---------------------- 第二个卷积块 ----------------------
        self.conv2 = nn.Conv2d(32,64,3,padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        # 添加通道注意力
        self.ca2 = ChannelAttention(in_channel=64,reduction_ratio=16)
        self.pool2 = nn.MaxPool2d(2)
        # ---------------------- 第三个卷积块 ----------------------
        self.conv3 = nn.Conv2d(64,128,3,padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        # 添加通道注意力
        self.ca3 = ChannelAttention(in_channel=128,reduction_ratio=16)
        self.pool3 = nn.MaxPool2d(2)
        # ---------------------- 全连接器 ----------------------
        self.fc1 = nn.Linear(128*4*4,512)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512,10)

    def forward(self,x):
        # ---------- 卷积块1处理 ----------
        x = self.conv1(x)       
        x = self.bn1(x)         
        x = self.relu1(x)       
        x = self.ca1(x)  # 应用通道注意力
        x = self.pool1(x)       
        
        # ---------- 卷积块2处理 ----------
        x = self.conv2(x)       
        x = self.bn2(x)         
        x = self.relu2(x)       
        x = self.ca2(x)  # 应用通道注意力
        x = self.pool2(x)       
        
        # ---------- 卷积块3处理 ----------
        x = self.conv3(x)       
        x = self.bn3(x)         
        x = self.relu3(x)       
        x = self.ca3(x)  # 应用通道注意力
        x = self.pool3(x)       
        
        # ---------- 展平与全连接层 ----------
        x = x.view(-1, 128 * 4 * 4)  
        x = self.fc1(x)           
        x = self.relu3(x)         
        x = self.dropout(x)       
        x = self.fc2(x)   

        return x

# 实例化
model = CNN().to(device)

# 损失函数、优化器、学习率调度器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=3)

# 训练模型(复用原有的train函数)
print("开始训练带通道注意力的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

可视化空间注意力热图

整体的思路与特征图可视化是类似的,这里选择对最后一个卷积层可视化。

步骤

(1)初始化设置

python 复制代码
# 1-初始化设置
model.eval()
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

(2)数据加载与处理

python 复制代码
# 2-数据加载与处理
for i,(images,labels) in enumerate(test_loader):
    if i >= num_samples:
        break
    images,labels = images.to(device),labels.to(device)

(3)注册钩子捕获特征图

python 复制代码
# 3-注册钩子捕获特征图
activation_maps = []
# 定义钩子
def hook(module,input,output):
    activation_maps.append(output.cpu())

forward_hook = model.conv3.register_forward_hook(hook)

(4)前向传播与特征提取

python 复制代码
# 4-前向传播与特征提取
output = model(images) # 前向传播,触发钩子
forward_hook.remove() # 移除钩子
_,predicted = torch.max(output,1) # 获取预测结果

(5)可视化热图

这里需要计算通道注意力权重,然后选取前n个活跃通道(权重大)可视化。

  • 通道归一化:保证颜色映射的一致性
  • 尺寸匹配:热力图与原始图像尺寸要匹配
  • 图像叠加:原始图像作为底图,热力图叠加
python 复制代码
# 5-可视化热图
# 反标准化--原始图像
img = images[0].cpu().permute(1,2,0).numpy()
img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)
img = np.clip(img, 0, 1)
# 获取激活图(最后一个卷积层的输出)
feature_maps = activation_maps[0][0].cpu() # 获取第一个样本, [C, H, W] 
channel_weights = torch.mean(feature_maps,dim=(1,2))
sorted_weight = torch.argsort(channel_weights,descending=True)
# 创建子图
fig,axes = plt.subplots(1,4,figsize=(16,4))

# 绘制原始图像
axes[0].imshow(img)
axes[0].set_title(f'Original Picture\nActual:{class_names[labels[0]]}\nPredicted:{class_names[predicted[0]]}')
axes[0].axis('off')

for j in range(3):
    channel_idx = sorted_weight[j] # 获得通道索引
    channel_map = feature_maps[channel_idx].numpy() # 获得单通道,[H, W]  
    # 归一化
    channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min()+1e-8)
    # 调整热图大小,匹配原始图像
    from scipy.ndimage import zoom
    heatmap = zoom(channel_map,(32/feature_maps.shape[1],32/feature_maps.shape[2]))
    # 叠加图像
    axes[j+1].imshow(img)
    axes[j+1].imshow(heatmap,alpha=0.5,cmap='jet')
    axes[j+1].set_title(f'Attention Heatmap - Channnel {channel_idx}')
    axes[j+1].axis('off')

plt.tight_layout()
plt.show()

结果

python 复制代码
# 注意力热图可视化
def visualize_attention_map(model,test_loader,device,num_samples=3):
    # 1-初始化设置
    model.eval()
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    with torch.no_grad():
        # 2-数据加载与处理
        for i,(images,labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            images,labels = images.to(device),labels.to(device)

            # 3-注册钩子捕获特征图
            activation_maps = []
            # 定义钩子
            def hook(module,input,output):
                activation_maps.append(output.cpu())

            forward_hook = model.conv3.register_forward_hook(hook)
            
            # 4-前向传播与特征提取
            output = model(images) # 前向传播,触发钩子
            forward_hook.remove() # 移除钩子
            _,predicted = torch.max(output,1) # 获取预测结果

            # 5-可视化热图
            # 反标准化--原始图像
            img = images[0].cpu().permute(1,2,0).numpy()
            img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)
            img = np.clip(img, 0, 1)
            # 获取激活图(最后一个卷积层的输出)
            feature_maps = activation_maps[0][0].cpu() # 获取第一个样本, [C, H, W] 
            channel_weights = torch.mean(feature_maps,dim=(1,2))
            sorted_weight = torch.argsort(channel_weights,descending=True)
            # 创建子图
            fig,axes = plt.subplots(1,4,figsize=(16,4))

            # 绘制原始图像
            axes[0].imshow(img)
            axes[0].set_title(f'Original Picture\nActual:{class_names[labels[0]]}\nPredicted:{class_names[predicted[0]]}')
            axes[0].axis('off')

            for j in range(3):
                channel_idx = sorted_weight[j] # 获得通道索引
                channel_map = feature_maps[channel_idx].numpy() # 获得单通道,[H, W]  
                # 归一化
                channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min()+1e-8)
                # 调整热图大小,匹配原始图像
                from scipy.ndimage import zoom
                heatmap = zoom(channel_map,(32/feature_maps.shape[1],32/feature_maps.shape[2]))
                # 叠加图像
                axes[j+1].imshow(img)
                axes[j+1].imshow(heatmap,alpha=0.5,cmap='jet')
                axes[j+1].set_title(f'Attention Heatmap - Channnel {channel_idx}')
                axes[j+1].axis('off')

            plt.tight_layout()
            plt.show()

# 调用可视化函数
visualize_attention_map(model, test_loader, device, num_samples=3)

未加入通道注意力的版本:

加入通道注意力的版本:

作业

修改注册钩子的模块,改成conv1和conv2:

python 复制代码
forward_hook = model.conv3.register_forward_hook(hook)

conv1(第一层卷积)

conv2(第二层卷积)

conv3(第三层卷积)

小结

(1)热图分析:红色表示高关注(对分类最重要的区域),蓝色表示低关注。如果热力图错误聚焦(红色区域)在背景上,可能说明模型过拟合或训练不足。

(2)多通道对比:不同的通道关注不同的特征,比如整体轮廓、纹理细节、颜色分布。

(3)使用热图可视化的应用:

相关推荐
zhjadsf2 小时前
Huggingface_hub源码解析 - 简介篇
python·huggingface
20岁30年经验的码农3 小时前
Python语言基础文档
开发语言·python
清静诗意3 小时前
独立 IoT 客户端绕过 Django 生命周期导致数据库断链:诊断与修复
python·mysql·django·生命周期
不知更鸟6 小时前
Django 项目设置流程
后端·python·django
自动化代码美学7 小时前
【Python3.13】官网学习之控制流
开发语言·windows·python·学习
王哈哈^_^10 小时前
YOLO11实例分割训练任务——从构建数据集到训练的完整教程
人工智能·深度学习·算法·yolo·目标检测·机器学习·计算机视觉
百锦再10 小时前
第18章 高级特征
android·java·开发语言·后端·python·rust·django
源码之家10 小时前
基于Python房价预测系统 数据分析 Flask框架 爬虫 随机森林回归预测模型、链家二手房 可视化大屏 大数据毕业设计(附源码)✅
大数据·爬虫·python·随机森林·数据分析·spark·flask
SalvoGao10 小时前
Python学习 | 怎么理解epoch?
数据结构·人工智能·python·深度学习·学习