神经网络中间层特征图可视化(输入为音频)(二)

相比方法一个人感觉这种方法更好

复制代码
import librosa
import numpy as np
import utils
import torch
import matplotlib.pyplot as plt

class Hook:
    def __init__(self):
        self.features = None

    def hook_fn(self, module, input, output):
        self.features = output

# 创建钩子的实例
hook = Hook()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def extract_mbe(_y, _sr, _nfft, _nb_mel):
    #梅尔频谱
    spec = librosa.core.spectrum._spectrogram(y=_y, n_fft=_nfft, hop_length=_nfft // 2, power=1)[0]
    mel_basis = librosa.filters.mel(sr=_sr, n_fft=_nfft, n_mels=_nb_mel)
    mel_spec = np.log(np.dot(mel_basis, spec).T)
    return mel_spec       #最后必须是[frames, dimensions]

def preprocess_data(X, seq_len, nb_ch):
    # split into sequences
    X = utils.split_in_seqs(X, seq_len)
    X = utils.split_multi_channels(X, nb_ch)
    # Convert to PyTorch tensors
    X = torch.Tensor(X)
    X = X.permute(0,1,3,2)   #x形状为[709,2,40,256],【总样本数,通道数,特征维度,像素宽度】
    return X

# 提取梅尔频谱特征
audio_path = "a011.wav"
y, sr = librosa.load(audio_path, sr=44100)
mel = extract_mbe(y, sr, 2048, 64)

value = preprocess_data(mel, 256, 1).to(device)     #value 为输入模型的样本特征


model = torch.load(f'best_model_2.pth')

# 将钩子注册到需要的层
model.cnn1.register_forward_hook(hook.hook_fn)

# 假设`input_data`是你的输入张量
output = model(value)

# 访问存储的特征
retnet_features = hook.features
#print(retnet_features.shape)
# 可视化特征(假设retnet_features是一个张量)
retnet_features = retnet_features.permute(0, 2, 1, 3)
#retnet_features = retnet_features.transpose(1, 2)
#print(retnet_features.shape)
retnet_features = torch.cat([retnet_features[i] for i in range(10)], dim=2)
#print(retnet_features.shape)

# 可视化批次中第一个样本的特定通道
plt.imshow(retnet_features.sum(1).detach().cpu().numpy(), cmap='viridis', origin='lower')   #[高,通道, 宽]
# plt.imshow(retnet_features.detach().cpu().numpy(), cmap='viridis', origin='lower')   #[高,宽]
plt.show()
相关推荐
@鱼香肉丝没有鱼19 小时前
Transformer底层原理—位置编码
人工智能·深度学习·transformer·位置编码
深度学习实战训练营19 小时前
HRNet:深度高分辨率表示学习用于人体姿态估计-k学长深度学习专栏
人工智能·深度学习
架构师李哲19 小时前
让智能家居“听懂人话”:我用4B模型+万条数据,教会了它理解复杂指令
深度学习·aigc
CoovallyAIHub20 小时前
是什么支撑L3自动驾驶落地?读懂AI驾驶与碰撞预测
深度学习·算法·计算机视觉
碧海银沙音频科技研究院20 小时前
论文写作word插入公式显示灰色解决办法
人工智能·深度学习·算法
qq_3106585120 小时前
mediasoup源码走读(十二)——router
服务器·c++·音视频
听风吹等浪起20 小时前
机器学习算法:随机梯度下降算法
人工智能·深度学习·算法·机器学习
拉姆哥的小屋20 小时前
【深度学习实战】基于CyclePatch框架的电池寿命预测:从NASA数据集到Transformer模型的完整实现
人工智能·深度学习·transformer
音视频牛哥21 小时前
SmartMediakit技术白皮书:与主流云厂商(PaaS)的技术定位对比与选型指南
人工智能·深度学习·机器学习·音视频·gb28181对接·rtsp服务器·rtsp播放器rtmp播放器
Rabbit_QL21 小时前
【LLM基础教程】从序列切分到上下文窗口02_三种数据切分方法
深度学习·语言模型