CNN+Transformer+SE注意力机制多分类模型 + SHAP特征重要性分析,pytorch框架

效果一览





代码功能

CNN提取一维序列的局部特征,如光谱峰值、表格数据趋势等。Transformer捕捉一维序列的全局依赖关系,解决长序列建模难题! 弥补CNN在长距离依赖建模上的不足,提升模型的全局特征提取能力。SE注意力机制动态调整特征通道权重,聚焦关键信息,提升分类精度!

支持多类别分类任务,适用于光谱分类、表格数据分类、时间序列分类等场景。

可自定义类别数量

输出训练损失和准确率,并评估训练集和测试集的准确率,精确率,召回率,f1分数,绘制roc曲线,混淆矩阵

结合SHAP(Shapley Additive exPlanations),直观展示每个特征对分类结果的影响!

包括蜂巢图,重要性图,单特征力图,决策图,热图,瀑布图等。

CNN+Transformer+SE注意力机制多分类模型 + SHAP特征重要性分析

模型架构与核心组件

1. CNN(卷积神经网络)

功能

  • 局部特征提取:通过一维卷积核滑动窗口(如核大小=3/5/7),捕获序列中的局部模式(如光谱峰值、数据趋势)。
  • 特征增强 :使用多层卷积堆叠(Conv1D+ReLU+MaxPool),逐步抽象高阶特征,输出维度为 (batch_size, channels, seq_len)

实现代码片段

python 复制代码
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding='same')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
    def forward(self, x):
        return self.pool(self.relu(self.conv(x)))

2. Transformer

功能

  • 全局依赖建模:利用自注意力机制(Multi-Head Attention)捕捉长序列中的上下文关系。
  • 位置编码:添加可学习的位置编码(Positional Encoding),解决序列顺序问题。
  • 特征融合 :输出全局特征矩阵 (batch_size, seq_len, d_model)

实现代码片段

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, num_layers=2):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
    def forward(self, x):
        x = self.pos_encoder(x.permute(0,2,1))  # 调整维度为 (seq_len, batch, d_model)
        return self.transformer(x).permute(1,0,2)

3. SE(Squeeze-and-Excitation)注意力机制

功能

  • 通道权重动态调整:通过全局平均池化(Squeeze)和全连接层(Excitation),生成通道权重向量。
  • 特征增强 :对CNN或Transformer输出的特征图进行通道级加权,公式:
    ( \text{SE}(x) = x \cdot \sigma(W_2 \cdot \text{ReLU}(W_1 \cdot \text{GAP}(x))) )

实现代码片段

python 复制代码
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels//reduction),
            nn.ReLU(),
            nn.Linear(channels//reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        weights = self.fc(self.gap(x).squeeze(-1))
        return x * weights.unsqueeze(-1)

4. 多分类任务支持

功能

  • 输出层 :全连接层 + Softmax,支持自定义类别数(num_classes)。
  • 评估指标:准确率、精确率、召回率、F1分数、ROC-AUC、混淆矩阵。

实现代码片段

python 复制代码
# 模型输出层
self.fc = nn.Linear(d_model, num_classes)

# 评估函数
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
def evaluate(y_true, y_pred):
    print(classification_report(y_true, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
    print("ROC-AUC:", roc_auc_score(y_true, y_pred, multi_class='ovr'))

5. SHAP特征重要性分析

功能

  • 模型可解释性:基于博弈论的SHAP值,量化特征对分类结果的贡献。
  • 可视化工具:支持蜂巢图、决策图、热力图等,适配序列输入。

实现代码片段

python 复制代码
import shap

# 初始化解释器
explainer = shap.DeepExplainer(model, background_data)

# 计算SHAP值
shap_values = explainer.shap_values(test_sample)

# 可视化(示例:特征重要性图)
shap.summary_plot(shap_values, test_sample, plot_type='bar')

模型整合与训练流程

完整模型架构

python 复制代码
class CNNTransformerSE(nn.Module):
    def __init__(self, input_dim, num_classes, d_model=64, nhead=4):
        super().__init__()
        self.cnn = CNNBlock(input_dim, 32)
        self.se1 = SEBlock(32)
        self.transformer = TransformerBlock(d_model, nhead)
        self.se2 = SEBlock(d_model)
        self.fc = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        x = self.cnn(x)          # CNN提取局部特征
        x = self.se1(x)          # SE增强局部特征
        x = self.transformer(x)  # Transformer建模全局依赖
        x = self.se2(x.mean(dim=1))  # SE增强全局特征 + 池化
        return self.fc(x)

训练与评估

python 复制代码
# 数据加载(示例:光谱数据集)
from torch.utils.data import DataLoader
train_loader = DataLoader(SpectrumDataset(), batch_size=32, shuffle=True)

# 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    # 每轮评估
    model.eval()
    evaluate(test_labels, model(test_data).argmax(axis=1))

应用场景与优势

  • 适用领域:光谱分类(如化学物质识别)、表格数据分类(如医疗诊断)、时间序列预测(如股票趋势分析)。
  • 优势
    • 局部-全局特征互补:CNN捕捉细节,Transformer建模长依赖,SE优化特征权重。
    • 高可解释性:SHAP分析直观展示关键特征,适用于需要决策透明度的场景(如医疗、金融)。
  • 案例数据集 :内置SpectrumDataset示例,支持自定义CSV或NumPy数据输入。

环境依赖

  • 框架:PyTorch ≥1.8.0 + CUDA(可选)
  • 依赖库scikit-learn(评估指标)、shap(可解释性分析)、matplotlib(可视化)
相关推荐
山顶听风5 分钟前
多层感知器MLP实现非线性分类(原理)
人工智能·分类·数据挖掘
山顶听风8 分钟前
MLP实战二:MLP 实现图像数字多分类
人工智能·机器学习·分类
rit84324992 小时前
基于BP神经网络的语音特征信号分类
人工智能·神经网络·分类
yzx9910136 小时前
基于 Q-Learning 算法和 CNN 的强化学习实现方案
人工智能·算法·cnn
海盗儿12 小时前
Attention Is All You Need (Transformer) 以及Transformer pytorch实现
pytorch·深度学习·transformer
春末的南方城市13 小时前
港科大&快手提出统一上下文视频编辑 UNIC,各种视频编辑任务一网打尽,还可进行多项任务组合!
人工智能·计算机视觉·stable diffusion·aigc·transformer
量子-Alex17 小时前
【反无人机检测】C2FDrone:基于视觉Transformer网络的无人机间由粗到细检测
网络·transformer·无人机
苏苏susuus21 小时前
机器学习:集成学习概念和分类、随机森林、Adaboost、GBDT
机器学习·分类·集成学习
l木本I1 天前
大模型低秩微调技术 LoRA 深度解析与实践
python·深度学习·自然语言处理·lstm·transformer
vlln2 天前
【论文解读】MemGPT: 迈向为操作系统的LLM
人工智能·深度学习·自然语言处理·transformer