模型的可解释性之CAM系列

CAM

CAM[1]是class activation mapping(类激活映射)的简称,由周博磊等人2016年的cvpr会议的Learning Deep Features for Discriminative Localization提出。作者了解到训练用于图像分类的CNNs,也可以用于定位目标。因而,作者提出了CAM,不仅可用于定位目标,还用于识别模型使用了图像了哪些区域。

图1 CAM原理

CAM方法

CAM通过使用全局平均池化和线性层以确定各个特征和特征对分类结果的影响 ,从而实现对物体的分类与定位。整体方法非常简单,如图1所示,先通过六次卷积得到最终的特征图,它的每一个通道是卷积网络从图像中提取的一类视觉特征,经过GAP(全局平均池化,即整个特征图上的平均池化)操作,我们可以得到各个特征图对应的特征值,再利用线性层再进行分类。

  • 如何对物体进行定位

    低分辨率的特征图被插值以适配原图尺寸且对应原图中提取的一类视觉特征(注意,无池化操作,池化操作与特征位置无关,即会丢失特征的位置信息)。

  • 如何区分区域的重要性

    借助最后线性层的可解释性 , <math xmlns="http://www.w3.org/1998/Math/MathML"> w i w_i </math>wi越大意味着第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个特征图上的特征对于类别terrier的影响越大。

最后,只需用特征对某类别影响值 <math xmlns="http://www.w3.org/1998/Math/MathML"> w i w_i </math>wi乘以相应特征图,再把带权的特征图相加,便可得到最终的类激活热力图。

综上来看,CAM是一个简单,巧妙的方法,用图像分类方法实现了定位等任务(如图2);利用对类别对应的特征图分析,还可以分析类别相关语义(如图3)。

图2 定位任务

图3 类别相关的特征

Grad-CAM

CAM有一些缺点,如必须使用GAP只能分析最后一层卷积的输出仅限于图像分类任务 等,Grad-CAM[3]则可以解决上述问题。Grad-CAM是由美国佐治亚理工大学研究者于2017在ICCV会议提出的。

图4 Grad-CAM的框架图

Grad-CAM方法

Grad-CAM利用输出对特征的偏导数来确定特征对结果的重要性 。对于分类任务,如图4所示,输入经过CNN得到若干特征图 <math xmlns="http://www.w3.org/1998/Math/MathML"> A 1 A^1 </math>A1, <math xmlns="http://www.w3.org/1998/Math/MathML"> ⋅ \cdot </math>⋅, <math xmlns="http://www.w3.org/1998/Math/MathML"> A k A^{k} </math>Ak(每个特征图都学到了相应的特征),通过全连接神经网络得到输出类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c的分数 <math xmlns="http://www.w3.org/1998/Math/MathML"> y c y^c </math>yc,进而计算偏导数形成的张量 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ y c ∂ A k \frac{\partial y^c}{\partial A^{k}} </math>∂Ak∂yc,对该张量进行全局平均池化操作,得到相应特征图对输出类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c的重要性。对计算得到的重要性乘以特征图,求和过ReLU激活函数,则得到了类别激活热力图。Grad-CAM方法也比较好理解,偏导数本身是变量对输出的影响,求特征图对输出的偏导数张量,再平均得到的也就是特征图上特征对输出类别分数的重要性。

图4中Guided Backpropagation操作代码如下所示,

python 复制代码
def guided_backpropagation(model, input_image, target_class):
    model.eval()
    # 将输入图像转换为需要的格式(如tensor)
    input_tensor= preprocess(input_image)
    # 启用梯度跟踪
    input_tensor.requires_grad= True
    # 前向传播
    output= model(input_tensor)
    # 计算目标类别的损失
    loss= output[0, target_class]
    model.zero_grad()
    loss.backward()
    # 获取输入图像的梯度
    gradients= input_tensor.grad.data
    # 对梯度进行处理(保留正梯度,抑制负梯度)
    guided_gradients= gradients.clone()
    guided_gradients[guided_gradients< 0]= 0
    return guided_gradients

