PyTorch实战(39)------使用Captum解释深度学习模型
0. 前言
我们已经学习了如何使用 PyTorch 解释深度学习模型,现代机器学习提供了更高效的方法来解释模型预测行为,无需直接处理海量参数。在本节中,我们将使用 Captum 工具包------这是一个专为 PyTorch 设计的模型可解释性工具库,仅需少量代码即可解析模型决策逻辑。通过本节学习,将掌握解构深度学习模型内部运作的关键技能。以这种方式深入了解模型有助于理解模型的预测行为逻辑。运用本节的实践经验,能够使用 Captum 解读自定义的深度学习模型。
1. 设置 Captum
Captum 是由 Meta 基于 PyTorch 开发的开源模型可解释性库。本节将继续使用训练好的手写数字分类模型,运用 Captum 提供的多种解释工具来剖析模型的预测机制。
模型训练代码与《使用PyToch可视化神经网络》一节所示的代码类似,接下来将使用训练好的模型和样本图像,解析模型对给定图像进行预测时的内部运作机制。
(1) 首先需导入 Captum 相关模块以使用其内置的模型可解释性功能:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import visualization as viz
(2) 为使输入图像符合模型要求的 (1,28,28) 尺寸,需进行张量重塑:
python
captum_input = sample_data[0].unsqueeze(0)
captum_input.requires_grad = True
根据 Captum 的要求,输入张量(本节中为图像)需要参与梯度计算。因此,我们需要将输入的 requires_grad 标志设置为 True。
(3) 接下来,准备待分析的样本图像,以便通过模型的可解释性方法进行处理:
python
orig_image = np.tile(np.transpose((sample_data[0].cpu().detach().numpy() / 2) + 0.5, (1, 2, 0)), (1,1,3))
dummy_attr = np.random.rand(*orig_image.shape) * 1e-6
_ = viz.visualize_image_attr(dummy_attr, orig_image, cmap='gray', method="original_image", title="Original Image")
输出结果如下所示:

上述代码输出了原始手写数字图像。我们通过灰度图像在通道维度上的平铺处理,使其适配 Captum 方法所需的三通道输入格式(尽管实际仍为单通道数据)。接下来,我们将对预处理后的灰度图像应用 Captum 提供的多种可解释性方法,分析其在预训练手写数字分类模型中的前向传播过程。
2. 探索 Captum 的可解释性工具
在本节中,我们将探讨 Captum 提供的模型可解释性方法。
解释模型结果的最基本方法之一是显著图 (Saliency Map) 分析,该方法通过计算模型输出(本节为类别概率)相对于输入(本节为图像像素)的梯度,相对于特定输入的梯度越大,说明该输入对输出的影响越大。具体到本节而言,使用显著图表征各像素的重要性,梯度值越大,表明对应像素对模型决策的影响越显著。这些梯度实质上量化了每个输入特征的微小变化对输出结果的扰动程度。
(1) 使用 Captum 的 Saliency 模块来计算梯度显著图:
python
saliency = Saliency(model)
gradients = saliency.attribute(captum_input, target=sample_targets[0].item())
gradients = np.reshape(gradients.squeeze().cpu().detach().numpy(), (28, 28, 1))
_ = viz.visualize_image_attr(gradients, orig_image, method="blended_heat_map", sign="absolute_value",
show_colorbar=True, title="Overlayed Gradients")
输出结果如下所示:

在以上代码中,我们将获得的梯度矩阵重塑为 (28,28,1) 的尺寸,以便与原始图像叠加显示(如上图所示)。Captum 的 viz 模块已封装了可视化功能。若需单独观察梯度分布,可使用以下代码:
python
plt.imshow(np.tile(gradients/(np.max(gradients)), (1,1,3)));
输出结果如下所示:

如输出所示,梯度集中在数字 1 的笔画区域,这些像素对模型判断起决定性作用。
接下来,我们将使用类似代码探索另一种可解释性方法------集成梯度 (Integrated Gradients)。该方法用于寻找特征归因 (feature attribution),也就是特征重要性,即识别哪些像素对预测结果起关键作用。在集成梯度方法中,除了输入图像外,我们还需要指定一个基准图像,通常将基准图像设置为所有像素值均为零的图像。零基准作为一个空白画布,用来衡量每个像素对输出的影响,从一个没有影响的状态开始,可以清晰地展示各个像素如何逐渐影响模型的决策。
(2) 通过计算从基线图像到输入图像的路径上梯度积分来确定像素重要性。使用 Captum 的 IntegratedGradients 模块计算输入图像各像素的重要性:
python
integ_grads = IntegratedGradients(model)
attributed_ig, delta = integ_grads.attribute(captum_input, target=sample_targets[0].item(), baselines=captum_input * 0,
return_convergence_delta=True)
attributed_ig = np.reshape(attributed_ig.squeeze().cpu().detach().numpy(), (28, 28, 1))
_ = viz.visualize_image_attr(attributed_ig, orig_image, method="blended_heat_map",sign="all", show_colorbar=True,
title="Overlayed Integrated Gradients")
输出结果如下所示:

