基于多模态特征融合和可解释性深度学习的工业压缩机异常分类与预测性维护智能诊断(Python)

算法首先加载已训练模型对不同振动模式(正常、4种异常)的预测结果和对应的SHAP(SHapley Additive exPlanations)特征重要性值,这些值来源于一个多模态深度学习模型(结合了原始时域信号和STFT时频特征)。系统通过两种主要方式可视化模型决策依据:一是将原始振动信号的SHAP重要性值以红色渐变背景的形式叠加在时域波形图上,直观展示哪些时间点对模型分类贡献最大;二是分别展示STFT频域特征的SHAP重要性热力图和模型内部注意力机制生成的注意力图,揭示模型在频域和时间维度上的关注重点。整个过程对五个振动类别(正常+四种异常)分别进行可视化,并自动保存高分辨率图像,为工业设备状态监测和故障诊断提供直观的可解释性分析工具。

第一步,系统初始化与数据加载:创建必要的目录结构用于存储可视化结果,从指定路径加载原始振动数据(包含正常和四种异常模式的时域信号)、模型预测结果索引文件、SHAP特征重要性值以及注意力机制生成的注意力图,同时对振动信号的三个通道(X轴、Y轴、Z轴)进行定义。

第二步,数据预处理与时域对齐:针对每个振动类别,提取模型预测为当前类别的代表性时间窗口,确保原始信号数据与对应的SHAP重要性值在时间长度上保持一致,通过截取最小公共长度实现时域对齐,为后续的可视化叠加准备匹配的数据结构。

第三步,时域信号与SHAP重要性叠加可视化:对每个振动类别和每个传感器通道,计算SHAP值的全局归一化范围,将SHAP重要性映射为0到1的归一化值,使用红色渐变色系将归一化后的SHAP值作为半透明背景,与黑色的原始时域振动波形进行叠加显示,形成重要性热力图与原始信号的双重视觉层。

第四步,频域特征重要性可视化:加载短时傅里叶变换(STFT)后的SHAP重要性值,计算所有类别和通道的全局最小最大值用于统一颜色映射,为每个振动类别创建包含三个通道(X、Y、Z轴)的STFT SHAP热力图,使用红蓝对比色系展示频域特征对模型决策的正负向贡献。

第五步,注意力机制可视化:提取多模态模型中的交叉注意力图,展示不同注意力头在处理查询-键关系时的关注模式,使用绿色渐变色系(viridis)统一显示所有类别和注意力头的激活强度,揭示模型内部信息交互的重点区域。

第六步,多类别并行处理与输出:对五种振动模式(正常、异常1-4)分别执行上述可视化流程,生成时域、频域和注意力三个维度的可视化结果,同时支持屏幕交互式查看和文件保存两种输出方式,保存为高分辨率PNG图像并按类别组织存储结构。

复制代码
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection

# Define class names for different vibration patterns
class_names = ["normal", "abnormal_1", "abnormal_2", "abnormal_3", "abnormal_4"]
# Define channel names for triaxial vibration sensor data
channel_names = ["x-axis", "y-axis", "z-axis"]

# Define file paths for raw data, SHAP values, and prediction indices
RAW_FOLDER = "test_raw/test_raw_follow"
SHAP_FOLDER = "shap_attention_results/multimodal"
IDX_FILE = "final_results/multimodal/trial_1/predicted_indices.json"

# Create directory for saving visualization results
SAVE_DIR = "figures/raw_shap_per_class"
os.makedirs(SAVE_DIR, exist_ok=True)

# Load prediction indices from JSON file
with open(IDX_FILE) as f:
    pred_idx = json.load(f)

# Load raw vibration data for each class (format: N_windows × 3_channels × L_timepoints)
raw_data = {
    cls: np.load(os.path.join(RAW_FOLDER, f"{cls}_raw_tensors.npy"))
    for cls in class_names
}

# Load SHAP values for raw signals (format: 5_classes × 3_channels × L_timepoints)
raw_shap = np.load(os.path.join(SHAP_FOLDER, "raw_shap_values.npy"))

# Calculate global min and max SHAP values for consistent normalization across all plots
gmin, gmax = np.abs(raw_shap).min(), np.abs(raw_shap).max()

担任《Mechanical System and Signal Processing》《中国电机工程学报》《宇航学报》《控制与决策》等期刊审稿专家,擅长领域:信号滤波/降噪,机器学习/深度学习,时间序列预分析/预测,设备故障诊断/缺陷检测/异常检测

相关推荐
狮驼岭的小钻风2 分钟前
汽车V模型开发流程、ASPICE、汽车功能安全的基石是国际标准 ISO 26262
网络·安全·汽车
wazmlp0018873694 分钟前
第五次python作业
服务器·开发语言·python
OPEN-Source5 分钟前
大模型实战:把自定义 Agent 封装成一个 HTTP 服务
人工智能·agent·deepseek
尘缘浮梦5 分钟前
websockets简单例子1
开发语言·python
不懒不懒6 分钟前
【从零开始:PyTorch实现MNIST手写数字识别全流程解析】
人工智能·pytorch·python
zhangshuang-peta6 分钟前
从REST到MCP:为何及如何为AI代理升级API
人工智能·ai agent·mcp·peta
helloworld也报错?7 分钟前
基于CrewAI创建一个简单的智能体
人工智能·python·vllm
机器学习之心9 分钟前
基于GRU门控循环单元的轴承剩余寿命预测MATLAB实现
深度学习·matlab·gru·轴承剩余寿命预测
wukangjupingbb10 分钟前
Gemini 3和GPT-5.1在多模态处理上的对比
人工智能·gpt·机器学习
明月照山海-10 分钟前
机器学习周报三十四
人工智能·机器学习