如何获取类激活图(CAM)?

什么是类激活图?

类激活图(Class Activation Map,CAM)是一种用于可视化深度学习模型在图像中感兴趣区域的技术。它通常与卷积神经网络(Convolutional Neural Networks,CNN)结合使用。CAM 可以帮助理解模型在图像分类任务中的决策过程,显示模型认为对于某个特定类别的分类而言,图像中哪些区域是关键的。具体来说,通过类激活图,我们可以看到模型在进行分类时对图像的不同部分给予了不同的注意力。 直接看效果图:

如何获取类激活图?

1. 训练模型,提取特征

首先,你需要训练一个卷积神经网络(CNN)模型,确保模型在训练集上学到了有效的特征表示。接着,在模型结构中选择一个卷积层,该层的输出将包含输入图像的特征图,这通常是在网络的最后几个卷积层之一。 本文选择采用ResNet50作为骨干网络,并经过多次实验发现,第四层即网络的最后一层的特征图效果最佳。以下是简略的流程和代码示例:

ini 复制代码
# 读取图片
img = Image.open(img_path).convert('RGB')
# 预处理,transform只是调整img的大小以及将其转为张量
img_tensor = transform(img)
# 添加批次维度,因为是单张图片
img_tensor = img_tensor.unsqueeze(0)    # 维度:[C,H,W]  ---> [B,C,H,W]

# 创建模型
model = build_model(cfg)
# 获取特征图
feature_map = model(img_tensor)

2. 计算类激活图:

针对给定的类别,我们需要计算与该类别对应的特征图的权重。通常,这可以通过全局平均池化(Global Average Pooling)来实现。具体来说,对于每个特征图,计算其所有元素的平均值,然后将这个平均值与相应的权重相乘,得到每个特征图的权重。最终,将所有特征图按照它们的权重进行加权求和,得到最终的类激活图。以下是相应的代码示例:

ini 复制代码
#全局平均池化、展平
feat = model.gap(feature_map)
feat = feat.view(feat.shape[0], -1)
# 计算类别概率
output = model.classifier(feat)
# 获取最相关类别的索引和权重
_, class_index = output.max(1)  
weight = model.classifier.weight[class_index[0]] 
# 使用权重对特征图进行加权求和
cam = feature_map[0] * weight[:, None, None]
# 沿着第一个轴对 cam 进行求和
cam = cam.sum(axis=0)
# 应用ReLU激活并进行归一化,得到最终的类激活图
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min())

3. 可视化:

将得到的类激活图叠加到原始图像上,可以使用热图(heatmap)来表示不同区域的重要性。这样,就可以看到模型对于特定类别的决策是基于图像的哪些区域。代码如下:

ini 复制代码
def show_cam(img_path, cam):
    # 用cv2加载原始图像
    img = cv2.imread(img_path) 
    # 将cam从计算图中分离(不会影响后续梯度的计算)并转换为NumPy数组,重命名为heatmap
    heatmap = cam.detach().numpy()
    # 调整图像和CAM的大小。因为前面对图像的预处理是将图片调整为(256,256)
    img = cv2.resize(img, (256, 256))
    heatmap = cv2.resize(heatmap, (256, 256))
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将灰度热力图转换为彩色热力图
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 转换颜色映射
    # 图像混合,根据具体效果调整参数
    result = cv2.addWeighted(heatmap, 0.4, img, 0.6, 0)
    # 绘制
    plt.imshow(result)  # 使用默认颜色映射
    plt.xticks([])  # 禁用 x 轴刻度
    plt.yticks([])  # 禁用 y 轴刻度
    plt.show()

总结

以上是单张图片获取类激活图的过程,仅供参考。如果处理一个批次的图片,再做相应调整。

相关推荐
却道天凉_好个秋2 小时前
OpenCV(二十九):高通滤波-索贝尔算子
人工智能·opencv·计算机视觉
AndrewHZ2 小时前
【图像处理基石】如何从色彩的角度分析一张图是否是好图?
图像处理·计算机视觉·cv·聚类算法·色彩科学
点云SLAM3 小时前
四元数 (Quaternion)微分-四元数导数的矩阵表示推导(8)
线性代数·算法·计算机视觉·矩阵·机器人·slam·四元数
却道天凉_好个秋3 小时前
OpenCV(二十八):双边滤波
人工智能·opencv·计算机视觉
B站_计算机毕业设计之家5 小时前
python手写数字识别计分系统+CNN模型+YOLOv5模型 深度学习 计算机毕业设计(建议收藏)✅
python·深度学习·yolo·计算机视觉·数据分析·cnn
CoovallyAIHub7 小时前
超越像素的视觉:亚像素边缘检测原理、方法与实战
深度学习·算法·计算机视觉
CoovallyAIHub7 小时前
中科大西工大提出RSKT-Seg:精度速度双提升,开放词汇分割不再难
深度学习·算法·计算机视觉
yolo_guo8 小时前
opencv 学习: QA_02 什么是图像中的高频成分和低频成分
linux·c++·opencv·计算机视觉
算法与编程之美8 小时前
探索不同的优化器、损失函数、batch_size对分类精度影响
人工智能·机器学习·计算机视觉·分类·batch
AI科技星9 小时前
引力编程时代:人类文明存续与升维
数据结构·人工智能·经验分享·算法·计算机视觉