如何获取类激活图(CAM)?

什么是类激活图?

类激活图(Class Activation Map,CAM)是一种用于可视化深度学习模型在图像中感兴趣区域的技术。它通常与卷积神经网络(Convolutional Neural Networks,CNN)结合使用。CAM 可以帮助理解模型在图像分类任务中的决策过程,显示模型认为对于某个特定类别的分类而言,图像中哪些区域是关键的。具体来说,通过类激活图,我们可以看到模型在进行分类时对图像的不同部分给予了不同的注意力。 直接看效果图:

如何获取类激活图?

1. 训练模型,提取特征

首先,你需要训练一个卷积神经网络(CNN)模型,确保模型在训练集上学到了有效的特征表示。接着,在模型结构中选择一个卷积层,该层的输出将包含输入图像的特征图,这通常是在网络的最后几个卷积层之一。 本文选择采用ResNet50作为骨干网络,并经过多次实验发现,第四层即网络的最后一层的特征图效果最佳。以下是简略的流程和代码示例:

ini 复制代码
# 读取图片
img = Image.open(img_path).convert('RGB')
# 预处理,transform只是调整img的大小以及将其转为张量
img_tensor = transform(img)
# 添加批次维度,因为是单张图片
img_tensor = img_tensor.unsqueeze(0)    # 维度:[C,H,W]  ---> [B,C,H,W]

# 创建模型
model = build_model(cfg)
# 获取特征图
feature_map = model(img_tensor)

2. 计算类激活图:

针对给定的类别,我们需要计算与该类别对应的特征图的权重。通常,这可以通过全局平均池化(Global Average Pooling)来实现。具体来说,对于每个特征图,计算其所有元素的平均值,然后将这个平均值与相应的权重相乘,得到每个特征图的权重。最终,将所有特征图按照它们的权重进行加权求和,得到最终的类激活图。以下是相应的代码示例:

ini 复制代码
#全局平均池化、展平
feat = model.gap(feature_map)
feat = feat.view(feat.shape[0], -1)
# 计算类别概率
output = model.classifier(feat)
# 获取最相关类别的索引和权重
_, class_index = output.max(1)  
weight = model.classifier.weight[class_index[0]] 
# 使用权重对特征图进行加权求和
cam = feature_map[0] * weight[:, None, None]
# 沿着第一个轴对 cam 进行求和
cam = cam.sum(axis=0)
# 应用ReLU激活并进行归一化,得到最终的类激活图
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min())

3. 可视化:

将得到的类激活图叠加到原始图像上,可以使用热图(heatmap)来表示不同区域的重要性。这样,就可以看到模型对于特定类别的决策是基于图像的哪些区域。代码如下:

ini 复制代码
def show_cam(img_path, cam):
    # 用cv2加载原始图像
    img = cv2.imread(img_path) 
    # 将cam从计算图中分离(不会影响后续梯度的计算)并转换为NumPy数组,重命名为heatmap
    heatmap = cam.detach().numpy()
    # 调整图像和CAM的大小。因为前面对图像的预处理是将图片调整为(256,256)
    img = cv2.resize(img, (256, 256))
    heatmap = cv2.resize(heatmap, (256, 256))
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将灰度热力图转换为彩色热力图
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 转换颜色映射
    # 图像混合,根据具体效果调整参数
    result = cv2.addWeighted(heatmap, 0.4, img, 0.6, 0)
    # 绘制
    plt.imshow(result)  # 使用默认颜色映射
    plt.xticks([])  # 禁用 x 轴刻度
    plt.yticks([])  # 禁用 y 轴刻度
    plt.show()

总结

以上是单张图片获取类激活图的过程,仅供参考。如果处理一个批次的图片,再做相应调整。

相关推荐
Coovally AI模型快速验证2 小时前
仅192万参数的目标检测模型,Micro-YOLO如何做到目标检测精度与效率兼得
人工智能·神经网络·yolo·目标检测·计算机视觉·目标跟踪·自然语言处理
Mrs.Gril4 小时前
目标检测: rtdetr在RK3588上部署
人工智能·目标检测·计算机视觉
qunaa01014 小时前
【计算机视觉】YOLOv10n-SPPF-LSKA托盘识别与检测
人工智能·yolo·计算机视觉
管牛牛4 小时前
图像的几何变换
人工智能·opencv·计算机视觉
sali-tec5 小时前
C# 基于OpenCv的视觉工作流-章11-高斯滤波
图像处理·人工智能·opencv·算法·计算机视觉
PeterClerk6 小时前
计算机视觉(CV)期刊(按 CCF 推荐目录 A/B/C + 交叉方向整理
论文阅读·图像处理·人工智能·深度学习·搜索引擎·计算机视觉·计算机期刊
智驱力人工智能7 小时前
矿山皮带锚杆等异物识别 从事故预防到智慧矿山的工程实践 锚杆检测 矿山皮带铁丝异物AI预警系统 工厂皮带木桩异物实时预警技术
人工智能·算法·安全·yolo·目标检测·计算机视觉·边缘计算
hudawei9967 小时前
google.mlkit:face-detection和 opencv的人脸识别有什么区别
人工智能·opencv·计算机视觉·google·人脸识别·mlkit·face-detection
格林威7 小时前
多光源条件下图像一致性校正:消除阴影与高光干扰的 6 个核心策略,附 OpenCV+Halcon 实战代码!
人工智能·数码相机·opencv·算法·计算机视觉·分类·视觉检测
程序员哈基耄8 小时前
一站式在线图像编辑器:全面解析多功能图像处理工具
图像处理·人工智能·计算机视觉