pytorch_grad_cam 库学习笔记——基类ActivationsAndGradient

pytorch_grad_cam 是一个包含用于计算机视觉的可解释 AI 的最先进方法的软件包。 这可用于诊断模型预测,无论是在生产中还是在 开发模型。 其目的还在于作为研究新可解释性方法的算法和指标的基准。

pytorch_grad_cam 官方源码 https://github.com/jacobgil/pytorch-grad-cam

pytorch_grad_cam 官方教程 https://jacobgil.github.io/pytorch-gradcam-book/introduction.html

在./pytorch-grad-cam/pytorch_grad_cam/activations_and_gradients.py里定义了名为ActivationsAndGradients的基类,是实现类激活映射(CAM)算法的核心组件之一,其主要功能是利用 PyTorch 的 Hook 机制,在模型的前向和反向传播过程中,捕获指定目标层(target_layers)的激活值(activations)和梯度(gradients)。

本篇文章主要在这里对基类ActivationsAndGradients进行逐步分析,以理解库函数原理。

ActivationsAndGradients 类

ActivationsAndGradients 类是实现类激活映射(CAM)算法的核心组件之一,其主要功能是利用 PyTorch 的 Hook 机制,在模型的前向和反向传播过程中,捕获指定目标层(target_layers)的激活值(activations)和梯度(gradients)。

以下是该类的详细解析:

1. init(self, model, target_layers, reshape_transform, detach=True)

python 复制代码
class ActivationsAndGradients:
    """ Class for extracting activations and
    registering gradients from targetted intermediate layers """

    def __init__(self, model, target_layers, reshape_transform, detach=True):
        self.model = model
        self.gradients = []
        self.activations = []
        self.reshape_transform = reshape_transform
        self.detach = detach
        self.handles = []
        for target_layer in target_layers:
            self.handles.append(
                target_layer.register_forward_hook(self.save_activation))
            # Because of https://github.com/pytorch/pytorch/issues/61519,
            # we don't use backward hook to record gradients.
            self.handles.append(
                target_layer.register_forward_hook(self.save_gradient))

功能:初始化 ActivationsAndGradients 实例。

参数:

  • model: 要分析的 PyTorch 模型。
  • target_layers: 一个包含目标 torch.nn.Module 层的列表(如 model.layer4 或 model.features)。这些层的激活和梯度将被捕获。
  • reshape_transform: 一个可选的函数,用于重塑激活和梯度。这在处理某些模型(如 Vision Transformers)时非常关键,因为它们的特征图形状(如 [batch, num_patches, features])与标准卷积网络的 [batch, channels, height, width] 不同,需要转换以便后续处理。
  • detach: 布尔值。如果为 True,则在捕获后将激活和梯度从计算图中分离(detach())并转移到 CPU。这可以节省 GPU 内存,并防止意外的梯度累积。如果为 False,则保留原始的张量(在 GPU 上且与计算图相连)。

