可视化模块

目录

可视化送入网络的图片

送入的数据为imgs,其大小为(8,3,256,256),并以2行8列进行展示

python 复制代码
import matplotlib.pyplot as plt
import numpy as np

# 假设你的张量名为 tensor,形状为 (8, 3, 256, 256)
# 假设通道顺序为 RGB

# 将张量的数据格式转换为 (8, 256, 256, 3)
tensor = imgs.permute(0, 2, 3, 1)

# 创建一个 2x4 的子图布局,8 张图像
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i in range(8):
    # 选择子图
    ax = axes[i // 4, i % 4]

    # 获取第 i 张图像的数据
    image = tensor[i].numpy()

    # 确保图像的像素值在 [0, 1] 范围内
    image = np.clip(image, 0, 1)

    # 绘制图像
    ax.imshow(image)
    ax.set_title(f'Image {i + 1}')
    ax.axis('off')

plt.tight_layout()
plt.show()

可视化网络层的热力图

python 复制代码
import torch
import matplotlib.pyplot as plt

# 创建一个空的列表来存储该层的输出
activation = []

# 定义一个钩子函数,用于获取该层的输出
def hook_fn(module, input, output):
    activation.append(output)

# 注册钩子到网络的fam4层
model.fam4.register_forward_hook(hook_fn)

# 初始化一个子图,排列方式为2x4
fig, axs = plt.subplots(2, 4, figsize=(16, 8))


# 将输入数据图片传递给网络进行前向传播
output = model(imgs)  

for i in range(8):
    # 获取钩子记录的该层的输出
    layer_output = activation[0]

    # 计算热力图
    heatmap = layer_output.mean(dim=1, keepdim=True)  # 在通道维度上取平均值

    # 可视化热力图
    axs[i // 4, i % 4].imshow(heatmap[i, 0].cpu().detach().numpy(), cmap='viridis')
    axs[i // 4, i % 4].set_title(f'Image {i + 1}')
    axs[i // 4, i % 4].axis('off')

plt.show()
相关推荐
勾股导航15 小时前
蚁群优化算法
人工智能·pytorch·python
All The Way North-1 天前
【LSTM系列·终篇】PyTorch nn.LSTM 终极指南:从API原理到双向多层实战,彻底告别维度错误!
pytorch·rnn·lstm·多层lstm·api详解·序列模型·双向lstm
deep_drink1 天前
【论文精读(三)】PointMLP:大道至简,无需卷积与注意力的纯MLP点云网络 (ICLR 2022)
人工智能·pytorch·python·深度学习·3d·point cloud
lanbo_ai2 天前
基于yolov10的火焰、火灾检测系统,支持图像、视频和摄像实时检测【pytorch框架、python源码】
pytorch·python·yolo
盼小辉丶2 天前
PyTorch实战(29)——使用TorchServe部署PyTorch模型
人工智能·pytorch·深度学习·模型部署
IRevers3 天前
【YOLO】YOLO-Master 腾讯轻量级YOLO架构超越YOLO-13(含检测和分割推理)
图像处理·人工智能·pytorch·python·yolo·transformer·边缘计算
小锋java12343 天前
【技术专题】PyTorch2 深度学习 - 张量(Tensor)的定义与操作
pytorch·深度学习
归一码字3 天前
DDPG手写讲解
人工智能·pytorch
七夜zippoe4 天前
图神经网络实战:从社交网络到推荐系统的工业级应用
网络·人工智能·pytorch·python·神经网络·cora
本是少年4 天前
构建 HuggingFace 图像-文本数据集指南
pytorch·transformer