DAY 42 Grad-CAM 与 Hook 函数

本文适合:深度学习炼丹师、模型可解释性入门、CV 算法工程师内容:Hook 原理 + Grad-CAM 公式 + 从零手写 Grad-CAM + 工程库使用环境:PyTorch + torchvision + OpenCV + matplotlib


@浙大疏锦行


一、前言:为什么我们需要 Grad-CAM?

CNN 模型强大,但黑盒问题一直存在:模型为什么把这张图分类为 "狗"?它到底看了图片的哪些区域?

Grad-CAM(Gradient-weighted Class Activation Mapping) 就是目前最通用、最稳定、无需修改网络的 CNN 可视化神器

它能生成一张热力图:

  • 红色 = 模型最关注的区域
  • 蓝色 = 模型几乎不关注

而 Grad-CAM 的核心技术就是:PyTorch Hook


二、Hook 到底是什么?(最通俗讲解)

2.1 回调函数(Callback)

你把一个函数丢给别人,别人在合适的时候帮你调用。这就是 回调 = hook 的本质

2.2 Lambda:轻量级回调

复制代码
lambda grad: print(grad.shape)

适合在 Hook 里写简单逻辑。

2.3 PyTorch 两大 Hook(最重要)

1)Module Hook(模块钩子)

注册在网络层上,用来抓:

  • 前向输出(特征图)
  • 反向梯度
python 复制代码
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)

2)Tensor Hook(张量钩子)

只抓梯度:

python 复制代码
tensor.register_hook(hook_fn)

三、Grad-CAM 原理(公式 + 白话)

3.1 四步走

  1. 前向传播拿到最后一个卷积层的特征图 A

  2. 反向传播对目标类别求导,得到梯度

  3. 计算每个通道的权重梯度全局平均池化:αk​=H×W1​∑i,j​∂Ai,j,k​∂yc​

  4. 加权求和 + ReLU + 上采样得到最终 CAM:LGrad−CAM​=ReLU(∑αk​⋅Ak​)


四、实战:从零实现 Grad-CAM(完整代码)

4.1 导入库

python 复制代码
import torch
import torch.nn as nn
import torchvision.models as models
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4.2 Grad-CAM 类(含 Hook 注册)

python 复制代码
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model.eval()
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self._register_hooks()

    def _forward_hook(self, module, inp, outp):
        self.activations = outp.detach()

    def _backward_hook(self, module, grad_inp, grad_outp):
        self.gradients = grad_outp[0].detach()

    def _register_hooks(self):
        self.target_layer.register_forward_hook(self._forward_hook)
        self.target_layer.register_backward_hook(self._backward_hook)

    def generate_cam(self, x, class_idx=None):
        out = self.model(x.to(device))

        if class_idx is None:
            class_idx = out.argmax(dim=1).item()

        self.model.zero_grad()
        loss = out[0, class_idx]
        loss.backward()

        # 权重:梯度 GAP
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(1, keepdim=True)
        cam = torch.relu(cam)

        cam -= cam.min()
        cam /= cam.max()

        cam = nn.functional.interpolate(
            cam, size=x.shape[2:], mode="bilinear", align_corners=False
        )
        return cam.squeeze().cpu().numpy()

4.3 图像预处理

python 复制代码
def preprocess(img_path):
    tfm = nn.Sequential(
        nn.Resize((224, 224)),
        nn.ToTensor(),
        nn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    )
    img = Image.open(img_path).convert("RGB")
    return tfm(img).unsqueeze(0)

4.4 可视化

python 复制代码
def show_result(img_path, cam):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224, 224)) / 255.0

    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)[..., ::-1]
    heatmap = heatmap / 255.0

    fusion = 0.5 * heatmap + 0.5 * img
    fusion = np.clip(fusion, 0, 1)

    plt.subplot(121), plt.imshow(img), plt.axis("off"), plt.title("原图")
    plt.subplot(122), plt.imshow(fusion), plt.axis("off"), plt.title("Grad-CAM")
    plt.show()

4.5 运行

python 复制代码
model = models.resnet50(pretrained=True).to(device)
target_layer = model.layer4[-1]

cam = GradCAM(model, target_layer)
x = preprocess("test.jpg")  # 换成你的图片
heatmap = cam.generate_cam(x)
show_result("test.jpg", heatmap)

五、工程版:使用 pytorch_grad_cam 库(最推荐)

5.1 安装

python 复制代码
pip install grad-cam

5.2 极简代码

python 复制代码
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

model = models.resnet50(pretrained=True).to(device).eval()
target_layers = [model.layer4[-1]]

cam = GradCAM(model=model, target_layers=target_layers)

后续和上面流程一致,非常稳定、支持各种模型。


六、常见问题(避坑指南)

  1. 热力图全黑

    • 没开梯度
    • Hook 没注册成功
    • 目标层不是卷积层
  2. 热力图很模糊

    • 换更深的卷积层
    • 使用 bilinear 上采样
  3. 结果不对

    • 图像归一化是否正确
    • 类别索引是否正确

七、总结

  1. Hook = PyTorch 可解释性的灵魂前向抓特征,反向抓梯度。

  2. Grad-CAM = 最通用的 CNN 可视化方法不修改网络、不需要额外训练。

  3. 本文代码✅ 可直接跑✅ 可直接发 CSDN✅ 可直接用于项目

相关推荐
思绪无限7 小时前
YOLOv5至YOLOv12升级:吸烟行为检测系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·yolov12·yolo全家桶·吸烟行为检测
<-->7 小时前
【tuner passes compile compress autotp】
人工智能·python·深度学习
m0_564876847 小时前
提示词应用
深度学习·学习·算法
日光明媚7 小时前
DMD 一步扩散核心原理:从符号定义到梯度推导
人工智能·机器学习·计算机视觉·ai作画·stable diffusion·aigc
正经人_x8 小时前
学习日记39:GLIGEN
人工智能·深度学习
思绪无限8 小时前
YOLOv5至YOLOv12升级:教室人员检测与计数系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·yolov12·yolo全家桶·教室人员检测与计数
youcans_8 小时前
【HALCON 实战入门】5. 相机接入与图像采集
图像处理·人工智能·计算机视觉·halcon·图像采集
思绪无限8 小时前
YOLOv5至YOLOv12升级:体育赛事目标检测系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·目标跟踪·体育赛事目标检测·yolov12·yolo全家桶
深度学习lover8 小时前
<数据集>yolo 柑橘识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·柑橘识别
思绪无限8 小时前
YOLOv5至YOLOv12升级:遥感目标检测系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·遥感目标检测·yolov12·yolo全家桶