day46打卡

@浙大疏锦行

import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np

import matplotlib.pyplot as plt

from torchvision import models, transforms

from PIL import Image

import cv2

from typing import List, Dict, Optional

class FeatureVisualizer:

"""卷积特征图可视化器"""

def init(self, model_name: str = 'resnet18', pretrained: bool = True):

"""

初始化特征可视化器

Args:

model_name: 模型名称 ('resnet18', 'vgg16', 'alexnet')

pretrained: 是否使用预训练权重

"""

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

self.model = self._load_model(model_name, pretrained)

self.model.to(self.device)

self.model.eval()

存储特征图

self.features = {}

self.hooks = []

def _load_model(self, model_name: str, pretrained: bool):

"""加载指定模型"""

model_dict = {

'resnet18': models.resnet18,

'vgg16': models.vgg16,

'alexnet': models.alexnet,

'resnet50': models.resnet50,

'vgg19': models.vgg19

}

if model_name not in model_dict:

raise ValueError(f"模型 {model_name} 不支持")

model = model_dict[model_name](pretrained=pretrained)

return model

def register_hooks(self, layer_names: List[str]):

"""

为指定层注册钩子

Args:

layer_names: 要可视化的层名称列表

"""

self._remove_hooks()

for name, module in self.model.named_modules():

if name in layer_names:

hook = module.register_forward_hook(

lambda m, inp, out, name=name: self._hook_fn(name, out)

)

self.hooks.append(hook)

def _hook_fn(self, name: str, output):

"""钩子函数,保存特征图"""

self.features[name] = output.detach().cpu()

def _remove_hooks(self):

"""移除所有钩子"""

for hook in self.hooks:

hook.remove()

self.hooks.clear()

self.features.clear()

def preprocess_image(self, image_path: str, img_size: int = 224):

"""

预处理图像

Args:

image_path: 图像路径

img_size: 输入图像大小

"""

定义预处理变换

preprocess = transforms.Compose([

transforms.Resize((img_size, img_size)),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

加载图像

image = Image.open(image_path).convert('RGB')

input_tensor = preprocess(image).unsqueeze(0)

return input_tensor.to(self.device), image

def visualize_single_layer(

self,

features: torch.Tensor,

layer_name: str,

num_channels: int = 16,

method: str = 'grid'

):

"""

可视化单层特征图

Args:

features: 特征图张量 [batch, channels, height, width]

layer_name: 层名称

num_channels: 显示的特征图数量

method: 可视化方法 ('grid', 'heatmap', 'max_activation')

"""

获取第一个batch的特征图

feature_maps = features[0]

num_total_channels = feature_maps.shape[0]

num_channels = min(num_channels, num_total_channels)

选择要显示的特征图索引

if method == 'max_activation':

选择激活值最大的通道

channel_activations = feature_maps.view(num_total_channels, -1).mean(dim=1)

selected_channels = torch.argsort(channel_activations, descending=True)[:num_channels]

else:

均匀选择通道

selected_channels = torch.linspace(0, num_total_channels-1, num_channels).long()

fig, axes = plt.subplots(1, num_channels, figsize=(3*num_channels, 3))

if num_channels == 1:

axes = [axes]

for idx, channel_idx in enumerate(selected_channels):

feat_map = feature_maps[channel_idx].numpy()

归一化到[0, 1]

feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)

if method == 'grid':

axes[idx].imshow(feat_map, cmap='viridis')

axes[idx].set_title(f'Ch {channel_idx}')

elif method == 'heatmap':

heatmap = cv2.applyColorMap(np.uint8(255 * feat_map), cv2.COLORMAP_JET)

axes[idx].imshow(heatmap)

axes[idx].set_title(f'Ch {channel_idx}')

axes[idx].axis('off')

plt.suptitle(f'{layer_name} Feature Maps', fontsize=16, y=1.05)

plt.tight_layout()

return fig

def visualize_layer_comparison(

self,

layer_features: Dict[str, torch.Tensor],

num_channels: int = 8,

save_path: Optional[str] = None

):

"""

对比不同层的特征图

Args:

layer_features: 层名称到特征图的字典

num_channels: 每层显示的特征图数量

save_path: 保存路径

"""

num_layers = len(layer_features)

fig, axes = plt.subplots(num_layers, num_channels,

figsize=(2.5*num_channels, 2.5*num_layers))

if num_layers == 1:

axes = [axes]

for layer_idx, (layer_name, features) in enumerate(layer_features.items()):

feature_maps = features[0]

num_total_channels = feature_maps.shape[0]

选择激活最强的通道

channel_activations = feature_maps.view(num_total_channels, -1).mean(dim=1)

selected_channels = torch.argsort(channel_activations, descending=True)[:num_channels]

for ch_idx, channel_idx in enumerate(selected_channels):

feat_map = feature_maps[channel_idx].numpy()

feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)

