hook来获取模型每层的激活值输出

1.利用train_dataset来提取数据集,并且提取第一个数据集图像,没有batch维度

2.添加 batch 维度: [C, H, W] -> [1, C, H, W]img = img.unsqueeze(0)

import sys

import os

from pathlib import Path

解决 OpenMP 库冲突问题

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

添加项目根目录到 Python 路径

project_root = Path(file).parent.parent

sys.path.insert(0, str(project_root))

from utils.hanfeng_dataset import HanfengDataset

import torch

import torch.nn as nn

import matplotlib.pyplot as plt

import torchvision

from models.UNeXt.UNeXt_Stripe_Conv import UNext_Stripe_Conv

========== 1. 准备一个存储激活值的字典 ==========

activations = {}

data_root = r"D:\github\DSNet-main\data\hanfeng"

def get_activation(name):

"""定义hook函数"""

def hook(model, input, output):

activations[name] = output.detach()

return hook

========== 2. 给模型的层注册hook ==========

model = UNext_Stripe_Conv(num_classes=1)

model.eval()

给所有卷积层注册hook

for name, layer in model.named_modules():

if isinstance(layer, nn.Conv2d):

layer.register_forward_hook(get_activation(name))

========== 3. 前向传播,自动捕获激活值 ==========

定义数据预处理 transform

import albumentations as A

from albumentations.pytorch import ToTensorV2

transform = A.Compose([

A.Resize(224, 224), # 调整大小,避免显存溢出

A.Normalize(mean=[0.105, 0.105, 0.105], std=[0.203, 0.203, 0.203]), # ImageNet 标准化

ToTensorV2(), # 转换为 tensor,自动变成 [C, H, W]

])

train_dataset = HanfengDataset(data_root, mode='trainval', transform=transform, auto_fix_path=True)

print(f"✓ 训练集加载成功: {len(train_dataset)} 个样本")

img, mask, meta = train_dataset[0]

print(f"图像形状: {img.shape}") # 现在应该是 [C, H, W]

# 添加 batch 维度: [C, H, W] -> [1, C, H, W]

img = img.unsqueeze(0)

with torch.no_grad():

output = model(img)

========== 4. 可视化任意层 ==========

def visualize_layer(layer_name, num_channels=16):

"""显示某层的前num_channels个通道"""

act = activations[layer_name][0] # [C, H, W]

fig, axes = plt.subplots(4, 4, figsize=(10, 10))

for i, ax in enumerate(axes.flat):

if i < min(num_channels, act.shape[0]):

ax.imshow(act[i].cpu(), cmap='viridis')

ax.set_title(f'Ch {i}')

ax.axis('off')

plt.suptitle(f'{layer_name}')

plt.show()

使用

print(f"捕获了 {len(activations)} 层")

first_conv = list(activations.keys())[0]

visualize_layer(first_conv)

相关推荐
寻星探路16 小时前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
聆风吟º18 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子18 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder18 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能19 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能57719 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
猫头虎19 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h19 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切19 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话19 小时前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python