【论文阅读—深度学习处理表格数据】ResNet-like & FT Transformer

[2106.11959] Revisiting Deep Learning Models for Tabular Data(论文链接)

https://github.com/yandex-research/tabular-dl-revisiting-models(代码链接)

这篇论文主要讲的是:

  • 现在很多研究人员提出了各种新的深度学习方法来处理表格数据(像 Excel 里的那种数据)。

  • 这些方法都说自己效果不错,但问题是:大家用的实验数据集和测试方法都不一样,所以根本没法直接比较,谁更好也不清楚

  • 此外,目前还缺少一些"简单又好用"的基准模型(baseline),让别人能方便地拿来对比。

于是,作者做了两件事:

  1. 总结梳理了现有的主要深度学习方法,看看它们的思路和优缺点。

  2. 提出了两个简单但效果很强的模型

    • 第一个是类似 ResNet 的结构(在图像识别里很常见),结果发现它在表格数据上也能作为一个强有力的基线模型。

    • 第二个是稍微改造过的 Transformer,用来处理表格数据,效果比很多其他模型都好。

他们用统一的实验流程,把这些模型在很多任务上做了公平比较。最后还把深度学习模型和"梯度提升树"(一种传统但很强的机器学习方法,比如 XGBoost)做对比,结论是:

👉 目前还没有一个方法能在所有情况下一直是最优的。

Introduction

研究背景

  • 深度学习在图像、语音、文本等领域非常成功,所以大家很想把它用到表格数据上(比如工业生产里的传感器数据、金融数据、医疗数据、比赛里的数据表)。

  • 在表格数据任务里,深度学习的主要对手是 梯度提升决策树(GBDT,比如 XGBoost、LightGBM、CatBoost),它们已经很强大了。

  • 如果深度学习能在表格数据上表现好,还有一个额外的好处:它可以和图像、语音等数据结合,做成"多模态"的端到端模型,用梯度优化一起训练,非常方便。

目前的问题

  • 已经有很多研究提出了新的深度学习模型来处理表格数据,但大家用的数据集和实验流程都不一样。

  • 这就导致:

    • 论文之间没法公平比较,不知道到底哪个模型更好。

    • 也不清楚深度学习是否真的比 GBDT 更强。

  • 此外,虽然模型很多,但还缺少一些 既简单又稳定的"基准模型",能让大家更容易对比和复现。现在大家常用的 MLP(多层感知机)太弱,根本挑战不到其他模型。

作者的思路

  • 他们怀疑:其实一些在图像和 NLP 里已经验证过的经典深度学习模块,在表格任务上可能被忽视了。

  • 所以,他们尝试把这些"成熟的架构"改造一下用在表格数据上,结果提出了两个简单的模型:

    • ResNet-like(残差网络的变体)

    • FT-Transformer(他们改造过的 Transformer)

论文主要贡献

  • 系统地评估了各种表格数据的深度学习模型。

  • 发现 ResNet-like 是一个很强的基线,推荐以后大家用它做比较。

  • 提出了 FT-Transformer,一个简单但强大的新方案。

  • 最后确认,目前没有任何模型能在所有任务上完全超越 GBDT。

传统方法:决策树集成(GBDT)

目前在表格数据任务里,GBDT(梯度提升树) 是最常用、最强的"浅层"方法。有三个非常流行的库:XGBoost、LightGBM、CatBoost,它们在细节上有区别,但在大多数任务上性能差不多。所以,GBDT 一直是机器学习比赛和工业界的首选。

深度学习方法:主要有三大类

可微分的树

传统的树模型不能通过梯度下降训练(因为分裂是"硬"的,没法求导)。有些研究者就把树的分裂函数"平滑化",让树也能用梯度优化,方便和深度学习结合。这些方法在部分任务上能超过 GBDT,但作者的实验发现,它们并不比一个调好参数的 ResNet 稳定

基于注意力的模型

因为注意力机制(Transformer)在图像、文本等领域大获成功,有人也尝试把它用在表格数据上。结果表明:虽然一些注意力模型表现还行,但经过精调的 ResNet 依然更强。不过,作者提出了一种更合适的 Transformer 改造版本(FT-Transformer),在大多数任务上超过了 ResNet。

显示建模特征交叉

在推荐系统和点击率预测中,很多人批评 MLP(多层感知机),因为它不擅长表达特征之间的"乘积关系"(比如"年龄 × 收入"这种交互)。有些工作尝试在 MLP 里加上"特征乘积"的模块。但作者的实验发现,这些方法并不比调好的基线模型更好。

其他杂项模型

还有一些方法不能明确归类到上面三类,比如一些比较特殊的架构设计,但效果也没有形成统一结论。

Models for tabular data

研究目标

找出简单但好用的深度学习基线模型,而不是复杂又难调的架构。尽量使用大家熟悉的深度学习"积木"(比如 MLP、ResNet、Transformer)来搭建模型。这样模型更容易理解、实现,也更容易调出好结果。

