python打卡day42

Grad-CAM与Hook函数
知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具------hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:

  1. 调试与可视化中间层输出
  2. 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
  3. 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
  4. 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理

1、回调函数

Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名

在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:

  1. 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
  2. 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
  3. Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)

2、lambda函数

在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式

  • 参数列表:可以是单个参数、多个参数或无参数
  • 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)

举个例子

python 复制代码
# 定义匿名函数:计算平方
square = lambda x: x ** 2

# 调用
print(square(5))  # 输出: 25

3、hook函数

PyTorch 提供了两种主要的 hook:

  • Module Hooks(模块钩子):用于监听整个模块的输入和输出
  • Tensor Hooks:用于监听张量的梯度

(1)模块钩子

允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:

  1. register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
  2. register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息

前向钩子举个例子

python 复制代码
# 创建模型实例
model = SimpleModel()

# 创建一个列表用于存储中间层的输出
conv_outputs = []

# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):
    print(f"钩子被调用!模块类型: {type(module)}")
    print(f"输入形状: {input[0].shape}") #  input是一个元组,对应 (image, label)
    print(f"输出形状: {output.shape}")
    
    # 保存卷积层的输出用于后续分析
    # 使用detach()避免追踪梯度,防止内存泄漏
    conv_outputs.append(output.detach())

# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)

# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)

# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)

# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()

反向钩子

python 复制代码
# 定义一个存储梯度的列表
conv_gradients = []

# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):
    print(f"反向钩子被调用!模块类型: {type(module)}")
    print(f"输入梯度数量: {len(grad_input)}")
    print(f"输出梯度数量: {len(grad_output)}")
    
    # 保存梯度供后续分析
    conv_gradients.append((grad_input, grad_output))

# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)

# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)

# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()

# 释放钩子
hook_handle.remove()

(2)张量钩子

PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:

  1. register_hook:用于监听张量的梯度
  2. register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
python 复制代码
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3

# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):
    print(f"原始梯度: {grad}")
    # 修改梯度,例如将梯度减半
    return grad / 2

# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)

# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()

print(f"x的梯度: {x.grad}")

# 释放钩子
hook_handle.remove()

4、Grad-CAM

一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是"猫"),原理:

  • 梯度计算:看最后几层特征图的梯度,哪个特征图对预测"猫"的贡献大
  • 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
  • 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
python 复制代码
# Grad-CAM实现
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # 注册钩子,用于获取目标层的前向传播输出和反向传播梯度
        self.register_hooks()
        
    def register_hooks(self):
        # 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)
        def forward_hook(module, input, output):
            self.activations = output.detach()
        
        # 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        # 在目标层注册前向钩子和反向钩子
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
    
    def generate_cam(self, input_image, target_class=None):
        # 前向传播,得到模型输出
        model_output = self.model(input_image)
        
        if target_class is None:
            # 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别
            target_class = torch.argmax(model_output, dim=1).item()
        
        # 清除模型梯度,避免之前的梯度影响
        self.model.zero_grad()
        
        # 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot)
        
        # 获取之前保存的目标层的梯度和激活值
        gradients = self.gradients
        activations = self.activations
        
        # 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
        
        # 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        
        # ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响
        cam = F.relu(cam)
        
        # 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围
        cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        
        return cam.cpu().squeeze().numpy(), target_class

idx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")

# 转换图像以便可视化
def tensor_to_np(tensor):
    img = tensor.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img

# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)

# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)

# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)

# 可视化
plt.figure(figsize=(12, 4))

# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')

# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')

# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')

plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()

@浙大疏锦行

相关推荐
源码师傅3 分钟前
PHP+mysql 美容美发预约小程序源码 支持DIY装修+完整图文搭建教程
开发语言·mysql·php·预约小程序源码·预约服务系统源码·美容预约小程序源码·美发预约小程序
hvinsion26 分钟前
【开源工具】Python+PyQt5打造智能桌面单词记忆工具:悬浮窗+热键切换+自定义词库
python·qt·考研·开源·英语·翻译·英语单词
t1987512832 分钟前
matlab实现求解兰伯特问题
开发语言·算法·matlab
梓仁沐白32 分钟前
【Kotlin】表达式&关键字
开发语言·python·kotlin
玉~你还好吗35 分钟前
【FreeRTOS#1】多任务处理&任务调度器&任务状态
java·开发语言
小二·36 分钟前
JavaScript 获取当前日期与时间的方法详解
开发语言·前端·javascript
胡萝卜3.040 分钟前
c语言内存函数
c语言·开发语言·笔记·学习方法
日升1 小时前
如何在 Chrome 136+ 用 browser-use 打开「带登录态」的浏览器
python·ai编程·trae
vortex51 小时前
Python进阶与常用库:探索高效编程的奥秘
开发语言·网络·python
南京**1 小时前
python学习(一)
windows·python·学习