1.利用train_dataset来提取数据集,并且提取第一个数据集图像,没有batch维度
2.添加 batch 维度: [C, H, W] -> [1, C, H, W]img = img.unsqueeze(0)
import sys
import os
from pathlib import Path
解决 OpenMP 库冲突问题
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
添加项目根目录到 Python 路径
project_root = Path(file).parent.parent
sys.path.insert(0, str(project_root))
from utils.hanfeng_dataset import HanfengDataset
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from models.UNeXt.UNeXt_Stripe_Conv import UNext_Stripe_Conv
========== 1. 准备一个存储激活值的字典 ==========
activations = {}
data_root = r"D:\github\DSNet-main\data\hanfeng"
def get_activation(name):
"""定义hook函数"""
def hook(model, input, output):
activations[name] = output.detach()
return hook
========== 2. 给模型的层注册hook ==========
model = UNext_Stripe_Conv(num_classes=1)
model.eval()
给所有卷积层注册hook
for name, layer in model.named_modules():
if isinstance(layer, nn.Conv2d):
layer.register_forward_hook(get_activation(name))
========== 3. 前向传播,自动捕获激活值 ==========
定义数据预处理 transform
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform = A.Compose([
A.Resize(224, 224), # 调整大小,避免显存溢出
A.Normalize(mean=[0.105, 0.105, 0.105], std=[0.203, 0.203, 0.203]), # ImageNet 标准化
ToTensorV2(), # 转换为 tensor,自动变成 [C, H, W]
])
train_dataset = HanfengDataset(data_root, mode='trainval', transform=transform, auto_fix_path=True)
print(f"✓ 训练集加载成功: {len(train_dataset)} 个样本")
img, mask, meta = train_dataset[0]
print(f"图像形状: {img.shape}") # 现在应该是 [C, H, W]
# 添加 batch 维度: [C, H, W] -> [1, C, H, W]
img = img.unsqueeze(0)
with torch.no_grad():
output = model(img)
========== 4. 可视化任意层 ==========
def visualize_layer(layer_name, num_channels=16):
"""显示某层的前num_channels个通道"""
act = activations[layer_name][0] # [C, H, W]
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
if i < min(num_channels, act.shape[0]):
ax.imshow(act[i].cpu(), cmap='viridis')
ax.set_title(f'Ch {i}')
ax.axis('off')
plt.suptitle(f'{layer_name}')
plt.show()
使用
print(f"捕获了 {len(activations)} 层")
first_conv = list(activations.keys())[0]
visualize_layer(first_conv)