如何获取类激活图(CAM)?

什么是类激活图?

类激活图(Class Activation Map,CAM)是一种用于可视化深度学习模型在图像中感兴趣区域的技术。它通常与卷积神经网络(Convolutional Neural Networks,CNN)结合使用。CAM 可以帮助理解模型在图像分类任务中的决策过程,显示模型认为对于某个特定类别的分类而言,图像中哪些区域是关键的。具体来说,通过类激活图,我们可以看到模型在进行分类时对图像的不同部分给予了不同的注意力。 直接看效果图:

如何获取类激活图?

1. 训练模型,提取特征

首先,你需要训练一个卷积神经网络(CNN)模型,确保模型在训练集上学到了有效的特征表示。接着,在模型结构中选择一个卷积层,该层的输出将包含输入图像的特征图,这通常是在网络的最后几个卷积层之一。 本文选择采用ResNet50作为骨干网络,并经过多次实验发现,第四层即网络的最后一层的特征图效果最佳。以下是简略的流程和代码示例:

ini 复制代码
# 读取图片
img = Image.open(img_path).convert('RGB')
# 预处理,transform只是调整img的大小以及将其转为张量
img_tensor = transform(img)
# 添加批次维度,因为是单张图片
img_tensor = img_tensor.unsqueeze(0)    # 维度:[C,H,W]  ---> [B,C,H,W]

# 创建模型
model = build_model(cfg)
# 获取特征图
feature_map = model(img_tensor)

2. 计算类激活图:

针对给定的类别,我们需要计算与该类别对应的特征图的权重。通常,这可以通过全局平均池化(Global Average Pooling)来实现。具体来说,对于每个特征图,计算其所有元素的平均值,然后将这个平均值与相应的权重相乘,得到每个特征图的权重。最终,将所有特征图按照它们的权重进行加权求和,得到最终的类激活图。以下是相应的代码示例:

ini 复制代码
#全局平均池化、展平
feat = model.gap(feature_map)
feat = feat.view(feat.shape[0], -1)
# 计算类别概率
output = model.classifier(feat)
# 获取最相关类别的索引和权重
_, class_index = output.max(1)  
weight = model.classifier.weight[class_index[0]] 
# 使用权重对特征图进行加权求和
cam = feature_map[0] * weight[:, None, None]
# 沿着第一个轴对 cam 进行求和
cam = cam.sum(axis=0)
# 应用ReLU激活并进行归一化,得到最终的类激活图
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min())

3. 可视化:

将得到的类激活图叠加到原始图像上,可以使用热图(heatmap)来表示不同区域的重要性。这样,就可以看到模型对于特定类别的决策是基于图像的哪些区域。代码如下:

ini 复制代码
def show_cam(img_path, cam):
    # 用cv2加载原始图像
    img = cv2.imread(img_path) 
    # 将cam从计算图中分离(不会影响后续梯度的计算)并转换为NumPy数组,重命名为heatmap
    heatmap = cam.detach().numpy()
    # 调整图像和CAM的大小。因为前面对图像的预处理是将图片调整为(256,256)
    img = cv2.resize(img, (256, 256))
    heatmap = cv2.resize(heatmap, (256, 256))
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将灰度热力图转换为彩色热力图
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 转换颜色映射
    # 图像混合,根据具体效果调整参数
    result = cv2.addWeighted(heatmap, 0.4, img, 0.6, 0)
    # 绘制
    plt.imshow(result)  # 使用默认颜色映射
    plt.xticks([])  # 禁用 x 轴刻度
    plt.yticks([])  # 禁用 y 轴刻度
    plt.show()

总结

以上是单张图片获取类激活图的过程,仅供参考。如果处理一个批次的图片,再做相应调整。

相关推荐
超龄超能程序猿5 小时前
(三)PS识别:基于噪声分析PS识别的技术实现
图像处理·人工智能·计算机视觉
Chef_Chen7 小时前
从0开始学习计算机视觉--Day07--神经网络
神经网络·学习·计算机视觉
加油吧zkf9 小时前
YOLO目标检测数据集类别:分类与应用
人工智能·计算机视觉·目标跟踪
加油吧zkf11 小时前
水下目标检测:突破与创新
人工智能·计算机视觉·目标跟踪
静心问道12 小时前
GoT:超越思维链:语言模型中的有效思维图推理
人工智能·计算机视觉·语言模型
晓131313 小时前
第七章 OpenCV篇——角点检测与特征检测
人工智能·深度学习·计算机视觉
PyAIExplorer15 小时前
图像旋转:从原理到 OpenCV 实践
人工智能·opencv·计算机视觉
PyAIExplorer21 小时前
OpenCV 图像操作:颜色识别、替换与水印添加
人工智能·opencv·计算机视觉
千宇宙航1 天前
闲庭信步使用SV搭建图像测试平台:第三十一课——基于神经网络的手写数字识别
图像处理·人工智能·深度学习·神经网络·计算机视觉·fpga开发
jndingxin1 天前
OpenCV CUDA模块设备层-----高效地计算两个 uint 类型值的带权重平均值
人工智能·opencv·计算机视觉