TabNet 流程图集合(Mermaid)

TabNet 流程图集合(Mermaid)

本文档包含 TabNet 架构的多个关键流程图,使用 Mermaid 语法绘制


目录

  1. TabNet 整体架构流程
  2. 单个决策步骤详细流程
  3. 注意力变换器流程
  4. 特征变换器流程
  5. 可解释性计算流程
  6. 自监督预训练流程
  7. 训练流程
  8. 模型选择决策树

1. TabNet 整体架构流程

生成掩码
生成掩码
生成掩码
输入特征 f
Batch Normalization
初始化 P0 = 1, a0 = 0
决策步骤 1
决策步骤 2
决策步骤 N
d1
d2
dN
ReLU
ReLU
ReLU
求和聚合
全连接层
输出预测
M1
M2
MN
可解释性分析

说明

  • 蓝色:输入
  • 绿色:核心决策步骤
  • 红色:最终输出
  • 黄色:可解释性分支

2. 单个决策步骤详细流程

决策步骤_i
前一步输出 ai-1
注意力变换器
先验缩放 Pi-1
特征掩码 Mi
原始特征 f
逐元素相乘
被选择的特征

Mi * f
特征变换器
分割
决策输出 di

维度 Nd
注意力输出 ai

维度 Na
更新先验

Prior Update Rule
传递到聚合
传递到下一步
传递到下一步

关键点

  • 黄色:特征掩码(可解释性来源)
  • 绿色:用于最终决策
  • 蓝色:传递给下一步
  • 橙色:控制特征重用

3. 注意力变换器流程

Sparsemax_特性
输出稀疏
大部分为 0
和为 1
上一步输出

ai-1
全连接层
Batch Norm
注意力分数
先验缩放

Pi-1
逐元素相乘
调制后的分数
Sparsemax

归一化
特征掩码

Mi

核心机制

  • 先验调制:已使用的特征权重降低
  • Sparsemax:产生稀疏掩码(大部分为0)
  • 结果:每步只选择少数关键特征

4. 特征变换器流程

独立层(步骤特定)
共享层(所有步骤共用)
前 N_d 维
后 N_a 维
GLU Block 结构
FC
BN
GLU 激活
Residual √0.5
选中的特征

Mi · f
FC + BN + GLU

  • Residual
    FC + BN + GLU

  • Residual
    FC + BN + GLU

  • Residual
    FC + BN + GLU

  • Residual
    分割
    di

用于决策
ai

给下一步

设计亮点

  • 蓝色:共享层(参数效率)
  • 绿色:独立层(步骤特异性)
  • 组合优势:平衡参数复用和表达能力

5. 可解释性计算流程

全局可解释性_所有样本
局部可解释性_单个样本
M1: 步骤1掩码
可视化

被选中的特征
M2: 步骤2掩码
可视化

被选中的特征
MN: 步骤N掩码
可视化

被选中的特征
d1
计算贡献度 η1
d2
计算贡献度 η2
dN
计算贡献度 ηN
加权掩码

η1 × M1
加权掩码

η2 × M2
加权掩码

ηN × MN
求和
归一化

使总和为 1
聚合特征重要性

M_agg
对所有样本求平均
全局特征重要性

两层可解释性

  1. 局部:查看每个步骤选择了哪些特征
  2. 全局:量化每个特征的整体贡献

6. 自监督预训练流程

训练多轮
原始特征 f
生成随机掩码 S

Bernoulli, p_s = 0.3
已知特征

1 - S * f
未知特征

S * f
TabNet Encoder

Prior: 1 - S
编码表示

d_out
TabNet Decoder

多步 Feature Transformer
重建特征

f_reconstructed
仅输出被掩盖部分

S * f_reconstructed
重建损失

Masked Reconstruction Loss
反向传播

参数更新
预训练完成的 Encoder
迁移到监督任务

微调

关键步骤

  1. 掩码:随机隐藏30%的特征
  2. 编码:用剩余特征学习表示
  3. 解码:重建被隐藏的特征
  4. 微调:用于下游监督任务

7. 训练流程









开始训练
加载数据

训练集 + 验证集
初始化模型

设置超参数
开始 Epoch
获取一个 Batch
前向传播
计算任务损失

分类/回归
计算稀疏损失

L_sparse
生成掩码
总损失

L_total = L_task + λ·L_sparse
反向传播
梯度裁剪

防止梯度爆炸
更新参数

Adam 优化器
还有 Batch?
验证集评估
验证损失

是否最优?
保存最佳模型

重置 patience
patience += 1
达到最大

Epoch?
patience

> 阈值?
Early Stopping

训练结束
学习率调度

ReduceLROnPlateau

训练要点

  • 梯度裁剪:防止梯度爆炸
  • Early Stopping:防止过拟合
  • 学习率调度:自适应调整
  • 保存最佳模型:基于验证集

8. 模型选择决策树



















需要表格数据建模
数据量 > 10k?
小数据集路径
需要可解释性?
需要

可解释性?
简单模型

Logistic/Trees
尝试 XGBoost

或 TabNet
需要

端到端学习?
有无标签

数据?
✅ 使用 TabNet
训练速度

要求高?
使用 XGBoost

  • SHAP 解释
    ✅ 使用 TabNet

内生可解释性
预训练

数据量大?
数据量

> 100k?
✅ 使用 TabNet

  • 预训练
    TabNet 或

树模型都可
TabNet vs

XGBoost 对比
优先 XGBoost
TabNet

更好?
✅ 使用 TabNet
使用 XGBoost

决策建议

  • 绿色:推荐 TabNet
  • 蓝色:推荐树模型(XGBoost/LightGBM)
  • 灰色:简单模型

9. TabNet vs 树模型对比流程

