DAY 42 Grad-CAM与Hook函数

浙大疏锦行-CSDN博客
知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

**作业:**理解下今天的代码即可

1 回调函数 (Callback Function)
回调函数是一种在特定事件发生时被调用执行的函数。你可以把它理解为一个"钩子",预先设定好,当某个条件满足时(例如,一个训练周期结束),这个函数就会被自动触发。

在深度学习训练中,回调函数非常实用。例如,Keras框架就广泛使用回调函数来实现以下功能:

模型保存:在每个epoch结束后,自动检查验证集上的性能,并保存最优的模型。
学习率调整:当模型性能陷入停滞时,自动降低学习率。
提前终止:当验证集损失在若干个epoch内不再下降时,提前结束训练,防止过拟合。
训练可视化:将训练过程中的损失、准确率等指标实时发送到可视化工具(如TensorBoard)。
回调函数的核心思想是将控制权反转,你不需要在主训练流程中显式地调用这些操作,而是将它们注册到训练流程中,由框架在特定时机自动调用。

1.2 Lambda函数 (Anonymous Function)
Lambda函数,也称为匿名函数,是Python中一种创建小型、单行函数的便捷方式。 它使用 lambda 关键字定义,语法结构为:lambda arguments: expression。

特点:

简洁:对于一些简单的功能,无需使用 def 关键字定义一个完整的函数。
匿名:它没有正式的函数名。
内联:通常在使用时当场定义,非常适合作为高阶函数(即接受其他函数作为参数的函数,如 map, filter)的参数。
示例:
一个将输入值加10的普通函数:

复制代码
def add_ten(x):
  return x + 10
add_ten_lambda = lambda x: x + 10
print(add_ten_lambda(5)) # 输出: 15

1 张量钩子: tensor.register_hook(hook_fn)

这种钩子直接注册在torch.Tensor对象上。当该张量的梯度被计算出来时,注册的钩子函数会被自动调用。

功能:主要用于检查或修改一个张量的梯度。

钩子函数签名:hook_fn(grad),它接收一个参数,即该张量的梯度。

返回值:可以返回一个新的Tensor来替代原有的梯度,或者返回None(此时对梯度的任何原地修改都会保留)。

示例:查看并修改中间变量的梯度

复制代码
import torch

# 创建一个需要梯度的张量
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
y = x * 2
z = y.mean()

# 定义一个钩子函数来查看y的梯度,并将其乘以2
def y_grad_hook(grad):
    print("钩子函数被调用:原始y的梯度是:\n", grad)
    return grad * 2

# 在张量y上注册钩子
handle = y.register_hook(y_grad_hook)

# 启动反向传播
z.backward()

print("\n反向传播后,x的梯度是:\n", x.grad)

# 使用完毕后移除钩子,防止内存泄漏
handle.remove()

在这个例子中,y是一个中间变量,通常它的梯度y.grad在计算后会是None。但通过Hook,我们成功地捕获并修改了它的梯度,这也影响了x的最终梯度。

2.2 模块钩子 (Module Hook)

模块钩子注册在nn.Module实例(例如 nn.Conv2d 或整个模型)上,可以在前向或反向传播的不同阶段介入。

register_forward_hook(hook_fn)

触发时机:在模块的 forward() 方法执行完毕后被调用。

钩子函数签名:hook_fn(module, input, output),接收三个参数:模块本身、模块的输入和模块的输出。

主要用途:获取中间层的特征图(激活值)。这是实现Grad-CAM的关键。

register_backward_hook(hook_fn)

触发时机:当梯度反向传播到该模块时被调用。

钩子函数签名:hook_fn(module, grad_input, grad_output),接收三个参数:模块本身、模块输入的梯度和模块输出的梯度。

主要用途:获取中间层的梯度。这也是实现Grad-CAM的关键。

register_forward_pre_hook(hook_fn)

触发时机:在模块的 forward() 方法执行之前被调用。

钩子函数签名:hook_fn(module, input),接收模块本身和模块的输入。