ax = axes[layer_idx][ch_idx] if num_layers > 1 else axes[ch_idx]

ax.imshow(feat_map, cmap='viridis')

if layer_idx == 0:

ax.set_title(f'Ch {channel_idx}', fontsize=10)

if ch_idx == 0:

ylabel = f'{layer_name}\n({feature_maps.shape[-2]}x{feature_maps.shape[-1]})'

ax.set_ylabel(ylabel, fontsize=10)

ax.axis('off')

plt.suptitle('Convolutional Layer Feature Map Comparison', fontsize=16, y=1.02)

plt.tight_layout()

if save_path:

plt.savefig(save_path, dpi=300, bbox_inches='tight')

return fig

def visualize_spatial_hierarchy(

self,

features_dict: Dict[str, torch.Tensor],

save_path: Optional[str] = None

):

"""

可视化空间层次结构

Args:

features_dict: 不同层的特征图字典

save_path: 保存路径

"""

layers = list(features_dict.keys())

num_layers = len(layers)

fig, axes = plt.subplots(num_layers, 4, figsize=(16, 4*num_layers))

for i, (layer_name, features) in enumerate(features_dict.items()):

feature_maps = features[0]

1. 平均特征图

avg_feat = feature_maps.mean(dim=0).numpy()

avg_feat = (avg_feat - avg_feat.min()) / (avg_feat.max() - avg_feat.min() + 1e-8)

2. 最大激活特征图

max_activation = feature_maps.max(dim=0)[0].numpy()

max_activation = (max_activation - max_activation.min()) / (max_activation.max() - max_activation.min() + 1e-8)

3. 通道相关性

channel_corr = torch.corrcoef(feature_maps.view(feature_maps.shape[0], -1))

4. 激活分布

activations = feature_maps.view(-1).numpy()

绘图

axes[i, 0].imshow(avg_feat, cmap='viridis')

axes[i, 0].set_title(f'{layer_name}\nAvg Feature Map', fontsize=10)

axes[i, 0].axis('off')

axes[i, 1].imshow(max_activation, cmap='viridis')

axes[i, 1].set_title('Max Activation Map', fontsize=10)

axes[i, 1].axis('off')

im = axes[i, 2].imshow(channel_corr[:20, :20], cmap='coolwarm', vmin=-1, vmax=1)

axes[i, 2].set_title('Channel Correlation', fontsize=10)

plt.colorbar(im, ax=axes[i, 2])

axes[i, 3].hist(activations, bins=50, alpha=0.7, color='skyblue')

axes[i, 3].set_title('Activation Distribution', fontsize=10)

axes[i, 3].set_xlabel('Activation Value')

axes[i, 3].set_ylabel('Frequency')

plt.suptitle('Feature Map Spatial Hierarchy Analysis', fontsize=16, y=1.02)

plt.tight_layout()

if save_path:

plt.savefig(save_path, dpi=300, bbox_inches='tight')

return fig

def analyze_layer_statistics(self, features_dict: Dict[str, torch.Tensor]):

"""

分析层统计信息

Args:

features_dict: 不同层的特征图字典

"""

stats_data = []

for layer_name, features in features_dict.items():

feature_maps = features[0]

stats = {

'Layer': layer_name,

'Channels': feature_maps.shape[0],

'Spatial Size': f"{feature_maps.shape[1]}x{feature_maps.shape[2]}",

'Mean Activation': float(feature_maps.mean()),

'Std Activation': float(feature_maps.std()),

'Max Activation': float(feature_maps.max()),

'Sparsity': float((feature_maps.abs() < 0.01).sum() / feature_maps.numel()),

'Non-linearity': float(feature_maps.relu().mean() / (feature_maps.mean() + 1e-8))

}

stats_data.append(stats)

return stats_data

def run_visualization(

self,

image_path: str,

layer_names: List[str],

visualization_type: str = 'comparison',

save_dir: str = './visualizations/'

):

"""

运行完整的可视化流程

Args:

image_path: 输入图像路径

layer_names: 要可视化的层名称

visualization_type: 可视化类型 ('single', 'comparison', 'hierarchy')

save_dir: 保存目录

"""

import os

os.makedirs(save_dir, exist_ok=True)

1. 预处理图像

input_tensor, original_image = self.preprocess_image(image_path)

