如何获取类激活图(CAM)?

什么是类激活图?

类激活图(Class Activation Map,CAM)是一种用于可视化深度学习模型在图像中感兴趣区域的技术。它通常与卷积神经网络(Convolutional Neural Networks,CNN)结合使用。CAM 可以帮助理解模型在图像分类任务中的决策过程,显示模型认为对于某个特定类别的分类而言,图像中哪些区域是关键的。具体来说,通过类激活图,我们可以看到模型在进行分类时对图像的不同部分给予了不同的注意力。 直接看效果图:

如何获取类激活图?

1. 训练模型,提取特征

首先,你需要训练一个卷积神经网络(CNN)模型,确保模型在训练集上学到了有效的特征表示。接着,在模型结构中选择一个卷积层,该层的输出将包含输入图像的特征图,这通常是在网络的最后几个卷积层之一。 本文选择采用ResNet50作为骨干网络,并经过多次实验发现,第四层即网络的最后一层的特征图效果最佳。以下是简略的流程和代码示例:

ini 复制代码
# 读取图片
img = Image.open(img_path).convert('RGB')
# 预处理,transform只是调整img的大小以及将其转为张量
img_tensor = transform(img)
# 添加批次维度,因为是单张图片
img_tensor = img_tensor.unsqueeze(0)    # 维度:[C,H,W]  ---> [B,C,H,W]

# 创建模型
model = build_model(cfg)
# 获取特征图
feature_map = model(img_tensor)

2. 计算类激活图:

针对给定的类别,我们需要计算与该类别对应的特征图的权重。通常,这可以通过全局平均池化(Global Average Pooling)来实现。具体来说,对于每个特征图,计算其所有元素的平均值,然后将这个平均值与相应的权重相乘,得到每个特征图的权重。最终,将所有特征图按照它们的权重进行加权求和,得到最终的类激活图。以下是相应的代码示例:

ini 复制代码
#全局平均池化、展平
feat = model.gap(feature_map)
feat = feat.view(feat.shape[0], -1)
# 计算类别概率
output = model.classifier(feat)
# 获取最相关类别的索引和权重
_, class_index = output.max(1)  
weight = model.classifier.weight[class_index[0]] 
# 使用权重对特征图进行加权求和
cam = feature_map[0] * weight[:, None, None]
# 沿着第一个轴对 cam 进行求和
cam = cam.sum(axis=0)
# 应用ReLU激活并进行归一化,得到最终的类激活图
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min())

3. 可视化:

将得到的类激活图叠加到原始图像上,可以使用热图(heatmap)来表示不同区域的重要性。这样,就可以看到模型对于特定类别的决策是基于图像的哪些区域。代码如下:

ini 复制代码
def show_cam(img_path, cam):
    # 用cv2加载原始图像
    img = cv2.imread(img_path) 
    # 将cam从计算图中分离(不会影响后续梯度的计算)并转换为NumPy数组,重命名为heatmap
    heatmap = cam.detach().numpy()
    # 调整图像和CAM的大小。因为前面对图像的预处理是将图片调整为(256,256)
    img = cv2.resize(img, (256, 256))
    heatmap = cv2.resize(heatmap, (256, 256))
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将灰度热力图转换为彩色热力图
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 转换颜色映射
    # 图像混合,根据具体效果调整参数
    result = cv2.addWeighted(heatmap, 0.4, img, 0.6, 0)
    # 绘制
    plt.imshow(result)  # 使用默认颜色映射
    plt.xticks([])  # 禁用 x 轴刻度
    plt.yticks([])  # 禁用 y 轴刻度
    plt.show()

总结

以上是单张图片获取类激活图的过程,仅供参考。如果处理一个批次的图片,再做相应调整。

相关推荐
xiaohouzi11223319 小时前
OpenCV的cv2.VideoCapture如何加GStreamer后端
人工智能·opencv·计算机视觉
小关会打代码19 小时前
计算机视觉案例分享之答题卡识别
人工智能·计算机视觉
天天进步201519 小时前
用Python打造专业级老照片修复工具:让时光倒流的数字魔法
人工智能·计算机视觉
荼蘼19 小时前
答题卡识别改分项目
人工智能·opencv·计算机视觉
IT古董20 小时前
【第五章:计算机视觉-项目实战之图像分类实战】1.经典卷积神经网络模型Backbone与图像-(4)经典卷积神经网络ResNet的架构讲解
人工智能·计算机视觉·cnn
张子夜 iiii21 小时前
4步OpenCV-----扫秒身份证号
人工智能·python·opencv·计算机视觉
paid槮1 天前
机器视觉之图像处理篇
图像处理·opencv·计算机视觉
通街市密人有1 天前
IDF: Iterative Dynamic Filtering Networks for Generalizable Image Denoising
人工智能·深度学习·计算机视觉
sali-tec1 天前
C# 基于halcon的视觉工作流-章34-环状测量
开发语言·图像处理·算法·计算机视觉·c#
小王爱学人工智能1 天前
OpenCV一些进阶操作
人工智能·opencv·计算机视觉