树模型优势场景
TabNet 优势场景
数据量 > 50k
需要可解释性
有无标签数据
端到端学习
复杂决策边界
多模态融合
数据量 < 10k
训练速度优先
高基数分类特征
成熟工具链需求
CPU 环境
表格学习任务
✅ 选择 TabNet
✅ 选择树模型
性能优异

可解释

灵活集成
训练快速

稳定可靠

工具成熟


10. Sparsemax vs Softmax 对比流程

示例对比
输入: 1, 2, 3, 100
Softmax:

~0, ~0, ~0, ~1
Sparsemax:

0, 0, 0, 1
注意力分数

logits
Softmax 归一化
Sparsemax 归一化
密集输出

所有值 > 0
稀疏输出

大部分值 = 0
问题:

所有特征都有权重

难以解释
优势1:

自动特征选择
优势2:

高可解释性
优势3:

参数高效

关键区别

  • Softmax:所有值 > 0(密集)
  • Sparsemax:大部分值 = 0(稀疏)

11. 预训练 vs 监督学习性能对比

< 10k
10k-100k
> 100k 是

有预训练
自监督预训练

掩码特征预测
学到通用表示
监督微调
性能: 提升
无预训练(仅监督学习)
随机初始化参数
监督训练
性能: 基线
数据情况
标注数据量
预训练提升显著

+3-5%
预训练有帮助

+1-2%
预训练提升较小

+0.3-0.5%
有大量

无标签数据?
✅ 强烈推荐预训练
可选预训练

实验结论(Higgs 数据集):

  • 1k 样本:预训练提升 3.9%
  • 10k 样本:预训练提升 1.4%
  • 100k 样本:预训练提升 0.27%

12. 完整数据流图(端到端)

单次前向传播
预处理


继续训练
Early Stop
原始表格数据
数据预处理
数值特征:

标准化
分类特征:

Embedding
缺失值:

填充/标记
特征矩阵 f
是否预训练?
自监督预训练阶段
监督学习阶段
训练 Encoder
训练循环
特征
决策步骤1
决策步骤2
决策步骤N
聚合
输出
计算损失
反向传播
更新参数
验证
加载最佳模型
推理阶段
新数据
预测
可解释性分析
特征掩码可视化
特征重要性排序
输出结果 + 解释


使用说明

如何渲染这些流程图

  1. 在 Markdown 编辑器中
    • 支持 Mermaid 的编辑器(如 Typora, Obsidian, VS Code + 插件)
    • 直接复制代码即可自动渲染
  2. 在 GitHub/GitLab
    • 将代码块放在 README.md
    • 自动渲染为图表
  3. 在线工具
  4. 导出为图片
    • 使用 Mermaid CLI: mmdc -i input.mmd -o output.png
    • 或在线编辑器中点击"Export"

自定义样式

您可以修改颜色:
节点
节点

常用颜色代码:

  • 绿色:#4caf50(成功/推荐)
  • 蓝色:#2196f3(信息/次要)
  • 黄色:#ffeb3b(警告/重要)
  • 橙色:#ff9800(注意)
  • 红色:#f44336(错误/停止)

流程图总结

流程图编号 名称 用途
1 整体架构 理解 TabNet 全局结构
2 决策步骤 深入单步处理流程
3 注意力变换器 理解特征选择机制
4 特征变换器 理解特征处理流程
5 可解释性 如何计算特征重要性
6 预训练 自监督学习流程
7 训练流程 完整训练循环
8 模型选择 何时用 TabNet
9 模型对比 TabNet vs 树模型
10 Sparsemax 稀疏归一化原理
11 预训练对比 预训练的效果
12 端到端流程 从数据到预测

推荐使用顺序

初学者路径

  1. 整体架构(图1) → 了解全局
  2. 决策步骤(图2) → 理解核心
  3. 训练流程(图7) → 实践指导
  4. 模型选择(图8) → 应用场景

进阶研究路径

  1. 注意力变换器(图3) → 深入机制
  2. 特征变换器(图4) → 架构细节
  3. 可解释性(图5) → 解释方法
  4. 预训练(图6) → 高级技巧

实战应用路径

  1. 模型选择(图8) → 决定是否使用
  2. 端到端流程(图12) → 完整流程
  3. 训练流程(图7) → 实现细节
  4. 可解释性(图5) → 结果分析
相关推荐
EmmaXLZHONG9 小时前
Reinforce Learning Concept Flow Chart (强化学习概念流程图)
人工智能·深度学习·机器学习·流程图
Nice__J1 天前
Mermaid (代码转流程图)语法详解
网络·流程图
IT_Octopus2 天前
JVM G1 CMS 垃圾收集器工作流程简化流程图
java·jvm·流程图
数说星榆1812 天前
前后端分离开发流程-泳道图设计与应用
论文阅读·职场和发展·毕业设计·流程图·职场发展·论文笔记·毕设
数说星榆1812 天前
项目管理流程图-泳道图模板免费下载
论文阅读·毕业设计·流程图·论文笔记·毕设
程途拾光1582 天前
产品功能验收泳道图-流程图模板下载
论文阅读·职场和发展·毕业设计·流程图·课程设计·论文笔记·毕设
檐下翻书1733 天前
招聘SOP流程图-泳道图模板详细教程
论文阅读·毕业设计·流程图·图论·论文笔记·毕设
轩情吖3 天前
Qt多元素控件之QTableWidget
开发语言·c++·qt·表格·控件·qtablewidget
数说星榆1814 天前
好用的PC电脑流程图软件无需下载在线绘制流程图模板大全
大数据·论文阅读·电脑·流程图·论文笔记
檐下翻书1734 天前
PC端免费在线流程图工具新手快速制作专业流程图教程
论文阅读·架构·毕业设计·流程图·论文笔记