TabNet 流程图集合(Mermaid)

TabNet 流程图集合(Mermaid)

本文档包含 TabNet 架构的多个关键流程图,使用 Mermaid 语法绘制


目录

  1. TabNet 整体架构流程
  2. 单个决策步骤详细流程
  3. 注意力变换器流程
  4. 特征变换器流程
  5. 可解释性计算流程
  6. 自监督预训练流程
  7. 训练流程
  8. 模型选择决策树

1. TabNet 整体架构流程

生成掩码
生成掩码
生成掩码
输入特征 f
Batch Normalization
初始化 P0 = 1, a0 = 0
决策步骤 1
决策步骤 2
决策步骤 N
d1
d2
dN
ReLU
ReLU
ReLU
求和聚合
全连接层
输出预测
M1
M2
MN
可解释性分析

说明

  • 蓝色:输入
  • 绿色:核心决策步骤
  • 红色:最终输出
  • 黄色:可解释性分支

2. 单个决策步骤详细流程

决策步骤_i
前一步输出 ai-1
注意力变换器
先验缩放 Pi-1
特征掩码 Mi
原始特征 f
逐元素相乘
被选择的特征

Mi * f
特征变换器
分割
决策输出 di

维度 Nd
注意力输出 ai

维度 Na
更新先验

Prior Update Rule
传递到聚合
传递到下一步
传递到下一步

关键点

  • 黄色:特征掩码(可解释性来源)
  • 绿色:用于最终决策
  • 蓝色:传递给下一步
  • 橙色:控制特征重用

3. 注意力变换器流程

Sparsemax_特性
输出稀疏
大部分为 0
和为 1
上一步输出

ai-1
全连接层
Batch Norm
注意力分数
先验缩放

Pi-1
逐元素相乘
调制后的分数
Sparsemax

归一化
特征掩码

Mi

核心机制

  • 先验调制:已使用的特征权重降低
  • Sparsemax:产生稀疏掩码(大部分为0)
  • 结果:每步只选择少数关键特征

4. 特征变换器流程

独立层(步骤特定)
共享层(所有步骤共用)
前 N_d 维
后 N_a 维
GLU Block 结构
FC
BN
GLU 激活
Residual √0.5
选中的特征

Mi · f
FC + BN + GLU

  • Residual
    FC + BN + GLU

  • Residual
    FC + BN + GLU

  • Residual
    FC + BN + GLU

  • Residual
    分割
    di

用于决策
ai

给下一步

设计亮点

  • 蓝色:共享层(参数效率)
  • 绿色:独立层(步骤特异性)
  • 组合优势:平衡参数复用和表达能力

5. 可解释性计算流程

全局可解释性_所有样本
局部可解释性_单个样本
M1: 步骤1掩码
可视化

被选中的特征
M2: 步骤2掩码
可视化

被选中的特征
MN: 步骤N掩码
可视化

被选中的特征
d1
计算贡献度 η1
d2
计算贡献度 η2
dN
计算贡献度 ηN
加权掩码

η1 × M1
加权掩码

η2 × M2
加权掩码

ηN × MN
求和
归一化

使总和为 1
聚合特征重要性

M_agg
对所有样本求平均
全局特征重要性

两层可解释性

  1. 局部:查看每个步骤选择了哪些特征
  2. 全局:量化每个特征的整体贡献

6. 自监督预训练流程

训练多轮
原始特征 f
生成随机掩码 S

Bernoulli, p_s = 0.3
已知特征

1 - S * f
未知特征

S * f
TabNet Encoder

Prior: 1 - S
编码表示

d_out
TabNet Decoder

多步 Feature Transformer
重建特征

f_reconstructed
仅输出被掩盖部分

S * f_reconstructed
重建损失

Masked Reconstruction Loss
反向传播

参数更新
预训练完成的 Encoder
迁移到监督任务

微调

关键步骤

  1. 掩码:随机隐藏30%的特征
  2. 编码:用剩余特征学习表示
  3. 解码:重建被隐藏的特征
  4. 微调:用于下游监督任务

7. 训练流程









开始训练
加载数据

训练集 + 验证集
初始化模型

设置超参数
开始 Epoch
获取一个 Batch
前向传播
计算任务损失

分类/回归
计算稀疏损失

L_sparse
生成掩码
总损失

L_total = L_task + λ·L_sparse
反向传播
梯度裁剪

防止梯度爆炸
更新参数

Adam 优化器
还有 Batch?
验证集评估
验证损失

是否最优?
保存最佳模型

重置 patience
patience += 1
达到最大

Epoch?
patience

> 阈值?
Early Stopping

训练结束
学习率调度

ReduceLROnPlateau

训练要点

  • 梯度裁剪:防止梯度爆炸
  • Early Stopping:防止过拟合
  • 学习率调度:自适应调整
  • 保存最佳模型:基于验证集

8. 模型选择决策树



















需要表格数据建模
数据量 > 10k?
小数据集路径
需要可解释性?
需要

可解释性?
简单模型

Logistic/Trees
尝试 XGBoost

或 TabNet
需要

端到端学习?
有无标签

数据?
✅ 使用 TabNet
训练速度

要求高?
使用 XGBoost

  • SHAP 解释
    ✅ 使用 TabNet

内生可解释性
预训练

数据量大?
数据量

> 100k?
✅ 使用 TabNet

  • 预训练
    TabNet 或

树模型都可
TabNet vs

XGBoost 对比
优先 XGBoost
TabNet

更好?
✅ 使用 TabNet
使用 XGBoost

决策建议

  • 绿色:推荐 TabNet
  • 蓝色:推荐树模型(XGBoost/LightGBM)
  • 灰色:简单模型