统一任务设定

数据集包含数值型特征(numerical)和类别型特征(categorical)。数据被分成 训练集(train)、验证集(val)、测试集(test)。任务分为三类:二分类、多分类和回归。

三种主要模型

MLP(多层感知机)
  • 最基础的深度学习模型。

  • 结构:线性层 → ReLU → Dropout → 线性层 → ... → 输出

  • 简单易用,但在表格数据任务里,往往不是最优解。

ResNet(残差网络改造版)
  • 原本 ResNet 在图像识别很成功(解决了"深层网络难训练"的问题)。

  • 作者尝试在表格数据上用一个简化版 ResNet

  • 核心思想:在每一层里,加一个"捷径连接"(skip connection),让输入可以直接绕过部分计算流向输出。

  • 这样有助于训练,能学到更深层次的特征。

  • 作者认为它会比 MLP 更适合一些需要"深层表达"的任务,其实就是在原来MLP的基础上加上残差链接

FT-Transformer(特征标记器 + Transformer)
  • 这是作者的创新点。

  • 步骤:

    1. 把每个特征(数值/类别)先转换成embedding(向量表示)

    2. Transformer 模型处理这些特征向量(像处理句子里的词一样)。

    3. 最后用 [CLS] token 的输出作为整体表示,再做预测。

  • 好处:Transformer 的自注意力机制可以自动学到特征之间的关系,特别适合表格这种"混合信息"的数据。

  • 作者发现,这个方法在多数任务上能超过 ResNet。

Experiments

作者想回答两个核心问题:

  1. 不同的深度学习(DL)架构在表格数据上的真实表现差别有多大?

  2. 和传统的梯度提升树(GBDT,例如 XGBoost、CatBoost)相比,DL 模型到底有没有优势?

数据集和实验

我们使用了11 个多样化的公开数据集(详细描述见附录)。对于每个数据集,都只有一个固定的训练集--验证集--测试集划分,因此所有算法都使用相同的划分。

这些数据集包括:

  • California Housing (CA):房价预测(房地产数据,Kelley Pace 和 Barry, 1997)

  • Adult (AD):收入预测(Kohavi, 1996)

  • Helena (HE):匿名数据集(Guyon 等, 2019)

  • Jannis (JA):匿名数据集(Guyon 等, 2019)

  • Higgs (HI):模拟的物理粒子数据(Baldi 等, 2014;使用 OpenML 上的 98K 样本版本,Vanschoren 等, 2014)

  • ALOI (AL):图像数据(Geusebroek 等, 2005)

  • Epsilon (EP):模拟物理实验数据

  • Year (YE):音频特征数据(Bertin-Mahieux 等, 2011)

  • Covertype (CO):森林特征数据(Blackard 和 Dean, 2000)

  • Yahoo (YA):搜索查询数据(Chapelle 和 Chang, 2011)

  • Microsoft (MI):搜索查询数据(Qin 和 Liu, 2013)

在学习排序(learning-to-rank)任务上,我们采用逐点(pointwise)方法,并把 MicrosoftYahoo 的排序问题视为回归问题来处理。

实验结果(简单总结)

  1. MLP 依然能打:普通的多层感知机(MLP)经过认真调参,表现不差,是很好的基线。

  2. ResNet-like 很强:在 MLP 上加残差连接,效果显著提升,变成一个强力的 baseline。

  3. FT-Transformer 最稳:改造过的 Transformer(FT-T)在大多数任务上表现最好,而且在各种数据集上都很稳定。

  4. NODE(树状神经网络)也不错,但结构复杂,而且在很多数据集上还是比不过 ResNet。

  5. 集成(Ensemble)有效果:把多个模型预测结果平均,性能更强;ResNet 和 FT-Transformer 的提升尤其明显。

Analysis

FT-Transformer 什么时候比 ResNet 更好?

作者做了一个合成实验

  • 构造一些虚拟任务,把目标值设计成 一部分像 GBDT 容易拟合的函数(fGBDT) ,另一部分像 深度学习模型容易拟合的函数(fDL)

  • 用一个参数 α 来控制这两部分的比例。

结果:

  • 在更偏向 ResNet 友好的任务上(fDL 占主导),ResNet 和 FT-Transformer 表现差不多,都比 CatBoost 强。

  • 但当任务更接近 GBDT 风格时,ResNet 的性能掉下去了,而 FT-Transformer 依然稳健。

结论:FT-Transformer 更"万能",既能处理 DL 友好的任务,也能处理 GBDT 擅长的任务,而 ResNet 更挑剔。

Ablation Study(消融实验)

作者测试了 FT-Transformer 的一些设计细节到底有没有用:

