算法首先加载已训练模型对不同振动模式(正常、4种异常)的预测结果和对应的SHAP(SHapley Additive exPlanations)特征重要性值,这些值来源于一个多模态深度学习模型(结合了原始时域信号和STFT时频特征)。系统通过两种主要方式可视化模型决策依据:一是将原始振动信号的SHAP重要性值以红色渐变背景的形式叠加在时域波形图上,直观展示哪些时间点对模型分类贡献最大;二是分别展示STFT频域特征的SHAP重要性热力图和模型内部注意力机制生成的注意力图,揭示模型在频域和时间维度上的关注重点。整个过程对五个振动类别(正常+四种异常)分别进行可视化,并自动保存高分辨率图像,为工业设备状态监测和故障诊断提供直观的可解释性分析工具。
第一步,系统初始化与数据加载:创建必要的目录结构用于存储可视化结果,从指定路径加载原始振动数据(包含正常和四种异常模式的时域信号)、模型预测结果索引文件、SHAP特征重要性值以及注意力机制生成的注意力图,同时对振动信号的三个通道(X轴、Y轴、Z轴)进行定义。
第二步,数据预处理与时域对齐:针对每个振动类别,提取模型预测为当前类别的代表性时间窗口,确保原始信号数据与对应的SHAP重要性值在时间长度上保持一致,通过截取最小公共长度实现时域对齐,为后续的可视化叠加准备匹配的数据结构。
第三步,时域信号与SHAP重要性叠加可视化:对每个振动类别和每个传感器通道,计算SHAP值的全局归一化范围,将SHAP重要性映射为0到1的归一化值,使用红色渐变色系将归一化后的SHAP值作为半透明背景,与黑色的原始时域振动波形进行叠加显示,形成重要性热力图与原始信号的双重视觉层。
第四步,频域特征重要性可视化:加载短时傅里叶变换(STFT)后的SHAP重要性值,计算所有类别和通道的全局最小最大值用于统一颜色映射,为每个振动类别创建包含三个通道(X、Y、Z轴)的STFT SHAP热力图,使用红蓝对比色系展示频域特征对模型决策的正负向贡献。
第五步,注意力机制可视化:提取多模态模型中的交叉注意力图,展示不同注意力头在处理查询-键关系时的关注模式,使用绿色渐变色系(viridis)统一显示所有类别和注意力头的激活强度,揭示模型内部信息交互的重点区域。
第六步,多类别并行处理与输出:对五种振动模式(正常、异常1-4)分别执行上述可视化流程,生成时域、频域和注意力三个维度的可视化结果,同时支持屏幕交互式查看和文件保存两种输出方式,保存为高分辨率PNG图像并按类别组织存储结构。
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
# Define class names for different vibration patterns
class_names = ["normal", "abnormal_1", "abnormal_2", "abnormal_3", "abnormal_4"]
# Define channel names for triaxial vibration sensor data
channel_names = ["x-axis", "y-axis", "z-axis"]
# Define file paths for raw data, SHAP values, and prediction indices
RAW_FOLDER = "test_raw/test_raw_follow"
SHAP_FOLDER = "shap_attention_results/multimodal"
IDX_FILE = "final_results/multimodal/trial_1/predicted_indices.json"
# Create directory for saving visualization results
SAVE_DIR = "figures/raw_shap_per_class"
os.makedirs(SAVE_DIR, exist_ok=True)
# Load prediction indices from JSON file
with open(IDX_FILE) as f:
pred_idx = json.load(f)
# Load raw vibration data for each class (format: N_windows × 3_channels × L_timepoints)
raw_data = {
cls: np.load(os.path.join(RAW_FOLDER, f"{cls}_raw_tensors.npy"))
for cls in class_names
}
# Load SHAP values for raw signals (format: 5_classes × 3_channels × L_timepoints)
raw_shap = np.load(os.path.join(SHAP_FOLDER, "raw_shap_values.npy"))
# Calculate global min and max SHAP values for consistent normalization across all plots
gmin, gmax = np.abs(raw_shap).min(), np.abs(raw_shap).max()















担任《Mechanical System and Signal Processing》《中国电机工程学报》《宇航学报》《控制与决策》等期刊审稿专家,擅长领域:信号滤波/降噪,机器学习/深度学习,时间序列预分析/预测,设备故障诊断/缺陷检测/异常检测