正如预期,梯度值在数字 1 的笔画区域呈现显著峰值。
(3) 最后,我们探讨另一种基于梯度的归因技术,称为 DeepLIFT。DeepLIFT 方法同样需要输入图像和基线图像(此处仍采用零值图像作为基线)。DeepLIFT 计算从基线图像到输入图像变化过程中非线性激活输出的改变量来评估特征重要性。使用 Captum 提供的 DeepLIFT 模块计算梯度,并将这些梯度叠加在原始输入图像上显示:
python
deep_lift = DeepLift(model)
attributed_dl = deep_lift.attribute(captum_input, target=sample_targets[0].item(), baselines=captum_input * 0,
return_convergence_delta=False)
attributed_dl = np.reshape(attributed_dl.squeeze(0).cpu().detach().numpy(), (28, 28, 1))
_ = viz.visualize_image_attr(attributed_dl, orig_image, method="blended_heat_map",sign="all",show_colorbar=True,
title="Overlayed DeepLift")
输出结果如下所示:

可以观察到,梯度值在数字 1 的像素区域再次出现显著极值(如上图所示),这验证了不同解释方法的一致性。
除了本节介绍的技术之外,Captum 提供了许多其他的的模型可解释性技术,如 LayerConductance、GradCAM 和 SHAP。
小结
本节介绍了使用 Captum 工具包解析 PyTorch 深度学习模型的方法。通过显著图、集成梯度和 DeepLIFT 三种可解释性技术,分析了手写数字分类模型的决策逻辑。这些方法通过计算输入特征对模型输出的梯度贡献,可视化关键像素区域,揭示了模型聚焦于数字笔画特征进行预测的机制。Captum 提供了统一 API 实现多种解释算法,仅需少量代码即可获得直观的可视化结果,帮助开发者理解模型内部运作方式,为模型调试和优化提供依据。
系列链接
PyTorch实战(1)------深度学习(Deep Learning)
PyTorch实战(2)------使用PyTorch构建神经网络
PyTorch实战(3)------PyTorch vs. TensorFlow详解
PyTorch实战(4)------卷积神经网络(Convolutional Neural Network,CNN)
PyTorch实战(5)------深度卷积神经网络
PyTorch实战(6)------模型微调详解
PyTorch实战(7)------循环神经网络
PyTorch实战(8)------图像描述生成
PyTorch实战(9)------从零开始实现Transformer
PyTorch实战(10)------从零开始实现GPT模型
PyTorch实战(11)------随机连接神经网络(RandWireNN)
PyTorch实战(12)------图神经网络(Graph Neural Network,GNN)
PyTorch实战(13)------图卷积网络(Graph Convolutional Network,GCN)
PyTorch实战(14)------图注意力网络(Graph Attention Network,GAT)
PyTorch实战(15)------基于Transformer的文本生成技术
PyTorch实战(16)------基于LSTM实现音乐生成
PyTorch实战(17)------神经风格迁移
PyTorch实战(18)------自编码器(Autoencoder,AE)
PyTorch实战(19)------变分自编码器(Variational Autoencoder,VAE)
PyTorch实战(20)------生成对抗网络(Generative Adversarial Network,GAN)
PyTorch实战(21)------扩散模型(Diffusion Model)
PyTorch实战(22)------MuseGAN详解与实现
PyTorch实战(23)------基于Transformer生成音乐
PyTorch实战(24)------深度强化学习
PyTorch实战(25)------使用PyTorch构建DQN模型
PyTorch实战(26)------PyTorch分布式训练
PyTorch实战(27)------自动混合精度训练
PyTorch实战(28)------PyTorch深度学习模型部署
PyTorch实战(29)------使用TorchServe部署PyTorch模型
PyTorch实战(30)------使用TorchScript和ONNX导出通用PyTorch模型
PyTorch实战(31)------在Android上部署PyTorch模型
PyTorch实战(32)------在iOS上构建PyTorch应用
PyTorch实战(33)------使用fastai进行快速原型开发
PyTorch实战(34)------基于PyTorch Lightning的跨硬件模型训练
PyTorch实战(35)------使用PyTorch Profiler分析模型推理性能
PyTorch实战(36)------PyTorch自动机器学习
PyTorch实战(37)------使用Optuna搜索最优超参数
PyTorch实战(38)------深度学习模型可解释性