CNN+Transformer+SE注意力机制多分类模型 + SHAP特征重要性分析,pytorch框架

效果一览





代码功能

CNN提取一维序列的局部特征,如光谱峰值、表格数据趋势等。Transformer捕捉一维序列的全局依赖关系,解决长序列建模难题! 弥补CNN在长距离依赖建模上的不足,提升模型的全局特征提取能力。SE注意力机制动态调整特征通道权重,聚焦关键信息,提升分类精度!

支持多类别分类任务,适用于光谱分类、表格数据分类、时间序列分类等场景。

可自定义类别数量

输出训练损失和准确率,并评估训练集和测试集的准确率,精确率,召回率,f1分数,绘制roc曲线,混淆矩阵

结合SHAP(Shapley Additive exPlanations),直观展示每个特征对分类结果的影响!

包括蜂巢图,重要性图,单特征力图,决策图,热图,瀑布图等。

CNN+Transformer+SE注意力机制多分类模型 + SHAP特征重要性分析

模型架构与核心组件

1. CNN(卷积神经网络)

功能

  • 局部特征提取:通过一维卷积核滑动窗口(如核大小=3/5/7),捕获序列中的局部模式(如光谱峰值、数据趋势)。
  • 特征增强 :使用多层卷积堆叠(Conv1D+ReLU+MaxPool),逐步抽象高阶特征,输出维度为 (batch_size, channels, seq_len)

实现代码片段

python 复制代码
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding='same')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
    def forward(self, x):
        return self.pool(self.relu(self.conv(x)))

2. Transformer

功能

  • 全局依赖建模:利用自注意力机制(Multi-Head Attention)捕捉长序列中的上下文关系。
  • 位置编码:添加可学习的位置编码(Positional Encoding),解决序列顺序问题。
  • 特征融合 :输出全局特征矩阵 (batch_size, seq_len, d_model)

实现代码片段

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, num_layers=2):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
    def forward(self, x):
        x = self.pos_encoder(x.permute(0,2,1))  # 调整维度为 (seq_len, batch, d_model)
        return self.transformer(x).permute(1,0,2)

3. SE(Squeeze-and-Excitation)注意力机制

功能

  • 通道权重动态调整:通过全局平均池化(Squeeze)和全连接层(Excitation),生成通道权重向量。
  • 特征增强 :对CNN或Transformer输出的特征图进行通道级加权,公式:
    ( \text{SE}(x) = x \cdot \sigma(W_2 \cdot \text{ReLU}(W_1 \cdot \text{GAP}(x))) )

实现代码片段

python 复制代码
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels//reduction),
            nn.ReLU(),
            nn.Linear(channels//reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        weights = self.fc(self.gap(x).squeeze(-1))
        return x * weights.unsqueeze(-1)

4. 多分类任务支持

功能

  • 输出层 :全连接层 + Softmax,支持自定义类别数(num_classes)。
  • 评估指标:准确率、精确率、召回率、F1分数、ROC-AUC、混淆矩阵。

实现代码片段

python 复制代码
# 模型输出层
self.fc = nn.Linear(d_model, num_classes)

# 评估函数
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
def evaluate(y_true, y_pred):
    print(classification_report(y_true, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
    print("ROC-AUC:", roc_auc_score(y_true, y_pred, multi_class='ovr'))

5. SHAP特征重要性分析

功能

  • 模型可解释性:基于博弈论的SHAP值,量化特征对分类结果的贡献。
  • 可视化工具:支持蜂巢图、决策图、热力图等,适配序列输入。

实现代码片段

python 复制代码
import shap

# 初始化解释器
explainer = shap.DeepExplainer(model, background_data)

# 计算SHAP值
shap_values = explainer.shap_values(test_sample)

# 可视化(示例:特征重要性图)
shap.summary_plot(shap_values, test_sample, plot_type='bar')

模型整合与训练流程

完整模型架构

python 复制代码
class CNNTransformerSE(nn.Module):
    def __init__(self, input_dim, num_classes, d_model=64, nhead=4):
        super().__init__()
        self.cnn = CNNBlock(input_dim, 32)
        self.se1 = SEBlock(32)
        self.transformer = TransformerBlock(d_model, nhead)
        self.se2 = SEBlock(d_model)
        self.fc = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        x = self.cnn(x)          # CNN提取局部特征
        x = self.se1(x)          # SE增强局部特征
        x = self.transformer(x)  # Transformer建模全局依赖
        x = self.se2(x.mean(dim=1))  # SE增强全局特征 + 池化
        return self.fc(x)

训练与评估

python 复制代码
# 数据加载(示例:光谱数据集)
from torch.utils.data import DataLoader
train_loader = DataLoader(SpectrumDataset(), batch_size=32, shuffle=True)

# 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    # 每轮评估
    model.eval()
    evaluate(test_labels, model(test_data).argmax(axis=1))

应用场景与优势

  • 适用领域:光谱分类(如化学物质识别)、表格数据分类(如医疗诊断)、时间序列预测(如股票趋势分析)。
  • 优势
    • 局部-全局特征互补:CNN捕捉细节,Transformer建模长依赖,SE优化特征权重。
    • 高可解释性:SHAP分析直观展示关键特征,适用于需要决策透明度的场景(如医疗、金融)。
  • 案例数据集 :内置SpectrumDataset示例,支持自定义CSV或NumPy数据输入。

环境依赖

  • 框架:PyTorch ≥1.8.0 + CUDA(可选)
  • 依赖库scikit-learn(评估指标)、shap(可解释性分析)、matplotlib(可视化)
相关推荐
蹦蹦跳跳真可爱5892 小时前
Python----卷积神经网络(卷积为什么能识别图像)
人工智能·python·深度学习·神经网络·计算机视觉·cnn
盼小辉丶7 小时前
PyTorch生成式人工智能实战(3)——分类任务详解
人工智能·pytorch·分类
JOYCE_Leo1618 小时前
一文详解卷积神经网络中的卷积层和池化层原理 !!
人工智能·深度学习·cnn·卷积神经网络
契合qht53_shine1 天前
深度学习 视觉处理(CNN) day_01
人工智能·深度学习·cnn
何仙鸟1 天前
卷积神经网络实战(1)
人工智能·神经网络·cnn
视觉AI1 天前
SiamMask中的分类分支、回归分支与Mask分支,有何本质差异?
计算机视觉·分类·回归
Mu先生Ai世界1 天前
AI 生成 3D 技术解析:驱动力、价值主张与核心挑战 (AI+3D 产品经理笔记 S2E01)
人工智能·游戏·3d·aigc·transformer·产品经理·vr
白熊1881 天前
【计算机视觉】CV实战项目- Four-Flower:基于TensorFlow的花朵分类实战指南
计算机视觉·分类·tensorflow
视觉语言导航1 天前
复杂地形越野机器人导航新突破!VERTIFORMER:数据高效多任务Transformer助力越野机器人移动导航
人工智能·深度学习·机器人·transformer·具身智能
巷9551 天前
卷积神经网络迁移学习:原理与实践指南
人工智能·cnn·迁移学习