使用GradCAM对vision transformer进行注意力可视化的过程中遇到的问题记载

0.pip install grad-cam

1.首先是跑了其他一些博主的帖子,都没跑通,最后在知乎上找到的这篇帖子,修改了图片的输入设置后跑通了,实例代码如下

bash 复制代码
import cv2
import numpy as np
import torch

from pytorch_grad_cam import GradCAM, \
                            ScoreCAM, \
                            GradCAMPlusPlus, \
                            AblationCAM, \
                            XGradCAM, \
                            EigenCAM, \
                            EigenGradCAM, \
                            LayerCAM, \
                            FullGrad

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

# 加载预训练的 ViT 模型
model = torch.hub.load('facebookresearch/deit:main','deit_tiny_patch16_224', pretrained=True)
model.eval()

# 判断是否使用 GPU 加速
use_cuda = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()

def reshape_transform(tensor, height=14, width=14):
    # 去掉cls token
    result = tensor[:, 1:, :].reshape(tensor.size(0),
    height, width, tensor.size(2))

    # 将通道维度放到第一个位置
    result = result.transpose(2, 3).transpose(1, 2)
    return result

# 创建 GradCAM 对象
cam = GradCAM(model=model,
            target_layers=[model.blocks[-1].norm1],
            # 这里的target_layer要看模型情况,
            # 比如还有可能是:target_layers = [model.blocks[-1].ffn.norm]
            use_cuda=use_cuda,
            reshape_transform=reshape_transform)

# 读取输入图像
image_path = "2.png"
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))
rgb_img = np.float32(rgb_img) / 255.0

# 预处理图像
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# 看情况将图像转换为批量形式
# input_tensor = input_tensor.unsqueeze(0)
if use_cuda:
    input_tensor = input_tensor.cuda()

# 计算 grad-cam
target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别
grayscale_cam = cam(input_tensor=input_tensor, targets=target_category)
grayscale_cam = grayscale_cam[0, :]

# 将 grad-cam 的输出叠加到原始图像上
visualization = show_cam_on_image(rgb_img, grayscale_cam)

# 保存可视化结果
cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR, visualization)
cv2.imwrite('cam3.jpg', visualization)

参考链接:

bash 复制代码
https://zhuanlan.zhihu.com/p/640450435

2.接下来参考了这位大佬的博客,针对我的网络进行修改,一些遇到的错误在大佬在里面说了

bash 复制代码
https://blog.csdn.net/holly_Z_P_F/article/details/130011296

3.新的问题

AttributeError: 'NoneType' object has no attribute 'shape'

bash 复制代码
Traceback (most recent call last):
  File "D:\yanjiusheng\ZSE-SBIR\kk.py", line 68, in <module>
    test()
  File "D:\yanjiusheng\ZSE-SBIR\kk.py", line 56, in test
    grayscale_cam = cam(input_tensor=input_tensor, targets=target_category)
  File "D:\anaconda3\envs\zse-sbir\lib\site-packages\pytorch_grad_cam\base_cam.py", line 186, in __call__
    return self.forward(input_tensor, targets, eigen_smooth)
  File "D:\anaconda3\envs\zse-sbir\lib\site-packages\pytorch_grad_cam\base_cam.py", line 110, in forward
    cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
  File "D:\anaconda3\envs\zse-sbir\lib\site-packages\pytorch_grad_cam\base_cam.py", line 141, in compute_cam_per_layer
    cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
  File "D:\anaconda3\envs\zse-sbir\lib\site-packages\pytorch_grad_cam\base_cam.py", line 66, in get_cam_image
    weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
  File "D:\anaconda3\envs\zse-sbir\lib\site-packages\pytorch_grad_cam\grad_cam.py", line 23, in get_cam_weights
    if len(grads.shape) == 4:
AttributeError: 'NoneType' object has no attribute 'shape'

这个问题我最后解决也是一知半解做出来了,目标层的问题:GradCAM 需要一个可以计算梯度的层作为 target_layer,通常是卷积层或自注意力机制中的特定部分,我原来选择的是model.sa.model.transformer.layers[-1] 作为目标层,但是这可能并不是一个适合用于 GradCAM 的层。我选择了原来目标层下更具体的一层,即model.sa.model.transformer.layers[-1][0].layer_norm_input ,兄弟们可以 print(model) 看下结构,多试几层看看哪层能用,我太菜了只能试。。。

参考的另外一些帖子:

bash 复制代码
https://ask.csdn.net/questions/8051598
bash 复制代码
https://github.com/open-mmlab/mmdetection/issues/1809
相关推荐
张较瘦_1 小时前
[论文阅读] 人工智能 + 软件工程 | 需求获取访谈中LLM生成跟进问题研究:来龙去脉与创新突破
论文阅读·人工智能
一 铭2 小时前
AI领域新趋势:从提示(Prompt)工程到上下文(Context)工程
人工智能·语言模型·大模型·llm·prompt
麻雀无能为力5 小时前
CAU数据挖掘实验 表分析数据插件
人工智能·数据挖掘·中国农业大学
时序之心5 小时前
时空数据挖掘五大革新方向详解篇!
人工智能·数据挖掘·论文·时间序列
.30-06Springfield6 小时前
人工智能概念之七:集成学习思想(Bagging、Boosting、Stacking)
人工智能·算法·机器学习·集成学习
说私域7 小时前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的超级文化符号构建路径研究
人工智能·小程序·开源
永洪科技7 小时前
永洪科技荣获商业智能品牌影响力奖,全力打造”AI+决策”引擎
大数据·人工智能·科技·数据分析·数据可视化·bi
shangyingying_17 小时前
关于小波降噪、小波增强、小波去雾的原理区分
人工智能·深度学习·计算机视觉
书玮嘎8 小时前
【WIP】【VLA&VLM——InternVL系列】
人工智能·深度学习
猫头虎8 小时前
猫头虎 AI工具分享:一个网页抓取、结构化数据提取、网页爬取、浏览器自动化操作工具:Hyperbrowser MCP
运维·人工智能·gpt·开源·自动化·文心一言·ai编程