2. 注册钩子并前向传播

self.register_hooks(layer_names)

with torch.no_grad():

_ = self.model(input_tensor)

3. 可视化原始图像

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)

plt.imshow(original_image)

plt.title('Original Image')

plt.axis('off')

plt.subplot(1, 2, 2)

img_np = input_tensor[0].cpu().permute(1, 2, 0).numpy()

img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])

img_np = np.clip(img_np, 0, 1)

plt.imshow(img_np)

plt.title('Preprocessed Image')

plt.axis('off')

plt.savefig(f'{save_dir}/input_image.png', dpi=300, bbox_inches='tight')

4. 执行可视化

if visualization_type == 'single':

for layer_name in layer_names:

fig = self.visualize_single_layer(

self.features[layer_name],

layer_name,

num_channels=12,

method='grid'

)

plt.savefig(f'{save_dir}/{layer_name}_features.png', dpi=300, bbox_inches='tight')

plt.show()

elif visualization_type == 'comparison':

fig = self.visualize_layer_comparison(

self.features,

num_channels=8,

save_path=f'{save_dir}/layer_comparison.png'

)

plt.show()

elif visualization_type == 'hierarchy':

fig = self.visualize_spatial_hierarchy(

self.features,

save_path=f'{save_dir}/spatial_hierarchy.png'

)

plt.show()

5. 分析统计信息

stats = self.analyze_layer_statistics(self.features)

打印统计表格

print("\n" + "="*80)

print("LAYER STATISTICS ANALYSIS")

print("="*80)

for stat in stats:

print(f"\n{stat['Layer']}:")

print(f" Channels: {stat['Channels']}")

print(f" Spatial Size: {stat['Spatial Size']}")

print(f" Mean Activation: {stat['Mean Activation']:.4f}")

print(f" Std Activation: {stat['Std Activation']:.4f}")

print(f" Max Activation: {stat['Max Activation']:.4f}")

print(f" Sparsity: {stat['Sparsity']:.2%}")

print(f" Non-linearity Ratio: {stat['Non-linearity']:.4f}")

6. 清理

self._remove_hooks()

return stats

使用示例

def main():

初始化可视化器

visualizer = FeatureVisualizer(model_name='resnet18', pretrained=True)

设置要可视化的层(根据模型结构调整)

if visualizer.model.class.name == 'ResNet':

layer_names = ['layer1.0.conv1', 'layer2.0.conv1', 'layer3.0.conv1', 'layer4.0.conv1']

elif visualizer.model.class.name == 'VGG':

layer_names = ['features.0', 'features.5', 'features.10', 'features.17']

else:

通用选择:前几个卷积层

layer_names = []

for name, module in visualizer.model.named_modules():

if isinstance(module, nn.Conv2d):

layer_names.append(name)

if len(layer_names) >= 4: # 选择前4个卷积层

break

print(f"Selected layers for visualization: {layer_names}")

运行可视化(替换为你的图像路径)

image_path = "path/to/your/image.jpg" # 请替换为实际图像路径

try:

stats = visualizer.run_visualization(

image_path=image_path,

layer_names=layer_names,

visualization_type='comparison', # 可选: 'single', 'comparison', 'hierarchy'

save_dir='./feature_visualizations/'

)

except FileNotFoundError:

print("示例:使用随机图像进行演示")

创建随机图像作为示例

dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)

Image.fromarray(dummy_image).save('dummy_image.jpg')

stats = visualizer.run_visualization(

image_path='dummy_image.jpg',

layer_names=layer_names,

visualization_type='comparison',

save_dir='./feature_visualizations/'

)

if name == "main":

main()

相关推荐
哈里谢顿20 小时前
一条 Python 语句在 C 扩展里到底怎么跑
python
Edward.W21 小时前
Python uv:新一代Python包管理工具,彻底改变开发体验
开发语言·python·uv
小熊officer21 小时前
Python字符串
开发语言·数据库·python
月疯21 小时前
各种信号的模拟(ECG信号、质谱图、EEG信号),方便U-net训练
开发语言·python
小鸡吃米…1 天前
机器学习中的回归分析
人工智能·python·机器学习·回归
AC赳赳老秦1 天前
Python 爬虫进阶:DeepSeek 优化反爬策略与动态数据解析逻辑
开发语言·hadoop·spring boot·爬虫·python·postgresql·deepseek
浩瀚之水_csdn1 天前
Python 三元运算符详解
开发语言·python
Yuner20001 天前
Python机器学习:从入门到精通
python
Amelia1111111 天前
day47
python