1.利用train_dataset来提取数据集,并且提取第一个数据集图像,没有batch维度
2.添加 batch 维度: C, H, W -> 1, C, H, Wimg = img.unsqueeze(0)
import sys
import os
from pathlib import Path
解决 OpenMP 库冲突问题
os.environ'KMP_DUPLICATE_LIB_OK' = 'TRUE'
添加项目根目录到 Python 路径
project_root = Path(file).parent.parent
sys.path.insert(0, str(project_root))
from utils.hanfeng_dataset import HanfengDataset
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from models.UNeXt.UNeXt_Stripe_Conv import UNext_Stripe_Conv
========== 1. 准备一个存储激活值的字典 ==========
activations = {}
data_root = r"D:\github\DSNet-main\data\hanfeng"
def get_activation(name):
"""定义hook函数"""
def hook(model, input, output):
activationsname = output.detach()
return hook
========== 2. 给模型的层注册hook ==========
model = UNext_Stripe_Conv(num_classes=1)
model.eval()
给所有卷积层注册hook
for name, layer in model.named_modules():
if isinstance(layer, nn.Conv2d):
layer.register_forward_hook(get_activation(name))
========== 3. 前向传播,自动捕获激活值 ==========
定义数据预处理 transform
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform = A.Compose([
A.Resize(224, 224), # 调整大小,避免显存溢出
A.Normalize(mean=0.105, 0.105, 0.105, std=0.203, 0.203, 0.203), # ImageNet 标准化
ToTensorV2(), # 转换为 tensor,自动变成 C, H, W
])
train_dataset = HanfengDataset(data_root, mode='trainval', transform=transform, auto_fix_path=True)
print(f"✓ 训练集加载成功: {len(train_dataset)} 个样本")
img, mask, meta = train_dataset0
print(f"图像形状: {img.shape}") # 现在应该是 C, H, W
# 添加 batch 维度: C, H, W -> 1, C, H, W
img = img.unsqueeze(0)
with torch.no_grad():
output = model(img)
========== 4. 可视化任意层 ==========
def visualize_layer(layer_name, num_channels=16):
"""显示某层的前num_channels个通道"""
act = activationslayer_name0 # C, H, W
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
if i < min(num_channels, act.shape0):
ax.imshow(acti.cpu(), cmap='viridis')
ax.set_title(f'Ch {i}')
ax.axis('off')
plt.suptitle(f'{layer_name}')
plt.show()
使用
print(f"捕获了 {len(activations)} 层")
first_conv = list(activations.keys())0
visualize_layer(first_conv)