卷积神经网络中间层特征图的可视化

python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image


# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(8)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        # x = self.bn(x)
        # x = self.relu(x)
        # x = self.pool(x)
        return x


if __name__ == '__main__':
    # 设置 CPU 张量的随机数种子
    torch.manual_seed(42)

    # 创建模型实例
    model = SimpleCNN()

    # 加载并预处理图片
    img_path = r'E:\photo\123.jpg'
    img = Image.open(img_path).convert('RGB')  # 读取的默认格式为 RGB,这里可去掉 convert()
    preprocess = transforms.Compose([transforms.Resize((960, 960)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img_tensor = preprocess(img).unsqueeze(0)  # (1, C, H, W)

    # 不计算梯度,进行一次前向传播
    with torch.no_grad():
        output = model(img_tensor)

    # 模型输出的图片大小
    print("Output size after conv layer:", output.size())

    # 可视化原始图片
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis('off')
    plt.show()

    # 可视化卷积层后的图片
    for i in range(output.size()[1]):
        plt.subplot(output.size()[1]//4, 4, i+1)
        plt.imshow(output[0, i, :, :].cpu().detach().numpy())
        plt.axis('off')
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.05)
    plt.suptitle('After Conv2d')
    plt.show()

原图大小为:(872, 1280, 3)

原图如下所示:

图片经过卷积层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])

图片经过卷积层后,得到的特征图如下所示:

图片经过 BN 层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])

图片经过 BN 层后,得到的特征图如下所示:

图片经过 ReLU 层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])

图片经过 ReLU 层后,得到的特征图如下所示:

图片经过最大池化层后,得到的特征图大小为:torch.Size([1, 8, 480, 480])

图片经过最大池化层后,得到的特征图如下所示:

PIL 中的 Image.open(img_path) 读取的图片维度为 (W, H, C),读取的图片模式默认为 RGB;

OpenCV 中的 cv2.imread(img_path) 读取的图像维度为 (H, W, C),读取的图片模式默认为 BGR;

Image 图像数据转换为 np.ndarray 时,格式会从 (W, H, C) 转换为 (H, W, C)。

transforms.Resize((960, 960):旨在改变图像的大小,默认使用双线性插值(Bilinear);还支持最近邻插值(Nearest)、双三次插值(Bicubic)。

transforms.ToTensor():将 PIL 图像或 Numpy 数组中的整数像素值转换为 torch.FloatTensor 类型的浮点数;如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 的数据类型是 np.uint8,则将像素值从 [0, 255] 归一化到 [0.0, 1.0],这是通过将每个像素值除以 255 来实现的;将 [H, W, C] 的图像格式转换为 [C, H, W] 的 tensor 格式。

相关推荐
AdCj33 分钟前
OpenAI 如何安全运行 Codex:Agent 时代的“AI 安全操作系统
人工智能·安全
Vwms4 分钟前
2026年电商行业WMS系统选型指南
大数据·人工智能·产品运营
Bnews12 分钟前
机器人高精度轨迹定位设备选型指南:赋能前沿科研创新
人工智能
nashane13 分钟前
HarmonyOS 6学习:Navigation Dialog模式与智能Web长截图融合实践
人工智能·pytorch·python
月光技术杂谈18 分钟前
拆解中国移动AI-eSIM:当一张SIM卡开始“调用大模型”,运营商到底在赌什么?
人工智能·esim·安全认证·中国移动·ai-esim·tocken·大模型调度
用户43305141438120 分钟前
CLI 和 REPL 工作流
人工智能
xiaoxiaoxiaolll21 分钟前
Nature Communications:三维超原子库+原子层保护,突破全彩VR超透镜量产瓶颈
人工智能·算法
用户43305141438121 分钟前
编写一个最基本的 ReAct Agent
人工智能
用户51914958484524 分钟前
WordPress Portfolleo 插件漏洞利用工具 (CVE-2024-49653)
人工智能·aigc
俊哥V25 分钟前
每日 AI 研究简报 · 2026-05-13
人工智能·ai