关键操作:

  1. 初始化存储列表:self.gradients = [] 和 self.activations = []。
  2. 初始化 self.handles = [] 以存储注册的 Hook 句柄,便于后续移除。
  3. 注册 Hook:
  • 为 target_layers 中的每一个层,调用 register_forward_hook(self.save_activation)。这会在该层的前向传播结束后,自动调用 save_activation 方法,捕获其输出(即激活值)。
  • 为同一个层,再次调用 register_forward_hook(self.save_gradient)。注意:这里没有使用 register_backward_hook。注释中提到了一个 PyTorch 的 issue (#61519),暗示使用后向 Hook 可能存在问题。因此,这里采用了另一种方法:在前向 Hook 中,为层的输出张量 output 注册一个梯度 Hook (output.register_hook(_store_grad))。当反向传播到达该 output 张量时,_store_grad 函数就会被调用。

2. save_activation(self, module, input, output)

python 复制代码
    def save_activation(self, module, input, output):
        activation = output
        if self.detach:
            if self.reshape_transform is not None:
                activation = self.reshape_transform(activation)
            self.activations.append(activation.cpu().detach())
        else:
            self.activations.append(activation)

功能:

前向 Hook 的回调函数,用于捕获目标层的激活值。

参数(由 PyTorch 自动提供):

module: 调用此 Hook 的层(即 target_layer)。

input: 该层的输入张量(通常用不到)。

output: 该层的输出张量(即激活值)。

流程:

  • 将 output 赋值给 activation。
  • 如果提供了 reshape_transform 函数,则对 activation 进行重塑。
  • 根据 detach 参数决定如何存储:
    如果 detach=True:将 activation 移动到 CPU 并从计算图中分离,然后添加到 self.activations 列表。
    否则:直接将原始的 activation 添加到列表。

结果:

每次前向传播后,self.activations 列表会按顺序存储所有 target_layers 的激活值。

3. save_gradient(self, module, input, output)

python 复制代码
    def save_gradient(self, module, input, output):
        if not hasattr(output, "requires_grad") or not output.requires_grad:
            # You can only register hooks on tensor requires grad.
            return

        # Gradients are computed in reverse order
        def _store_grad(grad):
            if self.detach:
                if self.reshape_transform is not None:
                    grad = self.reshape_transform(grad)
                self.gradients = [grad.cpu().detach()] + self.gradients
            else:
                self.gradients = [grad] + self.gradients

        output.register_hook(_store_grad)

功能:前向 Hook 的回调函数,用于为目标层的输出张量注册一个梯度 Hook。这个梯度 Hook 会在反向传播时被触发。

参数:同 save_activation。

流程:

  1. 检查 output 是否需要梯度 (requires_grad)。如果不需要(例如,某些层的输出是整数或布尔值),则直接返回,不注册 Hook。
  2. 定义一个内部函数 _store_grad(grad):
  • 这个函数是真正的梯度 Hook,它接收反向传播计算出的梯度 grad 作为输入。
  • 如果提供了 reshape_transform,则对 grad 进行重塑。
  • 根据 detach 参数决定如何存储:
    ** 如果 detach=True:将 grad 移动到 CPU 并分离,然后插入到 * * self.gradients 列表的开头 ([grad.cpu().detach()] + self.gradients)。
    ** 否则:直接将 grad 插入到列表开头。
  • 为什么插入开头? 因为反向传播是从后往前进行的。最后层的梯度先计算,最先被捕获。为了保持 self.gradients 列表的顺序与 self.activations 和 target_layers 的顺序一致,需要将新捕获的梯度放在列表前面。
  1. 调用 output.register_hook(_store_grad),将 _store_grad 函数注册为 output 张量的梯度 Hook。

结果:在反向传播过程中,每当计算到某个 target_layer 的输出梯度时,_store_grad 就会被调用,该梯度被处理后按正确的顺序存储在 self.gradients 列表中。

4. call(self, x)

python 复制代码
    def __call__(self, x):
        self.gradients = []
        self.activations = []
        return self.model(x)

功能:使 ActivationsAndGradients 对象可以像函数一样被调用。

流程:

  1. 在每次调用前,清空self.gradients 和 self.activations 列表。这是非常重要的,确保了每次调用捕获的都是本次前向/反向传播的数据,不会与之前的结果混合。
  2. 调用 self.model(x) 执行模型的前向传播。在此过程中,所有注册的 Hook 都会被触发,save_activation 会捕获激活值,save_gradient 会为输出张量注册梯度 Hook。

返回值:模型的前向输出(self.model(x) 的结果)。

副作用:self.activations 和 self.gradients 列表被填充。

5. release(self)

python 复制代码
    def release(self):
        for handle in self.handles:
            handle.remove()

功能:

移除所有已注册的 Hook。

流程:

遍历 self.handles 列表,调用每个 handle.remove()。

重要性:

这是资源管理的关键步骤。如果不移除 Hook,它们会一直存在于模型中,导致:

  1. 内存泄漏:捕获的激活和梯度会持续累积。
  2. 性能下降:每次前向/反向传播都会执行不必要的 Hook 函数。
  3. 错误:可能干扰模型的其他操作。

调用时机:通常在 BaseCAM 的 delexit 方法中调用。

总结

ActivationsAndGradients 类巧妙地利用了 PyTorch 的 Hook 机制:

  1. 捕获激活:通过 register_forward_hook 在前向传播后直接捕获目标层的输出。
  2. 捕获梯度:通过在前向 Hook 中为输出张量注册 register_hook,在反向传播时捕获其梯度,并通过将新梯度插入列表开头来保证顺序正确。
  3. 灵活性:支持 reshape_transform 以适应不同模型架构。
  4. 内存管理:通过 detach 选项控制是否保留计算图,并通过 release 方法确保 Hook 被正确移除,防止资源泄漏。
  5. 易用性:提供 call 接口,使得用户只需调用一次即可完成前向传播并自动捕获所需数据。

这个类是 BaseCAM 及其所有子类能够工作的基石,它透明地拦截了模型内部的计算过程,为 CAM 算法提供了必需的中间数据。