和 AutoInt 的比较

  • AutoInt 也是基于注意力的模型,但细节设计和 Transformer 不一样。

  • 实验表明:FT-Transformer 的主干(标准 Transformer 结构)更强,整体优于 AutoInt。

Feature biases(特征偏置项)有没有必要?

  • FT-Transformer 在 embedding 层里为每个特征加了一个偏置(bias)。

  • 去掉偏置再跑实验 → 性能下降。

  • 说明偏置设计是有帮助的。

从注意力图获取特征重要性

作者提出用 Transformer 的注意力图(attention maps) 来解释模型:

  • 计算 [CLS] token 的注意力分布 → 得到每个特征的重要性分数。

  • 优点:很高效,只要一次前向传播就能得到。

对比方法:

  • Integrated Gradients (IG):通用但很慢。

  • Permutation Test (PT):可靠但要做很多次前向传播(特征数 + 1)。

实验结果:

  • 注意力图方法(AM)的特征重要性和 PT 的排名高度相关,效果和 IG 差不多。

  • 但效率远高于 IG 和 PT。

结论 :用注意力图来解释 FT-Transformer 的特征重要性,是一种 便宜又有效的可解释性方法

加载FT-Transformer处理表格数据

python 复制代码
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from torch.utils.data import Dataset, DataLoader
from rtdl_revisiting_models import FTTransformer  # 注意要安装 rtdl-revisiting

# ========= 1. 加载数据 =========
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    "age", "workclass", "fnlwgt", "education", "education-num",
    "marital-status", "occupation", "relationship", "race", "sex",
    "capital-gain", "capital-loss", "hours-per-week", "native-country", "income"
]
df = pd.read_csv(url, names=columns, sep=",", skipinitialspace=True)

# ========= 2. 预处理 =========
# 标签:>50K / <=50K
df["income"] = LabelEncoder().fit_transform(df["income"])  # <=50K=0, >50K=1

# 数值特征 / 类别特征
num_features = ["age", "fnlwgt", "education-num", "capital-gain", "capital-loss", "hours-per-week"]
cat_features = [col for col in df.columns if col not in num_features + ["income"]]

# 数值标准化
scaler = StandardScaler()
df[num_features] = scaler.fit_transform(df[num_features])

# 类别编码
for col in cat_features:
    df[col] = LabelEncoder().fit_transform(df[col])

X = df[num_features + cat_features].values
y = df["income"].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# ========= 3. Dataset & Dataloader =========
class TabDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

train_loader = DataLoader(TabDataset(X_train, y_train), batch_size=256, shuffle=True)
test_loader = DataLoader(TabDataset(X_test, y_test), batch_size=256)

# ========= 4. 构建 FT-Transformer =========
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FTTransformer(
    d_numerical=len(num_features),
    categories=[int(df[col].nunique()) for col in cat_features],  # 每个类别特征的取值个数
    d_token=32,
    n_layers=3,
    n_heads=4,
    d_ffn_factor=2.0,
    attention_dropout=0.1,
    ffn_dropout=0.1,
    residual_dropout=0.1,
    d_out=2  # 二分类
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# ========= 5. 训练 =========
for epoch in range(5):  # 演示只训练 5 轮
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        loss = criterion(out, yb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} | Loss {loss.item():.4f}")

# ========= 6. 测试 =========
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        preds = out.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

print(f"Test Accuracy: {correct / total:.4f}")
相关推荐
张较瘦_8 小时前
[论文阅读] 软件工程 | REST API模糊测试的“标准化革命”——WFC与WFD如何破解行业三大痛点
论文阅读
大千AI助手8 小时前
梯度爆炸问题:深度学习中的「链式核弹」与拆弹指南
人工智能·深度学习·梯度·反向传播·梯度爆炸·指数效应·链式法则
荒野饮冰室9 小时前
目标检测定位损失函数:Smooth L1 loss 、IOU loss及其变体
深度学习·神经网络·损失函数·iou
百度森森9 小时前
禾赛AT128P 128 线混合固态激光雷达使用
深度学习
星期天要睡觉9 小时前
深度学习——ResNet 卷积神经网络
pytorch·深度学习·神经网络
悟乙己10 小时前
使用 BayesFlow 神经网络简化贝叶斯推断的案例分享(二)
人工智能·深度学习·神经网络
HUIMU_12 小时前
DAY20-新世纪DL(DeepLearning/深度学习)战士:终(目标检测/YOLO)3
深度学习·yolo·目标检测·滑动窗口·非极大值抑制·交并比·bouding box
AI绘画哇哒哒13 小时前
【值得收藏】手把手教你用PyTorch构建Transformer英汉翻译系统,从训练到推理
人工智能·pytorch·ai·语言模型·程序员·大模型·transformer
Vizio<13 小时前
《D (R,O) Grasp:跨机械手灵巧抓取的机器人 - 物体交互统一表示》论文解读
论文阅读·人工智能·机器人