如何获取类激活图(CAM)?

什么是类激活图?

类激活图(Class Activation Map,CAM)是一种用于可视化深度学习模型在图像中感兴趣区域的技术。它通常与卷积神经网络(Convolutional Neural Networks,CNN)结合使用。CAM 可以帮助理解模型在图像分类任务中的决策过程,显示模型认为对于某个特定类别的分类而言,图像中哪些区域是关键的。具体来说,通过类激活图,我们可以看到模型在进行分类时对图像的不同部分给予了不同的注意力。 直接看效果图:

如何获取类激活图?

1. 训练模型,提取特征

首先,你需要训练一个卷积神经网络(CNN)模型,确保模型在训练集上学到了有效的特征表示。接着,在模型结构中选择一个卷积层,该层的输出将包含输入图像的特征图,这通常是在网络的最后几个卷积层之一。 本文选择采用ResNet50作为骨干网络,并经过多次实验发现,第四层即网络的最后一层的特征图效果最佳。以下是简略的流程和代码示例:

ini 复制代码
# 读取图片
img = Image.open(img_path).convert('RGB')
# 预处理,transform只是调整img的大小以及将其转为张量
img_tensor = transform(img)
# 添加批次维度,因为是单张图片
img_tensor = img_tensor.unsqueeze(0)    # 维度:[C,H,W]  ---> [B,C,H,W]

# 创建模型
model = build_model(cfg)
# 获取特征图
feature_map = model(img_tensor)

2. 计算类激活图:

针对给定的类别,我们需要计算与该类别对应的特征图的权重。通常,这可以通过全局平均池化(Global Average Pooling)来实现。具体来说,对于每个特征图,计算其所有元素的平均值,然后将这个平均值与相应的权重相乘,得到每个特征图的权重。最终,将所有特征图按照它们的权重进行加权求和,得到最终的类激活图。以下是相应的代码示例:

ini 复制代码
#全局平均池化、展平
feat = model.gap(feature_map)
feat = feat.view(feat.shape[0], -1)
# 计算类别概率
output = model.classifier(feat)
# 获取最相关类别的索引和权重
_, class_index = output.max(1)  
weight = model.classifier.weight[class_index[0]] 
# 使用权重对特征图进行加权求和
cam = feature_map[0] * weight[:, None, None]
# 沿着第一个轴对 cam 进行求和
cam = cam.sum(axis=0)
# 应用ReLU激活并进行归一化,得到最终的类激活图
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min())

3. 可视化:

将得到的类激活图叠加到原始图像上,可以使用热图(heatmap)来表示不同区域的重要性。这样,就可以看到模型对于特定类别的决策是基于图像的哪些区域。代码如下:

ini 复制代码
def show_cam(img_path, cam):
    # 用cv2加载原始图像
    img = cv2.imread(img_path) 
    # 将cam从计算图中分离(不会影响后续梯度的计算)并转换为NumPy数组,重命名为heatmap
    heatmap = cam.detach().numpy()
    # 调整图像和CAM的大小。因为前面对图像的预处理是将图片调整为(256,256)
    img = cv2.resize(img, (256, 256))
    heatmap = cv2.resize(heatmap, (256, 256))
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将灰度热力图转换为彩色热力图
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 转换颜色映射
    # 图像混合,根据具体效果调整参数
    result = cv2.addWeighted(heatmap, 0.4, img, 0.6, 0)
    # 绘制
    plt.imshow(result)  # 使用默认颜色映射
    plt.xticks([])  # 禁用 x 轴刻度
    plt.yticks([])  # 禁用 y 轴刻度
    plt.show()

总结

以上是单张图片获取类激活图的过程,仅供参考。如果处理一个批次的图片,再做相应调整。

相关推荐
Trent19853 小时前
影楼精修-肤色统一算法解析
图像处理·人工智能·算法·计算机视觉
kyle~6 小时前
计算机视觉---目标检测(Object Detecting)概览
人工智能·目标检测·计算机视觉
双翌视觉7 小时前
机器视觉对位手机中框点胶的应用
计算机视觉·机器视觉·视觉对位·视觉软件
白熊1887 小时前
【计算机视觉】OpenCV实战项目:基于OpenCV的车牌识别系统深度解析
人工智能·opencv·计算机视觉
胡耀超8 小时前
霍夫圆变换全面解析(OpenCV)
人工智能·python·opencv·算法·计算机视觉·数据挖掘·数据安全
jndingxin8 小时前
OpenCV CUDA 模块中用于在 GPU 上计算两个数组对应元素差值的绝对值函数absdiff(
人工智能·opencv·计算机视觉
硅谷秋水8 小时前
学习以任务为中心的潜动作,随地采取行动
人工智能·深度学习·计算机视觉·语言模型·机器人
Wnq100729 小时前
工业场景轮式巡检机器人纯视觉识别导航的优势剖析与前景展望
人工智能·算法·计算机视觉·激光雷达·视觉导航·人形机器人·巡检机器人
量子-Alex11 小时前
【目标检测】RT-DETR
人工智能·目标检测·计算机视觉
2201_7549184111 小时前
OpenCV 图像透视变换详解
人工智能·opencv·计算机视觉