import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
from PIL import Image
import cv2
from typing import List, Dict, Optional
class FeatureVisualizer:
"""卷积特征图可视化器"""
def init(self, model_name: str = 'resnet18', pretrained: bool = True):
"""
初始化特征可视化器
Args:
model_name: 模型名称 ('resnet18', 'vgg16', 'alexnet')
pretrained: 是否使用预训练权重
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self._load_model(model_name, pretrained)
self.model.to(self.device)
self.model.eval()
存储特征图
self.features = {}
self.hooks = \[\]
def _load_model(self, model_name: str, pretrained: bool):
"""加载指定模型"""
model_dict = {
'resnet18': models.resnet18,
'vgg16': models.vgg16,
'alexnet': models.alexnet,
'resnet50': models.resnet50,
'vgg19': models.vgg19
}
if model_name not in model_dict:
raise ValueError(f"模型 {model_name} 不支持")
model = model_dictmodel_name(pretrained=pretrained)
return model
def register_hooks(self, layer_names: Liststr):
"""
为指定层注册钩子
Args:
layer_names: 要可视化的层名称列表
"""
self._remove_hooks()
for name, module in self.model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, name=name: self._hook_fn(name, out)
)
self.hooks.append(hook)
def _hook_fn(self, name: str, output):
"""钩子函数,保存特征图"""
self.featuresname = output.detach().cpu()
def _remove_hooks(self):
"""移除所有钩子"""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
self.features.clear()
def preprocess_image(self, image_path: str, img_size: int = 224):
"""
预处理图像
Args:
image_path: 图像路径
img_size: 输入图像大小
"""
定义预处理变换
preprocess = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=0.485, 0.456, 0.406,
std=0.229, 0.224, 0.225
)
])
加载图像
image = Image.open(image_path).convert('RGB')
input_tensor = preprocess(image).unsqueeze(0)
return input_tensor.to(self.device), image
def visualize_single_layer(
self,
features: torch.Tensor,
layer_name: str,
num_channels: int = 16,
method: str = 'grid'
):
"""
可视化单层特征图
Args:
features: 特征图张量 batch, channels, height, width
layer_name: 层名称
num_channels: 显示的特征图数量
method: 可视化方法 ('grid', 'heatmap', 'max_activation')
"""
获取第一个batch的特征图
feature_maps = features0
num_total_channels = feature_maps.shape0
num_channels = min(num_channels, num_total_channels)
选择要显示的特征图索引
if method == 'max_activation':
选择激活值最大的通道
channel_activations = feature_maps.view(num_total_channels, -1).mean(dim=1)
selected_channels = torch.argsort(channel_activations, descending=True):num_channels
else:
均匀选择通道
selected_channels = torch.linspace(0, num_total_channels-1, num_channels).long()
fig, axes = plt.subplots(1, num_channels, figsize=(3*num_channels, 3))
if num_channels == 1:
axes = axes
for idx, channel_idx in enumerate(selected_channels):
feat_map = feature_mapschannel_idx.numpy()
归一化到0, 1
feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
if method == 'grid':
axesidx.imshow(feat_map, cmap='viridis')
axesidx.set_title(f'Ch {channel_idx}')
elif method == 'heatmap':
heatmap = cv2.applyColorMap(np.uint8(255 * feat_map), cv2.COLORMAP_JET)
axesidx.imshow(heatmap)
axesidx.set_title(f'Ch {channel_idx}')
axesidx.axis('off')
plt.suptitle(f'{layer_name} Feature Maps', fontsize=16, y=1.05)
plt.tight_layout()
return fig
def visualize_layer_comparison(
self,
layer_features: Dictstr, torch.Tensor,
num_channels: int = 8,
save_path: Optionalstr = None
):
"""
对比不同层的特征图
Args:
layer_features: 层名称到特征图的字典
num_channels: 每层显示的特征图数量
save_path: 保存路径
"""
num_layers = len(layer_features)
fig, axes = plt.subplots(num_layers, num_channels,
figsize=(2.5*num_channels, 2.5*num_layers))
if num_layers == 1:
axes = axes
for layer_idx, (layer_name, features) in enumerate(layer_features.items()):
feature_maps = features0
num_total_channels = feature_maps.shape0
选择激活最强的通道
channel_activations = feature_maps.view(num_total_channels, -1).mean(dim=1)
selected_channels = torch.argsort(channel_activations, descending=True):num_channels
for ch_idx, channel_idx in enumerate(selected_channels):
feat_map = feature_mapschannel_idx.numpy()
feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
ax = axeslayer_idxch_idx if num_layers > 1 else axesch_idx
ax.imshow(feat_map, cmap='viridis')
if layer_idx == 0:
ax.set_title(f'Ch {channel_idx}', fontsize=10)
if ch_idx == 0:
ylabel = f'{layer_name}\n({feature_maps.shape-2}x{feature_maps.shape-1})'
ax.set_ylabel(ylabel, fontsize=10)
ax.axis('off')
plt.suptitle('Convolutional Layer Feature Map Comparison', fontsize=16, y=1.02)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
return fig
def visualize_spatial_hierarchy(
self,
features_dict: Dictstr, torch.Tensor,
save_path: Optionalstr = None
):
"""
可视化空间层次结构
Args:
features_dict: 不同层的特征图字典
save_path: 保存路径
"""
layers = list(features_dict.keys())
num_layers = len(layers)
fig, axes = plt.subplots(num_layers, 4, figsize=(16, 4*num_layers))
for i, (layer_name, features) in enumerate(features_dict.items()):
feature_maps = features0
1. 平均特征图
avg_feat = feature_maps.mean(dim=0).numpy()
avg_feat = (avg_feat - avg_feat.min()) / (avg_feat.max() - avg_feat.min() + 1e-8)
2. 最大激活特征图
max_activation = feature_maps.max(dim=0)0.numpy()
max_activation = (max_activation - max_activation.min()) / (max_activation.max() - max_activation.min() + 1e-8)
3. 通道相关性
channel_corr = torch.corrcoef(feature_maps.view(feature_maps.shape0, -1))
4. 激活分布
activations = feature_maps.view(-1).numpy()
绘图
axesi, 0.imshow(avg_feat, cmap='viridis')
axesi, 0.set_title(f'{layer_name}\nAvg Feature Map', fontsize=10)
axesi, 0.axis('off')
axesi, 1.imshow(max_activation, cmap='viridis')
axesi, 1.set_title('Max Activation Map', fontsize=10)
axesi, 1.axis('off')
im = axesi, 2.imshow(channel_corr:20, :20, cmap='coolwarm', vmin=-1, vmax=1)
axesi, 2.set_title('Channel Correlation', fontsize=10)
plt.colorbar(im, ax=axesi, 2)
axesi, 3.hist(activations, bins=50, alpha=0.7, color='skyblue')
axesi, 3.set_title('Activation Distribution', fontsize=10)
axesi, 3.set_xlabel('Activation Value')
axesi, 3.set_ylabel('Frequency')
plt.suptitle('Feature Map Spatial Hierarchy Analysis', fontsize=16, y=1.02)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
return fig
def analyze_layer_statistics(self, features_dict: Dictstr, torch.Tensor):
"""
分析层统计信息
Args:
features_dict: 不同层的特征图字典
"""
stats_data = \[\]
for layer_name, features in features_dict.items():
feature_maps = features0
stats = {
'Layer': layer_name,
'Channels': feature_maps.shape0,
'Spatial Size': f"{feature_maps.shape1}x{feature_maps.shape2}",
'Mean Activation': float(feature_maps.mean()),
'Std Activation': float(feature_maps.std()),
'Max Activation': float(feature_maps.max()),
'Sparsity': float((feature_maps.abs() < 0.01).sum() / feature_maps.numel()),
'Non-linearity': float(feature_maps.relu().mean() / (feature_maps.mean() + 1e-8))
}
stats_data.append(stats)
return stats_data
def run_visualization(
self,
image_path: str,
layer_names: Liststr,
visualization_type: str = 'comparison',
save_dir: str = './visualizations/'
):
"""
运行完整的可视化流程
Args:
image_path: 输入图像路径
layer_names: 要可视化的层名称
visualization_type: 可视化类型 ('single', 'comparison', 'hierarchy')
save_dir: 保存目录
"""
import os
os.makedirs(save_dir, exist_ok=True)
1. 预处理图像
input_tensor, original_image = self.preprocess_image(image_path)
2. 注册钩子并前向传播
self.register_hooks(layer_names)
with torch.no_grad():
_ = self.model(input_tensor)
3. 可视化原始图像
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
img_np = input_tensor0.cpu().permute(1, 2, 0).numpy()
img_np = img_np * np.array(0.229, 0.224, 0.225) + np.array(0.485, 0.456, 0.406)
img_np = np.clip(img_np, 0, 1)
plt.imshow(img_np)
plt.title('Preprocessed Image')
plt.axis('off')
plt.savefig(f'{save_dir}/input_image.png', dpi=300, bbox_inches='tight')
4. 执行可视化
if visualization_type == 'single':
for layer_name in layer_names:
fig = self.visualize_single_layer(
self.featureslayer_name,
layer_name,
num_channels=12,
method='grid'
)
plt.savefig(f'{save_dir}/{layer_name}_features.png', dpi=300, bbox_inches='tight')
plt.show()
elif visualization_type == 'comparison':
fig = self.visualize_layer_comparison(
self.features,
num_channels=8,
save_path=f'{save_dir}/layer_comparison.png'
)
plt.show()
elif visualization_type == 'hierarchy':
fig = self.visualize_spatial_hierarchy(
self.features,
save_path=f'{save_dir}/spatial_hierarchy.png'
)
plt.show()
5. 分析统计信息
stats = self.analyze_layer_statistics(self.features)
打印统计表格
print("\n" + "="*80)
print("LAYER STATISTICS ANALYSIS")
print("="*80)
for stat in stats:
print(f"\n{stat'Layer'}:")
print(f" Channels: {stat'Channels'}")
print(f" Spatial Size: {stat'Spatial Size'}")
print(f" Mean Activation: {stat'Mean Activation':.4f}")
print(f" Std Activation: {stat'Std Activation':.4f}")
print(f" Max Activation: {stat'Max Activation':.4f}")
print(f" Sparsity: {stat'Sparsity':.2%}")
print(f" Non-linearity Ratio: {stat'Non-linearity':.4f}")
6. 清理
self._remove_hooks()
return stats
使用示例
def main():
初始化可视化器
visualizer = FeatureVisualizer(model_name='resnet18', pretrained=True)
设置要可视化的层(根据模型结构调整)
if visualizer.model.class.name == 'ResNet':
layer_names = 'layer1.0.conv1', 'layer2.0.conv1', 'layer3.0.conv1', 'layer4.0.conv1'
elif visualizer.model.class.name == 'VGG':
layer_names = 'features.0', 'features.5', 'features.10', 'features.17'
else:
通用选择:前几个卷积层
layer_names = \[\]
for name, module in visualizer.model.named_modules():
if isinstance(module, nn.Conv2d):
layer_names.append(name)
if len(layer_names) >= 4: # 选择前4个卷积层
break
print(f"Selected layers for visualization: {layer_names}")
运行可视化(替换为你的图像路径)
image_path = "path/to/your/image.jpg" # 请替换为实际图像路径
try:
stats = visualizer.run_visualization(
image_path=image_path,
layer_names=layer_names,
visualization_type='comparison', # 可选: 'single', 'comparison', 'hierarchy'
save_dir='./feature_visualizations/'
)
except FileNotFoundError:
print("示例:使用随机图像进行演示")
创建随机图像作为示例
dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
Image.fromarray(dummy_image).save('dummy_image.jpg')
stats = visualizer.run_visualization(
image_path='dummy_image.jpg',
layer_names=layer_names,
visualization_type='comparison',
save_dir='./feature_visualizations/'
)
if name == "main":
main()