卷积神经网络中间层特征图的可视化

python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image


# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(8)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        # x = self.bn(x)
        # x = self.relu(x)
        # x = self.pool(x)
        return x


if __name__ == '__main__':
    # 设置 CPU 张量的随机数种子
    torch.manual_seed(42)

    # 创建模型实例
    model = SimpleCNN()

    # 加载并预处理图片
    img_path = r'E:\photo\123.jpg'
    img = Image.open(img_path).convert('RGB')  # 读取的默认格式为 RGB,这里可去掉 convert()
    preprocess = transforms.Compose([transforms.Resize((960, 960)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img_tensor = preprocess(img).unsqueeze(0)  # (1, C, H, W)

    # 不计算梯度,进行一次前向传播
    with torch.no_grad():
        output = model(img_tensor)

    # 模型输出的图片大小
    print("Output size after conv layer:", output.size())

    # 可视化原始图片
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis('off')
    plt.show()

    # 可视化卷积层后的图片
    for i in range(output.size()[1]):
        plt.subplot(output.size()[1]//4, 4, i+1)
        plt.imshow(output[0, i, :, :].cpu().detach().numpy())
        plt.axis('off')
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.05)
    plt.suptitle('After Conv2d')
    plt.show()

原图大小为:(872, 1280, 3)

原图如下所示:

图片经过卷积层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])

图片经过卷积层后,得到的特征图如下所示:

图片经过 BN 层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])

图片经过 BN 层后,得到的特征图如下所示:

图片经过 ReLU 层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])

图片经过 ReLU 层后,得到的特征图如下所示:

图片经过最大池化层后,得到的特征图大小为:torch.Size([1, 8, 480, 480])

图片经过最大池化层后,得到的特征图如下所示:

PIL 中的 Image.open(img_path) 读取的图片维度为 (W, H, C),读取的图片模式默认为 RGB;

OpenCV 中的 cv2.imread(img_path) 读取的图像维度为 (H, W, C),读取的图片模式默认为 BGR;

Image 图像数据转换为 np.ndarray 时,格式会从 (W, H, C) 转换为 (H, W, C)。

transforms.Resize((960, 960):旨在改变图像的大小,默认使用双线性插值(Bilinear);还支持最近邻插值(Nearest)、双三次插值(Bicubic)。

transforms.ToTensor():将 PIL 图像或 Numpy 数组中的整数像素值转换为 torch.FloatTensor 类型的浮点数;如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 的数据类型是 np.uint8,则将像素值从 [0, 255] 归一化到 [0.0, 1.0],这是通过将每个像素值除以 255 来实现的;将 [H, W, C] 的图像格式转换为 [C, H, W] 的 tensor 格式。

相关推荐
黎燃4 小时前
短视频平台内容推荐算法优化:从协同过滤到多模态深度学习
人工智能
飞哥数智坊5 小时前
多次尝试用 CodeBuddy 做小程序,最终我放弃了
人工智能·ai编程
后端小肥肠6 小时前
别再眼馋 10w + 治愈漫画!Coze 工作流 3 分钟出成品,小白可学
人工智能·aigc·coze
唐某人丶9 小时前
教你如何用 JS 实现 Agent 系统(2)—— 开发 ReAct 版本的“深度搜索”
前端·人工智能·aigc
FIT2CLOUD飞致云9 小时前
九月月报丨MaxKB在不同规模医疗机构的应用进展汇报
人工智能·开源
阿里云大数据AI技术9 小时前
【新模型速递】PAI-Model Gallery云上一键部署Qwen3-Next系列模型
人工智能
袁庭新9 小时前
全球首位AI机器人部长,背负反腐重任
人工智能·aigc
机器之心9 小时前
谁说Scaling Law到头了?新研究:每一步的微小提升会带来指数级增长
人工智能·openai
算家计算10 小时前
AI配音革命!B站最新开源IndexTTS2本地部署教程:精准对口型,情感随心换
人工智能·开源·aigc
量子位10 小时前
马斯克周末血裁xAI 500人
人工智能·ai编程