python第35天打卡

三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化

python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchinfo import summary
import numpy as np

# 1. 定义示例模型
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(32*6*6, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, 1)
        x = self.fc_layers(x)
        return x

model = CNN()

# 2. 模型摘要可视化 (使用torchinfo)
print("\n=== 模型结构摘要 ===")
summary(
    model, 
    input_size=(1, 3, 32, 32),  # (batch, channels, height, width)
    col_names=["input_size", "output_size", "num_params", "kernel_size"],
    verbose=1
)

# 3. 权重分布可视化
def plot_weight_distribution(model):
    plt.figure(figsize=(12, 6))
    
    # 收集所有权重
    all_weights = []
    for name, param in model.named_parameters():
        if 'weight' in name:
            flattened = param.detach().cpu().numpy().flatten()
            all_weights.extend(flattened)
    
    # 绘制直方图
    plt.hist(all_weights, bins=150, alpha=0.7, color='blue', edgecolor='black')
    plt.title('Model Weights Distribution')
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency (log scale)')
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    plt.show()

print("\n=== 权重分布图 ===")
plot_weight_distribution(model)

# 4. 逐层权重可视化 (额外方法)
def plot_layer_weights(model):
    plt.figure(figsize=(15, 10))
    
    for i, (name, param) in enumerate(model.named_parameters()):
        if 'weight' in name:
            plt.subplot(3, 3, i+1)
            layer_weights = param.detach().cpu().numpy().flatten()
            plt.hist(layer_weights, bins=100, alpha=0.7)
            plt.title(f'{name} weight distribution')
            plt.grid(True, alpha=0.2)
    
    plt.tight_layout()
    plt.show()

print("\n=== 逐层权重分布 ===")
plot_layer_weights(model)

# 5. 权重矩阵可视化 (额外方法)
def visualize_weight_matrix(model):
    plt.figure(figsize=(15, 4))
    
    # 获取第一个卷积层的权重
    conv1_weight = model.conv_layers[0].weight.detach().cpu()
    
    # 归一化权重用于显示
    min_val = torch.min(conv1_weight)
    max_val = torch.max(conv1_weight)
    normalized_weights = (conv1_weight - min_val) / (max_val - min_val)
    
    # 绘制权重矩阵
    for i in range(16):  # 显示前16个卷积核
        plt.subplot(2, 8, i+1)
        plt.imshow(normalized_weights[i].permute(1, 2, 0))  # CHW -> HWC
        plt.axis('off')
        plt.title(f'Kernel {i+1}')
    
    plt.suptitle('First Conv Layer Kernels', fontsize=16)
    plt.tight_layout()
    plt.show()

print("\n=== 卷积核可视化 ===")
visualize_weight_matrix(model)

@浙大疏锦行

相关推荐
猫头虎几秒前
PyCharm 2025.3 最新变化:值得更新吗?
ide·爬虫·python·pycharm·beautifulsoup·ai编程·pip
ekprada5 分钟前
DAY45 TensorBoard深度学习可视化工具
人工智能·python
轻竹办公PPT7 分钟前
PPT生成效率提升的方法:AI生成PPT实战说明
人工智能·python·powerpoint
YJlio8 分钟前
Python 一键拆分 PDF:按“目录/章节”建文件夹 + 每页单独导出(支持书签识别&正文识别)
开发语言·python·pdf
IT方大同9 分钟前
C语言进制转化
c语言·开发语言
Amelia11111115 分钟前
day30
python
野生风长18 分钟前
从零开始的C语言:文件操作与数据存储(上)(文件的分类,文件的打开和关闭)
c语言·开发语言
我是哈哈hh29 分钟前
【Python数据分析】数据可视化(全)
开发语言·python·信息可视化·数据挖掘·数据分析
魔镜前的帅比30 分钟前
LangGraph(流程化控制)
python·langchain
yaoh.wang31 分钟前
力扣(LeetCode) 69: x 的平方根 - 解法思路
python·算法·leetcode·面试·职场和发展·牛顿法·二分法