Guided Backpropagation考虑高分辨率下对分类结果有影响的特征,通过ReLU激活函数减少了噪声梯度,即像素级特征 ;而Grad-CAM聚焦于粗粒度的重要特征 。将Grad-CAM结果插值再与Guided Backpropagation结果相加,即可得到一个适中的结果,实验结果如图5所示。

图5 Guided Grad-CAM结果

Grad-CAM代码

通过定义下面的钩子类,当模型某卷积层前向传播和反向传播后,就会激活钩子,从而获取相应的激活值和梯度。再对获取的梯度进行处理,就可以实现Grad-CAM功能。

python 复制代码
class HookActsGrads:
    def __init__(self, model):
        self.model= model
        self.activations= None
        self.gradients= None
        self.register_hooks()

    def register_hooks(self):
        self.model.layer4[-1].register_forward_hook(self.save_activation)
        self.model.layer4[-1].register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations= output

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients= grad_output[0]

下图6中有一只虎纹猫和边牧,我们通过对边牧犬的输出分数对卷积层计算偏导数,得到加权的特征图,再插值与原图结合就可得图7。可以看到,模型识别其为边牧,是因为它委屈巴巴的脸。

图6 虎纹猫和边牧

图7 Grad-CAM识别狗的解释

完整代码

详细Grad-CAM代码见我的github.

Reference

1\] \[[1512.04150\] Learning Deep Features for Discriminative Localization](https://link.juejin.cn?target=https%3A%2F%2Farxiv.org%2Fabs%2F1512.04150 "https://arxiv.org/abs/1512.04150") \[2\] [CAM可解释性论文精读:Learning Deep Features for Discriminative Localization](https://link.juejin.cn?target=https%3A%2F%2Fwww.bilibili.com%2Fvideo%2FBV1fe4y1C7mN%2F%3Fshare_source%3Dcopy_web%26vd_source%3D4e1d84dfd96e839e80aa487fc1c25f46 "https://www.bilibili.com/video/BV1fe4y1C7mN/?share_source=copy_web&vd_source=4e1d84dfd96e839e80aa487fc1c25f46") \[3\] \[[1610.02391\] Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization](https://link.juejin.cn?target=https%3A%2F%2Farxiv.org%2Fabs%2F1610.02391 "https://arxiv.org/abs/1610.02391")

相关推荐
阿坡RPA11 小时前
手搓MCP客户端&服务端:从零到实战极速了解MCP是什么?
人工智能·aigc
用户277844910499311 小时前
借助DeepSeek智能生成测试用例:从提示词到Excel表格的全流程实践
人工智能·python
机器之心11 小时前
刚刚,DeepSeek公布推理时Scaling新论文,R2要来了?
人工智能
算AI13 小时前
人工智能+牙科:临床应用中的几个问题
人工智能·算法
凯子坚持 c14 小时前
基于飞桨框架3.0本地DeepSeek-R1蒸馏版部署实战
人工智能·paddlepaddle
你觉得20514 小时前
哈尔滨工业大学DeepSeek公开课:探索大模型原理、技术与应用从GPT到DeepSeek|附视频与讲义下载方法
大数据·人工智能·python·gpt·学习·机器学习·aigc
8K超高清14 小时前
中国8K摄像机:科技赋能文化传承新图景
大数据·人工智能·科技·物联网·智能硬件
hyshhhh15 小时前
【算法岗面试题】深度学习中如何防止过拟合?
网络·人工智能·深度学习·神经网络·算法·计算机视觉
薛定谔的猫-菜鸟程序员15 小时前
零基础玩转深度神经网络大模型:从Hello World到AI炼金术-详解版(含:Conda 全面使用指南)
人工智能·神经网络·dnn
币之互联万物15 小时前
2025 AI智能数字农业研讨会在苏州启幕,科技助农与数据兴业成焦点
人工智能·科技