目录
可视化送入网络的图片
送入的数据为imgs,其大小为(8,3,256,256),并以2行8列进行展示
python
import matplotlib.pyplot as plt
import numpy as np
# 假设你的张量名为 tensor,形状为 (8, 3, 256, 256)
# 假设通道顺序为 RGB
# 将张量的数据格式转换为 (8, 256, 256, 3)
tensor = imgs.permute(0, 2, 3, 1)
# 创建一个 2x4 的子图布局,8 张图像
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(8):
# 选择子图
ax = axes[i // 4, i % 4]
# 获取第 i 张图像的数据
image = tensor[i].numpy()
# 确保图像的像素值在 [0, 1] 范围内
image = np.clip(image, 0, 1)
# 绘制图像
ax.imshow(image)
ax.set_title(f'Image {i + 1}')
ax.axis('off')
plt.tight_layout()
plt.show()
可视化网络层的热力图
python
import torch
import matplotlib.pyplot as plt
# 创建一个空的列表来存储该层的输出
activation = []
# 定义一个钩子函数,用于获取该层的输出
def hook_fn(module, input, output):
activation.append(output)
# 注册钩子到网络的fam4层
model.fam4.register_forward_hook(hook_fn)
# 初始化一个子图,排列方式为2x4
fig, axs = plt.subplots(2, 4, figsize=(16, 8))
# 将输入数据图片传递给网络进行前向传播
output = model(imgs)
for i in range(8):
# 获取钩子记录的该层的输出
layer_output = activation[0]
# 计算热力图
heatmap = layer_output.mean(dim=1, keepdim=True) # 在通道维度上取平均值
# 可视化热力图
axs[i // 4, i % 4].imshow(heatmap[i, 0].cpu().detach().numpy(), cmap='viridis')
axs[i // 4, i % 4].set_title(f'Image {i + 1}')
axs[i // 4, i % 4].axis('off')
plt.show()