Grad-CAM与Hook函数
知识点回顾
- 回调函数
- lambda函数
- hook函数的模块钩子和张量钩子
- Grad-CAM的示例
在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具------hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:
- 调试与可视化中间层输出
- 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
- 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
- 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理
1、回调函数
Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名
在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:
- 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
- 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
- Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)
2、lambda函数
在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式
- 参数列表:可以是单个参数、多个参数或无参数
- 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)
举个例子
python
# 定义匿名函数:计算平方
square = lambda x: x ** 2
# 调用
print(square(5)) # 输出: 25
3、hook函数
PyTorch 提供了两种主要的 hook:
- Module Hooks(模块钩子):用于监听整个模块的输入和输出
- Tensor Hooks:用于监听张量的梯度
(1)模块钩子
允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:
- register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
- register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息
前向钩子举个例子
python
# 创建模型实例
model = SimpleModel()
# 创建一个列表用于存储中间层的输出
conv_outputs = []
# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):
print(f"钩子被调用!模块类型: {type(module)}")
print(f"输入形状: {input[0].shape}") # input是一个元组,对应 (image, label)
print(f"输出形状: {output.shape}")
# 保存卷积层的输出用于后续分析
# 使用detach()避免追踪梯度,防止内存泄漏
conv_outputs.append(output.detach())
# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)
# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)
# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)
# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()
反向钩子
python
# 定义一个存储梯度的列表
conv_gradients = []
# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):
print(f"反向钩子被调用!模块类型: {type(module)}")
print(f"输入梯度数量: {len(grad_input)}")
print(f"输出梯度数量: {len(grad_output)}")
# 保存梯度供后续分析
conv_gradients.append((grad_input, grad_output))
# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)
# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)
# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()
# 释放钩子
hook_handle.remove()
(2)张量钩子
PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:
- register_hook:用于监听张量的梯度
- register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
python
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3
# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):
print(f"原始梯度: {grad}")
# 修改梯度,例如将梯度减半
return grad / 2
# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)
# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()
print(f"x的梯度: {x.grad}")
# 释放钩子
hook_handle.remove()
4、Grad-CAM
一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是"猫"),原理:
- 梯度计算:看最后几层特征图的梯度,哪个特征图对预测"猫"的贡献大
- 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
- 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
python
# Grad-CAM实现
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度
self.register_hooks()
def register_hooks(self):
# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)
def forward_hook(module, input, output):
self.activations = output.detach()
# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
# 在目标层注册前向钩子和反向钩子
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def generate_cam(self, input_image, target_class=None):
# 前向传播,得到模型输出
model_output = self.model(input_image)
if target_class is None:
# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别
target_class = torch.argmax(model_output, dim=1).item()
# 清除模型梯度,避免之前的梯度影响
self.model.zero_grad()
# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度
one_hot = torch.zeros_like(model_output)
one_hot[0, target_class] = 1
model_output.backward(gradient=one_hot)
# 获取之前保存的目标层的梯度和激活值
gradients = self.gradients
activations = self.activations
# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性
weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果
cam = torch.sum(weights * activations, dim=1, keepdim=True)
# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响
cam = F.relu(cam)
# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围
cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
cam = cam - cam.min()
cam = cam / cam.max() if cam.max() > 0 else cam
return cam.cpu().squeeze().numpy(), target_class
idx = 102 # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")
# 转换图像以便可视化
def tensor_to_np(tensor):
img = tensor.cpu().numpy().transpose(1, 2, 0)
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
img = std * img + mean
img = np.clip(img, 0, 1)
return img
# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)
# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)
# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)
# 可视化
plt.figure(figsize=(12, 4))
# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')
# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')
# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')
plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()