9. TabNet vs 树模型对比流程

树模型优势场景
TabNet 优势场景
数据量 > 50k
需要可解释性
有无标签数据
端到端学习
复杂决策边界
多模态融合
数据量 < 10k
训练速度优先
高基数分类特征
成熟工具链需求
CPU 环境
表格学习任务
✅ 选择 TabNet
✅ 选择树模型
性能优异

可解释

灵活集成
训练快速

稳定可靠

工具成熟


10. Sparsemax vs Softmax 对比流程

示例对比
输入: 1, 2, 3, 100
Softmax:

~0, ~0, ~0, ~1
Sparsemax:

0, 0, 0, 1
注意力分数

logits
Softmax 归一化
Sparsemax 归一化
密集输出

所有值 > 0
稀疏输出

大部分值 = 0
问题:

所有特征都有权重

难以解释
优势1:

自动特征选择
优势2:

高可解释性
优势3:

参数高效

关键区别

  • Softmax:所有值 > 0(密集)
  • Sparsemax:大部分值 = 0(稀疏)

11. 预训练 vs 监督学习性能对比

< 10k
10k-100k
> 100k 是

有预训练
自监督预训练

掩码特征预测
学到通用表示
监督微调
性能: 提升
无预训练(仅监督学习)
随机初始化参数
监督训练
性能: 基线
数据情况
标注数据量
预训练提升显著

+3-5%
预训练有帮助

+1-2%
预训练提升较小

+0.3-0.5%
有大量

无标签数据?
✅ 强烈推荐预训练
可选预训练

实验结论(Higgs 数据集):

  • 1k 样本:预训练提升 3.9%
  • 10k 样本:预训练提升 1.4%
  • 100k 样本:预训练提升 0.27%

12. 完整数据流图(端到端)

单次前向传播
预处理


继续训练
Early Stop
原始表格数据
数据预处理
数值特征:

标准化
分类特征:

Embedding
缺失值:

填充/标记
特征矩阵 f
是否预训练?
自监督预训练阶段
监督学习阶段
训练 Encoder
训练循环
特征
决策步骤1
决策步骤2
决策步骤N
聚合
输出
计算损失
反向传播
更新参数
验证
加载最佳模型
推理阶段
新数据
预测
可解释性分析
特征掩码可视化
特征重要性排序
输出结果 + 解释


使用说明

如何渲染这些流程图

  1. 在 Markdown 编辑器中
    • 支持 Mermaid 的编辑器(如 Typora, Obsidian, VS Code + 插件)
    • 直接复制代码即可自动渲染
  2. 在 GitHub/GitLab
    • 将代码块放在 README.md
    • 自动渲染为图表
  3. 在线工具
  4. 导出为图片
    • 使用 Mermaid CLI: mmdc -i input.mmd -o output.png
    • 或在线编辑器中点击"Export"

自定义样式

您可以修改颜色:
节点
节点

常用颜色代码:

  • 绿色:#4caf50(成功/推荐)
  • 蓝色:#2196f3(信息/次要)
  • 黄色:#ffeb3b(警告/重要)
  • 橙色:#ff9800(注意)
  • 红色:#f44336(错误/停止)

流程图总结

流程图编号 名称 用途
1 整体架构 理解 TabNet 全局结构
2 决策步骤 深入单步处理流程
3 注意力变换器 理解特征选择机制
4 特征变换器 理解特征处理流程
5 可解释性 如何计算特征重要性
6 预训练 自监督学习流程
7 训练流程 完整训练循环
8 模型选择 何时用 TabNet
9 模型对比 TabNet vs 树模型
10 Sparsemax 稀疏归一化原理
11 预训练对比 预训练的效果
12 端到端流程 从数据到预测

推荐使用顺序

初学者路径

  1. 整体架构(图1) → 了解全局
  2. 决策步骤(图2) → 理解核心
  3. 训练流程(图7) → 实践指导
  4. 模型选择(图8) → 应用场景

进阶研究路径

  1. 注意力变换器(图3) → 深入机制
  2. 特征变换器(图4) → 架构细节
  3. 可解释性(图5) → 解释方法
  4. 预训练(图6) → 高级技巧

实战应用路径

  1. 模型选择(图8) → 决定是否使用
  2. 端到端流程(图12) → 完整流程
  3. 训练流程(图7) → 实现细节
  4. 可解释性(图5) → 结果分析
相关推荐
川西胖墩墩17 小时前
团队协作泳道图制作工具 PC中文免费
大数据·论文阅读·人工智能·架构·流程图
其美杰布-富贵-李18 小时前
TabNet: 注意力驱动的可解释表格学习架构
学习·表格数据·tabnet
min1811234561 天前
产品开发跨职能流程图在线生成工具
人工智能·microsoft·信息可视化·架构·机器人·流程图
AC赳赳老秦1 天前
量化交易脚本开发:DeepSeek生成技术指标计算与信号触发代码
数据库·elasticsearch·信息可视化·流程图·数据库架构·memcached·deepseek
帅次1 天前
系统设计方法论全解:原则、模型与用户体验核心要义
设计模式·流程图·软件工程·软件构建·需求分析·设计规范·规格说明书
檐下翻书1732 天前
产品开发跨职能流程图在线生成工具
大数据·人工智能·架构·流程图·论文笔记
程序员zgh3 天前
类AI技巧 —— 文字描述+draw.io 自动生成图表
c语言·c++·ai作画·流程图·ai编程·甘特图·draw.io
神探小白牙3 天前
使用@antv/x6绘制流程图
流程图
爱好读书4 天前
SQL生成ER图|AI生成ER图
数据库·sql·毕业设计·流程图·课程设计