主要用途:在数据进入某一层之前,检查或修改数据。

  1. 实战应用:Grad-CAM (梯度加权类激活映射)

Grad-CAM是一种非常流行的模型可解释性技术,它能够生成一张热力图,直观地显示出模型在做出特定分类决策时,主要依赖了输入图像的哪些区域。

工作原理简介

Grad-CAM的核心思想是:特征图中包含空间信息,而梯度中包含重要性信息。

获取特征图:首先,对一张输入图像进行正向传播,并使用register_forward_hook捕获我们感兴趣的最后一个卷积层的输出特征图(Activations)。

获取梯度:以我们想解释的类别(例如,"猫")的得分为起点,进行反向传播。使用register_backward_hook捕获该特征图的梯度。这个梯度反映了特征图上每个像素点对"猫"这个类别的最终得分有多重要。

计算权重:对每个通道的梯度图进行全局平均池化,得到每个特征图通道的重要性权重(alpha)。

加权求和:将捕获的特征图与其对应的重要性权重相乘并求和,得到一个粗糙的热力图。

ReLU激活:对热力图应用ReLU函数,只保留对目标类别有正向贡献的区域。 最终将热力图上采样到和原图一样大小,并叠加在原图上进行可视化。

复制代码
import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import requests
import numpy as np
import cv2

# 1. 定义一个包装器来方便地使用Hook
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # 注册钩子
        self.target_layer.register_forward_hook(self.save_activations)
        self.target_layer.register_backward_hook(self.save_gradients)

    def save_activations(self, module, input, output):
        self.activations = output
        
    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def __call__(self, x, class_idx=None):
        self.model.eval()
        output = self.model(x)

        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        
        # 清零旧梯度并进行反向传播
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0][class_idx] = 1
        output.backward(gradient=one_hot, retain_graph=True)

        # 计算权重和热力图
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        activations = self.activations.detach()

        # 加权
        for i in range(activations.shape[1]):
            activations[:, i, :, :] *= pooled_gradients[i]

        heatmap = torch.mean(activations, dim=1).squeeze()
        heatmap = np.maximum(heatmap.cpu(), 0) # ReLU
        heatmap /= torch.max(heatmap) # 归一化

        return heatmap.numpy()

# 2. 准备模型和数据
model = resnet50(pretrained=True)
target_layer = model.layer4[-1] # 选择最后一个卷积块的最后一层

# 加载一张示例图片
url = "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg"
img = Image.open(requests.get(url, stream=True).raw).convert('RGB')
img = img.resize((224, 224))
input_tensor = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0) / 255.0

# 3. 生成并可视化热力图
grad_cam = GradCAM(model, target_layer)
heatmap = grad_cam(input_tensor, class_idx=281) # ImageNet中"边境牧羊犬"的索引

# 将热力图叠加到原图
heatmap_resized = cv2.resize(heatmap, (224, 224))
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
superimposed_img = heatmap_colored * 0.4 + np.array(img)
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

Image.fromarray(superimposed_img).show()
相关推荐
The Open Group2 小时前
英特尔公司Darren Pulsipher 博士:以架构之力推动政府数字化转型
大数据·人工智能·架构
Ronin-Lotus2 小时前
深度学习篇---卷积核的权重
人工智能·深度学习
.银河系.2 小时前
8.18 机器学习-决策树(1)
人工智能·决策树·机器学习
敬往事一杯酒哈2 小时前
第7节 神经网络
人工智能·深度学习·神经网络
三掌柜6662 小时前
NVIDIA 技术沙龙探秘:聚焦 Physical AI 专场前沿技术
大数据·人工智能
Hello123网站3 小时前
Flowith-节点式GPT-4 驱动的AI生产力工具
人工智能·ai工具
yzx9910133 小时前
Yolov模型的演变
人工智能·算法·yolo
若天明4 小时前
深度学习-计算机视觉-微调 Fine-tune
人工智能·python·深度学习·机器学习·计算机视觉·ai·cnn
爱喝奶茶的企鹅4 小时前
Ethan独立开发新品速递 | 2025-08-19
人工智能