TabNet 流程图集合(Mermaid)
本文档包含 TabNet 架构的多个关键流程图,使用 Mermaid 语法绘制
目录
1. TabNet 整体架构流程
生成掩码
生成掩码
生成掩码
输入特征 f
Batch Normalization
初始化 P0 = 1, a0 = 0
决策步骤 1
决策步骤 2
决策步骤 N
d1
d2
dN
ReLU
ReLU
ReLU
求和聚合
全连接层
输出预测
M1
M2
MN
可解释性分析
说明:
- 蓝色:输入
- 绿色:核心决策步骤
- 红色:最终输出
- 黄色:可解释性分支
2. 单个决策步骤详细流程
决策步骤_i
前一步输出 ai-1
注意力变换器
先验缩放 Pi-1
特征掩码 Mi
原始特征 f
逐元素相乘
被选择的特征
Mi * f
特征变换器
分割
决策输出 di
维度 Nd
注意力输出 ai
维度 Na
更新先验
Prior Update Rule
传递到聚合
传递到下一步
传递到下一步
关键点:
- 黄色:特征掩码(可解释性来源)
- 绿色:用于最终决策
- 蓝色:传递给下一步
- 橙色:控制特征重用
3. 注意力变换器流程
Sparsemax_特性
输出稀疏
大部分为 0
和为 1
上一步输出
ai-1
全连接层
Batch Norm
注意力分数
先验缩放
Pi-1
逐元素相乘
调制后的分数
Sparsemax
归一化
特征掩码
Mi
核心机制:
- 先验调制:已使用的特征权重降低
- Sparsemax:产生稀疏掩码(大部分为0)
- 结果:每步只选择少数关键特征
4. 特征变换器流程
独立层(步骤特定)
共享层(所有步骤共用)
前 N_d 维
后 N_a 维
GLU Block 结构
FC
BN
GLU 激活
Residual √0.5
选中的特征
Mi · f
FC + BN + GLU
-
Residual
FC + BN + GLU -
Residual
FC + BN + GLU -
Residual
FC + BN + GLU -
Residual
分割
di
用于决策
ai
给下一步
设计亮点:
- 蓝色:共享层(参数效率)
- 绿色:独立层(步骤特异性)
- 组合优势:平衡参数复用和表达能力
5. 可解释性计算流程
全局可解释性_所有样本
局部可解释性_单个样本
M1: 步骤1掩码
可视化
被选中的特征
M2: 步骤2掩码
可视化
被选中的特征
MN: 步骤N掩码
可视化
被选中的特征
d1
计算贡献度 η1
d2
计算贡献度 η2
dN
计算贡献度 ηN
加权掩码
η1 × M1
加权掩码
η2 × M2
加权掩码
ηN × MN
求和
归一化
使总和为 1
聚合特征重要性
M_agg
对所有样本求平均
全局特征重要性
两层可解释性:
- 局部:查看每个步骤选择了哪些特征
- 全局:量化每个特征的整体贡献
6. 自监督预训练流程
训练多轮
原始特征 f
生成随机掩码 S
Bernoulli, p_s = 0.3
已知特征
1 - S * f
未知特征
S * f
TabNet Encoder
Prior: 1 - S
编码表示
d_out
TabNet Decoder
多步 Feature Transformer
重建特征
f_reconstructed
仅输出被掩盖部分
S * f_reconstructed
重建损失
Masked Reconstruction Loss
反向传播
参数更新
预训练完成的 Encoder
迁移到监督任务
微调
关键步骤:
- 掩码:随机隐藏30%的特征
- 编码:用剩余特征学习表示
- 解码:重建被隐藏的特征
- 微调:用于下游监督任务
7. 训练流程
是
否
是
否
是
否
否
是
开始训练
加载数据
训练集 + 验证集
初始化模型
设置超参数
开始 Epoch
获取一个 Batch
前向传播
计算任务损失
分类/回归
计算稀疏损失
L_sparse
生成掩码
总损失
L_total = L_task + λ·L_sparse
反向传播
梯度裁剪
防止梯度爆炸
更新参数
Adam 优化器
还有 Batch?
验证集评估
验证损失
是否最优?
保存最佳模型
重置 patience
patience += 1
达到最大
Epoch?
patience
> 阈值?
Early Stopping
训练结束
学习率调度
ReduceLROnPlateau
训练要点:
- 梯度裁剪:防止梯度爆炸
- Early Stopping:防止过拟合
- 学习率调度:自适应调整
- 保存最佳模型:基于验证集
8. 模型选择决策树
否
是
是
否
是
否
是
否
是
否
是
否
是
否
是
否
是
否
需要表格数据建模
数据量 > 10k?
小数据集路径
需要可解释性?
需要
可解释性?
简单模型
Logistic/Trees
尝试 XGBoost
或 TabNet
需要
端到端学习?
有无标签
数据?
✅ 使用 TabNet
训练速度
要求高?
使用 XGBoost
- SHAP 解释
✅ 使用 TabNet
内生可解释性
预训练
数据量大?
数据量
> 100k?
✅ 使用 TabNet
- 预训练
TabNet 或
树模型都可
TabNet vs
XGBoost 对比
优先 XGBoost
TabNet
更好?
✅ 使用 TabNet
使用 XGBoost
决策建议:
- 绿色:推荐 TabNet
- 蓝色:推荐树模型(XGBoost/LightGBM)
- 灰色:简单模型
9. TabNet vs 树模型对比流程
树模型优势场景
TabNet 优势场景
数据量 > 50k
需要可解释性
有无标签数据
端到端学习
复杂决策边界
多模态融合
数据量 < 10k
训练速度优先
高基数分类特征
成熟工具链需求
CPU 环境
表格学习任务
✅ 选择 TabNet
✅ 选择树模型
性能优异
可解释
灵活集成
训练快速
稳定可靠
工具成熟
10. Sparsemax vs Softmax 对比流程
示例对比
输入: 1, 2, 3, 100
Softmax:
~0, ~0, ~0, ~1
Sparsemax:
0, 0, 0, 1
注意力分数
logits
Softmax 归一化
Sparsemax 归一化
密集输出
所有值 > 0
稀疏输出
大部分值 = 0
问题:
所有特征都有权重
难以解释
优势1:
自动特征选择
优势2:
高可解释性
优势3:
参数高效
关键区别:
- Softmax:所有值 > 0(密集)
- Sparsemax:大部分值 = 0(稀疏)
11. 预训练 vs 监督学习性能对比
< 10k
10k-100k
> 100k 是
否
有预训练
自监督预训练
掩码特征预测
学到通用表示
监督微调
性能: 提升
无预训练(仅监督学习)
随机初始化参数
监督训练
性能: 基线
数据情况
标注数据量
预训练提升显著
+3-5%
预训练有帮助
+1-2%
预训练提升较小
+0.3-0.5%
有大量
无标签数据?
✅ 强烈推荐预训练
可选预训练
实验结论(Higgs 数据集):
- 1k 样本:预训练提升 3.9%
- 10k 样本:预训练提升 1.4%
- 100k 样本:预训练提升 0.27%
12. 完整数据流图(端到端)
单次前向传播
预处理
是
否
继续训练
Early Stop
原始表格数据
数据预处理
数值特征:
标准化
分类特征:
Embedding
缺失值:
填充/标记
特征矩阵 f
是否预训练?
自监督预训练阶段
监督学习阶段
训练 Encoder
训练循环
特征
决策步骤1
决策步骤2
决策步骤N
聚合
输出
计算损失
反向传播
更新参数
验证
加载最佳模型
推理阶段
新数据
预测
可解释性分析
特征掩码可视化
特征重要性排序
输出结果 + 解释
使用说明
如何渲染这些流程图
- 在 Markdown 编辑器中 :
- 支持 Mermaid 的编辑器(如 Typora, Obsidian, VS Code + 插件)
- 直接复制代码即可自动渲染
- 在 GitHub/GitLab :
- 将代码块放在 README.md 中
- 自动渲染为图表
- 在线工具 :
- Mermaid Live Editor
- 粘贴代码即可编辑和导出
- 导出为图片 :
- 使用 Mermaid CLI:
mmdc -i input.mmd -o output.png - 或在线编辑器中点击"Export"
- 使用 Mermaid CLI:
自定义样式
您可以修改颜色:
节点
节点
常用颜色代码:
- 绿色:
#4caf50(成功/推荐) - 蓝色:
#2196f3(信息/次要) - 黄色:
#ffeb3b(警告/重要) - 橙色:
#ff9800(注意) - 红色:
#f44336(错误/停止)
流程图总结
| 流程图编号 | 名称 | 用途 |
|---|---|---|
| 1 | 整体架构 | 理解 TabNet 全局结构 |
| 2 | 决策步骤 | 深入单步处理流程 |
| 3 | 注意力变换器 | 理解特征选择机制 |
| 4 | 特征变换器 | 理解特征处理流程 |
| 5 | 可解释性 | 如何计算特征重要性 |
| 6 | 预训练 | 自监督学习流程 |
| 7 | 训练流程 | 完整训练循环 |
| 8 | 模型选择 | 何时用 TabNet |
| 9 | 模型对比 | TabNet vs 树模型 |
| 10 | Sparsemax | 稀疏归一化原理 |
| 11 | 预训练对比 | 预训练的效果 |
| 12 | 端到端流程 | 从数据到预测 |
推荐使用顺序
初学者路径:
- 整体架构(图1) → 了解全局
- 决策步骤(图2) → 理解核心
- 训练流程(图7) → 实践指导
- 模型选择(图8) → 应用场景
进阶研究路径:
- 注意力变换器(图3) → 深入机制
- 特征变换器(图4) → 架构细节
- 可解释性(图5) → 解释方法
- 预训练(图6) → 高级技巧
实战应用路径:
- 模型选择(图8) → 决定是否使用
- 端到端流程(图12) → 完整流程
- 训练流程(图7) → 实现细节
- 可解释性(图5) → 结果分析