三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
python
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchinfo import summary
import numpy as np
# 1. 定义示例模型
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc_layers = nn.Sequential(
nn.Linear(32*6*6, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv_layers(x)
x = torch.flatten(x, 1)
x = self.fc_layers(x)
return x
model = CNN()
# 2. 模型摘要可视化 (使用torchinfo)
print("\n=== 模型结构摘要 ===")
summary(
model,
input_size=(1, 3, 32, 32), # (batch, channels, height, width)
col_names=["input_size", "output_size", "num_params", "kernel_size"],
verbose=1
)
# 3. 权重分布可视化
def plot_weight_distribution(model):
plt.figure(figsize=(12, 6))
# 收集所有权重
all_weights = []
for name, param in model.named_parameters():
if 'weight' in name:
flattened = param.detach().cpu().numpy().flatten()
all_weights.extend(flattened)
# 绘制直方图
plt.hist(all_weights, bins=150, alpha=0.7, color='blue', edgecolor='black')
plt.title('Model Weights Distribution')
plt.xlabel('Weight Value')
plt.ylabel('Frequency (log scale)')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()
print("\n=== 权重分布图 ===")
plot_weight_distribution(model)
# 4. 逐层权重可视化 (额外方法)
def plot_layer_weights(model):
plt.figure(figsize=(15, 10))
for i, (name, param) in enumerate(model.named_parameters()):
if 'weight' in name:
plt.subplot(3, 3, i+1)
layer_weights = param.detach().cpu().numpy().flatten()
plt.hist(layer_weights, bins=100, alpha=0.7)
plt.title(f'{name} weight distribution')
plt.grid(True, alpha=0.2)
plt.tight_layout()
plt.show()
print("\n=== 逐层权重分布 ===")
plot_layer_weights(model)
# 5. 权重矩阵可视化 (额外方法)
def visualize_weight_matrix(model):
plt.figure(figsize=(15, 4))
# 获取第一个卷积层的权重
conv1_weight = model.conv_layers[0].weight.detach().cpu()
# 归一化权重用于显示
min_val = torch.min(conv1_weight)
max_val = torch.max(conv1_weight)
normalized_weights = (conv1_weight - min_val) / (max_val - min_val)
# 绘制权重矩阵
for i in range(16): # 显示前16个卷积核
plt.subplot(2, 8, i+1)
plt.imshow(normalized_weights[i].permute(1, 2, 0)) # CHW -> HWC
plt.axis('off')
plt.title(f'Kernel {i+1}')
plt.suptitle('First Conv Layer Kernels', fontsize=16)
plt.tight_layout()
plt.show()
print("\n=== 卷积核可视化 ===")
visualize_weight_matrix(model)