hook来获取模型每层的激活值输出

1.利用train_dataset来提取数据集,并且提取第一个数据集图像,没有batch维度

2.添加 batch 维度: C, H, W -> 1, C, H, Wimg = img.unsqueeze(0)

import sys

import os

from pathlib import Path

解决 OpenMP 库冲突问题

os.environ'KMP_DUPLICATE_LIB_OK' = 'TRUE'

添加项目根目录到 Python 路径

project_root = Path(file).parent.parent

sys.path.insert(0, str(project_root))

from utils.hanfeng_dataset import HanfengDataset

import torch

import torch.nn as nn

import matplotlib.pyplot as plt

import torchvision

from models.UNeXt.UNeXt_Stripe_Conv import UNext_Stripe_Conv

========== 1. 准备一个存储激活值的字典 ==========

activations = {}

data_root = r"D:\github\DSNet-main\data\hanfeng"

def get_activation(name):

"""定义hook函数"""

def hook(model, input, output):

activationsname = output.detach()

return hook

========== 2. 给模型的层注册hook ==========

model = UNext_Stripe_Conv(num_classes=1)

model.eval()

给所有卷积层注册hook

for name, layer in model.named_modules():

if isinstance(layer, nn.Conv2d):

layer.register_forward_hook(get_activation(name))

========== 3. 前向传播,自动捕获激活值 ==========

定义数据预处理 transform

import albumentations as A

from albumentations.pytorch import ToTensorV2

transform = A.Compose([

A.Resize(224, 224), # 调整大小,避免显存溢出

A.Normalize(mean=0.105, 0.105, 0.105, std=0.203, 0.203, 0.203), # ImageNet 标准化

ToTensorV2(), # 转换为 tensor,自动变成 C, H, W

])

train_dataset = HanfengDataset(data_root, mode='trainval', transform=transform, auto_fix_path=True)

print(f"✓ 训练集加载成功: {len(train_dataset)} 个样本")

img, mask, meta = train_dataset0

print(f"图像形状: {img.shape}") # 现在应该是 C, H, W

# 添加 batch 维度: C, H, W -> 1, C, H, W

img = img.unsqueeze(0)

with torch.no_grad():

output = model(img)

========== 4. 可视化任意层 ==========

def visualize_layer(layer_name, num_channels=16):

"""显示某层的前num_channels个通道"""

act = activationslayer_name0 # C, H, W

fig, axes = plt.subplots(4, 4, figsize=(10, 10))

for i, ax in enumerate(axes.flat):

if i < min(num_channels, act.shape0):

ax.imshow(acti.cpu(), cmap='viridis')

ax.set_title(f'Ch {i}')

ax.axis('off')

plt.suptitle(f'{layer_name}')

plt.show()

使用

print(f"捕获了 {len(activations)} 层")

first_conv = list(activations.keys())0

visualize_layer(first_conv)

相关推荐
装不满的克莱因瓶1 小时前
链式法则如何传递参数误差 —— 深入理解神经网络中的梯度传播
人工智能·python·深度学习·神经网络·数学·机器学习·ai
Anastasiozzzz1 小时前
从有限状态机到智能体图:传统 FSM 与 Agent Graph的演进
java·人工智能·python·ai
biter down7 小时前
从 0 到 1 搭建 Python 接口自动化测试框架(博客系统实战)
开发语言·python
肖永威8 小时前
Python多业务并行计算框架插件化演进:从硬编码到动态注册
python·插件化·并行计算·动态注册
yz_aiks8 小时前
Linux Jar包配置Systemd自启动实战:从排查到配置全流程
linux·python·jar·自启动·systemd
不知名的老吴8 小时前
线程的生命周期之线程“插队“
java·开发语言·python
xsc6996759 小时前
从零搭建大模型与智能体平台 - 完整技术详解
python
无风听海11 小时前
多租户系统中的 OIDC:Discovery 端点与联合登录的深度实践
后端·python·flask
CTA终结者11 小时前
期货量化主力换月程序怎么移仓:天勤 underlying_symbol 与任务切换
python·区块链
马士兵教育11 小时前
Java还有前景吗?Java+AI大模型学习路线及项目?
java·人工智能·python·学习·机器学习