[PyTorch]即插即用的热力图生成

先上张效果图,本来打算移植霹雳老师的使用Pytorch实现Grad-CAM并绘制热力图。但是看了下代码,需要骨干网络按照标准写法(即将特征层封装为features数组),而我写的网络图省事并没有进行封装,改造网络的代价又太大了,所以干脆直接重写一个。

一、生成热力图

大致可以分为三步:①读取图片;②前向传递运算;③用特征向量生成特征图。而图片的resize图简单可以直接用t**ransforms,**后面反正也是直接resize回来的,并不会造成变形。

python 复制代码
# 加载一个transforms用于变形,input_shape为预设的图像尺寸
transform = transforms.Compose([transforms.Resize((input_shape[0],input_shape[1])),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])
image = Image.open(image_path)     #image_path为文件路径
input_tensor = transform(image)    #将图片转换为tensor类型
input_batch = input_tensor.unsqueeze(0)    #为tensor添加batch维度

# 前向传递
model.eval()
with torch.no_grad():
    output = model(input_batch)

使用特征图生成热力图的原理是:将该维度上所有的tensor进行叠加,然后将生成的矩阵变形回输入向量的尺寸

python 复制代码
heatmap = torch.sum(output, dim=1)    #所有通道求和
max_value = torch.max(heatmap)
min_value = torch.min(heatmap)
heatmap = (heatmap-min_value)/(max_value-min_value)*255

heatmap = heatmap.cpu().numpy().astype(np.uint8).transpose(1,2,0)  # 提取热力图

heatmap = cv2.resize(heatmap, input_shape,interpolation=cv2.INTER_LINEAR)  # 还原尺寸

# 将矩阵转换为image类
heatmap=cv2.applyColorMap(heatmap,cv2.COLORMAP_JET)
heatimg = Image.fromarray(heatmap)

二、叠加原图

直接使用plt进行叠加!

python 复制代码
    # 将热力图叠加到原图上
    org_size = image.size
    heatimg = heatimg.resize(org_size)    #将热力图变回输入图像的尺寸
    plt.axis('off')
    plt.imshow(image)
    plt.imshow(heatimg, alpha=0.5)  # alpha为热力图的透明度

    # 显示叠加后的图形
    plt.show()

三、总结

这段代码和霹雳老师的Grad-CAM对比优劣都很明显,优点是代码比较简单。上可以通过插入前向传递的环境直接得到任何层的热力图。但缺点就是不能关注特定的类别,且生成的热力图也不是很美观。

相关推荐
nancy_princess2 小时前
clip实验
人工智能·深度学习
飞哥数智坊2 小时前
TRAE Friends@济南第4次活动:100+极客集结,2小时极限编程燃爆全场!
人工智能
AI自动化工坊2 小时前
ProofShot实战:给AI编码助手添加可视化验证,提升前端开发效率3倍
人工智能·ai·开源·github
飞哥数智坊2 小时前
一场直播涨粉 2 万的背后!OpenClaw + 飞书,正在重塑软件交付的方式
人工智能
飞哥数智坊3 小时前
养虾记第3期:安装、调教、落地,这场沙龙我们全聊了
人工智能
再不会python就不礼貌了3 小时前
从工具到个人助理——AI Agent的原理、演进与安全风险
人工智能·安全·ai·大模型·transformer·ai编程
AI医影跨模态组学3 小时前
Radiother Oncol 空军军医大学西京医院等团队:基于纵向CT的亚区域放射组学列线图预测食管鳞状细胞癌根治性放化疗后局部无复发生存期
人工智能·深度学习·医学影像·影像组学
A尘埃3 小时前
神经网络的激活函数+损失函数
人工智能·深度学习·神经网络·激活函数
没有不重的名么3 小时前
Pytorch深度学习快速入门教程
人工智能·pytorch·深度学习
有为少年4 小时前
告别“唯语料论”:用合成抽象数据为大模型开智
人工智能·深度学习·神经网络·算法·机器学习·大模型·预训练