import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
from PIL import Image
import cv2
from typing import List, Dict, Optional
class FeatureVisualizer:
"""卷积特征图可视化器"""
def init(self, model_name: str = 'resnet18', pretrained: bool = True):
"""
初始化特征可视化器
Args:
model_name: 模型名称 ('resnet18', 'vgg16', 'alexnet')
pretrained: 是否使用预训练权重
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self._load_model(model_name, pretrained)
self.model.to(self.device)
self.model.eval()
存储特征图
self.features = {}
self.hooks = []
def _load_model(self, model_name: str, pretrained: bool):
"""加载指定模型"""
model_dict = {
'resnet18': models.resnet18,
'vgg16': models.vgg16,
'alexnet': models.alexnet,
'resnet50': models.resnet50,
'vgg19': models.vgg19
}
if model_name not in model_dict:
raise ValueError(f"模型 {model_name} 不支持")
model = model_dict[model_name](pretrained=pretrained)
return model
def register_hooks(self, layer_names: List[str]):
"""
为指定层注册钩子
Args:
layer_names: 要可视化的层名称列表
"""
self._remove_hooks()
for name, module in self.model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, name=name: self._hook_fn(name, out)
)
self.hooks.append(hook)
def _hook_fn(self, name: str, output):
"""钩子函数,保存特征图"""
self.features[name] = output.detach().cpu()
def _remove_hooks(self):
"""移除所有钩子"""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
self.features.clear()
def preprocess_image(self, image_path: str, img_size: int = 224):
"""
预处理图像
Args:
image_path: 图像路径
img_size: 输入图像大小
"""
定义预处理变换
preprocess = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
加载图像
image = Image.open(image_path).convert('RGB')
input_tensor = preprocess(image).unsqueeze(0)
return input_tensor.to(self.device), image
def visualize_single_layer(
self,
features: torch.Tensor,
layer_name: str,
num_channels: int = 16,
method: str = 'grid'
):
"""
可视化单层特征图
Args:
features: 特征图张量 [batch, channels, height, width]
layer_name: 层名称
num_channels: 显示的特征图数量
method: 可视化方法 ('grid', 'heatmap', 'max_activation')
"""
获取第一个batch的特征图
feature_maps = features[0]
num_total_channels = feature_maps.shape[0]
num_channels = min(num_channels, num_total_channels)
选择要显示的特征图索引
if method == 'max_activation':
选择激活值最大的通道
channel_activations = feature_maps.view(num_total_channels, -1).mean(dim=1)
selected_channels = torch.argsort(channel_activations, descending=True)[:num_channels]
else:
均匀选择通道
selected_channels = torch.linspace(0, num_total_channels-1, num_channels).long()
fig, axes = plt.subplots(1, num_channels, figsize=(3*num_channels, 3))
if num_channels == 1:
axes = [axes]
for idx, channel_idx in enumerate(selected_channels):
feat_map = feature_maps[channel_idx].numpy()
归一化到[0, 1]
feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
if method == 'grid':
axes[idx].imshow(feat_map, cmap='viridis')
axes[idx].set_title(f'Ch {channel_idx}')
elif method == 'heatmap':
heatmap = cv2.applyColorMap(np.uint8(255 * feat_map), cv2.COLORMAP_JET)
axes[idx].imshow(heatmap)
axes[idx].set_title(f'Ch {channel_idx}')
axes[idx].axis('off')
plt.suptitle(f'{layer_name} Feature Maps', fontsize=16, y=1.05)
plt.tight_layout()
return fig
def visualize_layer_comparison(
self,
layer_features: Dict[str, torch.Tensor],
num_channels: int = 8,
save_path: Optional[str] = None
):
"""
对比不同层的特征图
Args:
layer_features: 层名称到特征图的字典
num_channels: 每层显示的特征图数量
save_path: 保存路径
"""
num_layers = len(layer_features)
fig, axes = plt.subplots(num_layers, num_channels,
figsize=(2.5*num_channels, 2.5*num_layers))
if num_layers == 1:
axes = [axes]
for layer_idx, (layer_name, features) in enumerate(layer_features.items()):
feature_maps = features[0]
num_total_channels = feature_maps.shape[0]
选择激活最强的通道
channel_activations = feature_maps.view(num_total_channels, -1).mean(dim=1)
selected_channels = torch.argsort(channel_activations, descending=True)[:num_channels]
for ch_idx, channel_idx in enumerate(selected_channels):
feat_map = feature_maps[channel_idx].numpy()
feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
ax = axes[layer_idx][ch_idx] if num_layers > 1 else axes[ch_idx]
ax.imshow(feat_map, cmap='viridis')
if layer_idx == 0:
ax.set_title(f'Ch {channel_idx}', fontsize=10)
if ch_idx == 0:
ylabel = f'{layer_name}\n({feature_maps.shape[-2]}x{feature_maps.shape[-1]})'
ax.set_ylabel(ylabel, fontsize=10)
ax.axis('off')
plt.suptitle('Convolutional Layer Feature Map Comparison', fontsize=16, y=1.02)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
return fig
def visualize_spatial_hierarchy(
self,
features_dict: Dict[str, torch.Tensor],
save_path: Optional[str] = None
):
"""
可视化空间层次结构
Args:
features_dict: 不同层的特征图字典
save_path: 保存路径
"""
layers = list(features_dict.keys())
num_layers = len(layers)
fig, axes = plt.subplots(num_layers, 4, figsize=(16, 4*num_layers))
for i, (layer_name, features) in enumerate(features_dict.items()):
feature_maps = features[0]
1. 平均特征图
avg_feat = feature_maps.mean(dim=0).numpy()
avg_feat = (avg_feat - avg_feat.min()) / (avg_feat.max() - avg_feat.min() + 1e-8)
2. 最大激活特征图
max_activation = feature_maps.max(dim=0)[0].numpy()
max_activation = (max_activation - max_activation.min()) / (max_activation.max() - max_activation.min() + 1e-8)
3. 通道相关性
channel_corr = torch.corrcoef(feature_maps.view(feature_maps.shape[0], -1))
4. 激活分布
activations = feature_maps.view(-1).numpy()
绘图
axes[i, 0].imshow(avg_feat, cmap='viridis')
axes[i, 0].set_title(f'{layer_name}\nAvg Feature Map', fontsize=10)
axes[i, 0].axis('off')
axes[i, 1].imshow(max_activation, cmap='viridis')
axes[i, 1].set_title('Max Activation Map', fontsize=10)
axes[i, 1].axis('off')
im = axes[i, 2].imshow(channel_corr[:20, :20], cmap='coolwarm', vmin=-1, vmax=1)
axes[i, 2].set_title('Channel Correlation', fontsize=10)
plt.colorbar(im, ax=axes[i, 2])
axes[i, 3].hist(activations, bins=50, alpha=0.7, color='skyblue')
axes[i, 3].set_title('Activation Distribution', fontsize=10)
axes[i, 3].set_xlabel('Activation Value')
axes[i, 3].set_ylabel('Frequency')
plt.suptitle('Feature Map Spatial Hierarchy Analysis', fontsize=16, y=1.02)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
return fig
def analyze_layer_statistics(self, features_dict: Dict[str, torch.Tensor]):
"""
分析层统计信息
Args:
features_dict: 不同层的特征图字典
"""
stats_data = []
for layer_name, features in features_dict.items():
feature_maps = features[0]
stats = {
'Layer': layer_name,
'Channels': feature_maps.shape[0],
'Spatial Size': f"{feature_maps.shape[1]}x{feature_maps.shape[2]}",
'Mean Activation': float(feature_maps.mean()),
'Std Activation': float(feature_maps.std()),
'Max Activation': float(feature_maps.max()),
'Sparsity': float((feature_maps.abs() < 0.01).sum() / feature_maps.numel()),
'Non-linearity': float(feature_maps.relu().mean() / (feature_maps.mean() + 1e-8))
}
stats_data.append(stats)
return stats_data
def run_visualization(
self,
image_path: str,
layer_names: List[str],
visualization_type: str = 'comparison',
save_dir: str = './visualizations/'
):
"""
运行完整的可视化流程
Args:
image_path: 输入图像路径
layer_names: 要可视化的层名称
visualization_type: 可视化类型 ('single', 'comparison', 'hierarchy')
save_dir: 保存目录
"""
import os
os.makedirs(save_dir, exist_ok=True)
1. 预处理图像
input_tensor, original_image = self.preprocess_image(image_path)
2. 注册钩子并前向传播
self.register_hooks(layer_names)
with torch.no_grad():
_ = self.model(input_tensor)
3. 可视化原始图像
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
img_np = input_tensor[0].cpu().permute(1, 2, 0).numpy()
img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img_np = np.clip(img_np, 0, 1)
plt.imshow(img_np)
plt.title('Preprocessed Image')
plt.axis('off')
plt.savefig(f'{save_dir}/input_image.png', dpi=300, bbox_inches='tight')
4. 执行可视化
if visualization_type == 'single':
for layer_name in layer_names:
fig = self.visualize_single_layer(
self.features[layer_name],
layer_name,
num_channels=12,
method='grid'
)
plt.savefig(f'{save_dir}/{layer_name}_features.png', dpi=300, bbox_inches='tight')
plt.show()
elif visualization_type == 'comparison':
fig = self.visualize_layer_comparison(
self.features,
num_channels=8,
save_path=f'{save_dir}/layer_comparison.png'
)
plt.show()
elif visualization_type == 'hierarchy':
fig = self.visualize_spatial_hierarchy(
self.features,
save_path=f'{save_dir}/spatial_hierarchy.png'
)
plt.show()
5. 分析统计信息
stats = self.analyze_layer_statistics(self.features)
打印统计表格
print("\n" + "="*80)
print("LAYER STATISTICS ANALYSIS")
print("="*80)
for stat in stats:
print(f"\n{stat['Layer']}:")
print(f" Channels: {stat['Channels']}")
print(f" Spatial Size: {stat['Spatial Size']}")
print(f" Mean Activation: {stat['Mean Activation']:.4f}")
print(f" Std Activation: {stat['Std Activation']:.4f}")
print(f" Max Activation: {stat['Max Activation']:.4f}")
print(f" Sparsity: {stat['Sparsity']:.2%}")
print(f" Non-linearity Ratio: {stat['Non-linearity']:.4f}")
6. 清理
self._remove_hooks()
return stats
使用示例
def main():
初始化可视化器
visualizer = FeatureVisualizer(model_name='resnet18', pretrained=True)
设置要可视化的层(根据模型结构调整)
if visualizer.model.class.name == 'ResNet':
layer_names = ['layer1.0.conv1', 'layer2.0.conv1', 'layer3.0.conv1', 'layer4.0.conv1']
elif visualizer.model.class.name == 'VGG':
layer_names = ['features.0', 'features.5', 'features.10', 'features.17']
else:
通用选择:前几个卷积层
layer_names = []
for name, module in visualizer.model.named_modules():
if isinstance(module, nn.Conv2d):
layer_names.append(name)
if len(layer_names) >= 4: # 选择前4个卷积层
break
print(f"Selected layers for visualization: {layer_names}")
运行可视化(替换为你的图像路径)
image_path = "path/to/your/image.jpg" # 请替换为实际图像路径
try:
stats = visualizer.run_visualization(
image_path=image_path,
layer_names=layer_names,
visualization_type='comparison', # 可选: 'single', 'comparison', 'hierarchy'
save_dir='./feature_visualizations/'
)
except FileNotFoundError:
print("示例:使用随机图像进行演示")
创建随机图像作为示例
dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
Image.fromarray(dummy_image).save('dummy_image.jpg')
stats = visualizer.run_visualization(
image_path='dummy_image.jpg',
layer_names=layer_names,
visualization_type='comparison',
save_dir='./feature_visualizations/'
)
if name == "main":
main()