论文总结
1、本研究介绍了Trompt,一种用于表格数据分析的新型网络架构。Trompt 利用提示学习来确定单个样本中不同的特征重要性。
2、Trompt由多个Trompt和一个共享的下游Trompt单元组成。每个Trompt单元负责特征提取,而下游Trompt单元负责预测。
摘要
表格数据可以说是金融、医疗和电子商务等多个实际领域中最常用的数据结构之一。然而,基于最近发布的一份表格基准测试,我们可以看到深度神经网络在表格数据集上仍落后于基于树的模型(Grinsztajn 等,2022)。本文提出了Trompt------即Tabular Prompt------一种受语言模型提示学习启发的新颖架构。提示学习的核心是通过一组外部提示调整大型预训练模型,而无需直接修改模型。基于这一思想,特罗姆普特将表格数据的学习策略分为两部分,分别是表的内在信息和样本间的多样化信息。Trompt的评估基于上述基准。实验结果表明,Trompt的表现优于最先进的深度神经网络,且可与基于树的模型相媲美(见图1)。

图1。基准测试结果。
引言
表格数据在许多现实应用中发挥着重要作用,例如银行用来评估公司信誉的财务报表、医生用来识别患者病因的诊断报告,以及电子商务平台用来发现客户潜在兴趣的客户记录。一般来说,表格数据可用于记录由异构特征组成的活动,并且有许多实际用途。 另一方面,深度学习在多个领域取得了巨大成功,包括计算机视觉、自然语言处理(NLP)和机器人技术(He 等,2016;Redmon 等,2016;Gu 等,2017;Devlin 等,2018)。除了卓越的性能外,深度学习端到端优化特性还有诸多优势,包括(i)利用流式数据的在线学习(Sahoo 等,2017),(ii)多模型集成,包含不同类型的输入,如图像和文本(Ramachandram 和 Taylor,2017),以及(iii)实现半监督学习和生成建模的表征学习(Van Engelen 和 Hoos, 2020;Goodfellow 等,2020)。 因此,研究人员致力于将深度学习应用于表格数据,方法包括(i)变换器(Huang 等,2020;Somepalli 等,2021;Gorishniy 等,2021)或(ii)归纳偏倚调查(Katzir 等,2020;Arik 和 Pfister,2021)。 尽管许多先前的论文声称已达到最前沿,但进一步研究表明,之前的研究是在有利的数据集上进行评估,基于树的模型在表格数据领域仍表现出更优的表现(Borisov 等,2021;Gorishniy 等,2021;Shwartz-Ziv 和 Armon,2022)。为了公平对比在不同算法之间,(Grinsztajn 等,2022)提出了表式数据的标准基准。该基准测试在本作中称为Grinsztajn45,包含来自不同领域的45个精心策划的数据集。 本文提出了一种新颖的提示启发架构------Trompt,即Tabular Prompt的缩写。提示学习在近年来语言模型的发展中发挥了重要作用。例如,GPT-3 在适当的提示工程下能够很好地处理各种任务(Radford 等,2018;Brown 等,2020)。在Trompt中,提示词被用来推导不同样本中不同的特征重要性。图2中,Trompt由多个Trompt和一个共享的下游Trompt单元组成。每个Trompt单元负责特征提取,而下游Trompt单元负责预测。 Trompt 的性能基于 Grinsztajn45 基准测试进行评估,并与三种深度学习模型和五种基于树的模型进行比较。图1展示了Grinsztajn45的整体评估结果。横轴表示超参数搜索迭代次数,纵轴表示归一化性能。图1中,Trompt持续优于最先进的深度学习模型(SAINT和FT-Transformer),深度学习模型与树状模型之间的差距缩小。
我们的主要贡献总结如下:• 实验基于公认的表格基准Grinsztajn45进行。此外,我们还将两个表现良好的树状模型------LightGBM(Ke等,2017)和CatBoost(Prokhorenkova 等,2018)加入基线。
• Trompt在深度学习模型中实现了最先进的性能,缩小了深度学习模型与树状模型之间的性能差距。
• 进行了全面的实证研究和消融测试以验证Trompt的设计。这些结果进一步阐明了未来表式神经网络架构设计的研究方向。
相关工作
本节首先讨论语言模型的提示学习。其次,我们讨论了表格神经网络的两个研究分支:变换器和归纳偏倚研究。最后,我们讨论了特罗姆普特与相关作品的差异,并突出我们作品的独特性。
提示学习
提示学习的目的是将下游任务的输入和输出转换为用于构建预训练模型的原始任务。与微调不同,微调通常涉及修改任务并更新模型权重,带有提示的预训练模型可以专注于一个任务。通过提示学习,少量数据甚至零射程也能取得良好结果(Radford 等,2018;Brown 等,2020)。提示学习的出现极大提升了预训练模型的应用多样性,这些模型体积过大,普通用户难以微调。 要提示语言模型,可以在句子前插入任务特定的提示,并提示模型调整其针对不同任务的响应(Brown 等,2020)。提示可以是离散的也可以是软的。前者由自然语言词汇中的离散词汇组成(Radford 等,2018;Brown 等,2020),而后者则是学习表征(Li 和 Liang,2021;Lester等,2021)。
表格神经网络
**Transformer。**自2017年以来,自注意力机制彻底改变了自然语言处理(Vaswani 等,2017),并很快被计算机视觉、强化学习和语音识别等其他领域采用(Dosovitskiy 等,2020;Chen 等,2021;Zhang 等,2020)。Transformer块的目的是捕捉特征之间的关系,这些关系也可以应用于表格数据。 TabTransformer(Huang 等,2020)是首个基于 Transformer 的表神经网络。然而,TabTransformer 只向变换器块输入了类别特征,忽略了类别特征和数值特征之间的潜在关系。FT-Transformer(Gorishniy 等,2021)通过向变压器模块输入类别和数值特征来解决了这一问题。SAINT(Somepalli等,2021)通过关注不仅在特征维度,还对样本维度进行关注,进一步改进了FT-Transformer。 归纳偏倚调查。深度神经网络在具有明显归纳偏见的任务中表现良好。例如,卷积神经网络(CNN)在图像上表现良好。CNN的核设计用于捕捉局部图案,因为相邻像素通常相互关联(LeCun等,1995)。循环神经网络(RNN)在语言理解中被广泛应用,因为词语之间的因果关系通过循环单元得到了良好封装(Rumelhart 等,1986)。然而,与其他常见任务不同,表格数据的归纳偏差尚未被充分发现。 鉴于基于树的模型一直是表数据领域的最先进技术(Borisov 等,2021;Gorishniy 等,2021;Shwartz-Ziv 和 Armon,2022)、Net-DNF(Katzir 等,2020)和 TabNet(Arik 和 Pfister,2021)假设表格数据的归纳偏差可能是学习过程基于树的模型策略。该策略是通过选择部分特征并从非叶节点的特征中推导最优分割,来寻找最优的根到叶决策路径。为了模拟该学习策略,TabNet采用了顺序注意力和稀疏正则化。另一方面,Net-DNF理论上证明决策树等价于某种析取范式(DNF),并提出析取神经范式以模拟DNF公式。

图2。拟建的Trompt整体架构
Tprompt的独特性
我们认为表格数据的列重要性并非对所有样本都不变,可以分为多种模态。由于提示学习的诞生是为了将模型适应多项任务,因此该概念被用来处理多种模态。为此,Tprompt将表格数据的学习策略分为两部分。第一部分类似于预训练模型,侧重于学习表的内在列信息。第二部分类似于提示,侧重于多样化不同样本的特征重要性。 据我们理解,Trompt是第一个受提示启发的表格神经网络。与基于变压器的模型相比,特隆普特学习的是分离的列的重要性,而不是关注列之间的相互作用。与 TabNet 和 Net-DNF 相比,Trompt 通过模拟提示学习而非决策树的分支分裂来处理多种模态
Tprompt
在本节中,我们将详细介绍Trompt的架构设计。如图2所示,Tprompt由多个Tprompt单元和共享的下游Tprompt单元组成。每个Tprompt单元负责特征提取和提供多样化表示,而Tprompt下游单元负责预测。Tprompt单元和下游Tprompt单元的详细信息分别在第3.1节和第3.2节讨论。在第3.3节,我们进一步讨论了Trompt的提示学习。
Tprompt单元
图3展示了一个错视胞体的结构,可以分为三部分。第一部分基于列嵌入(Ecolumn)、前一个单元的输出(Oprev)和提示嵌入(Eprompt)推导特征重要性(Mimportance)。第二部分将输入转换为特征嵌入(Efeature),分别有两条路径用于类别列和数值列。第三部分扩展了Efeature,用于后续的乘法。 第一部分的细节见第3.1.1节,第二和第三部分的细节见第3.1.2节。最后,特罗姆普特单元的输出生成过程见第3.1.3节。
推导特征重要性
设E列∈RC是列嵌入×Eprompt∈RP是提示嵌入×。C 是数据集定义的表列数,P 和 d 分别是提示词数量和隐藏维度的超参数。Ecolumn 和 Eprompt 都是输入独立且可训练的。设Oprev ∈ RB×P ×为前一个单元的输出,B 是批次大小。 Oprev与提示嵌入融合为方程(1)和(2)。由于 Eprompt 是输入独立且没有批处理维度的,Eprompt 通过堆栈操作扩展为 SEprompt,作为方程(1)。之后,我们将 SEprompt 和 Oprev 连接起来,并将连接张量的维数减回至 RB×P ×d,作为最终加法,称为方程(2)。

图3。Tprompt单元的架构
出于与Eprompt相同的原因,E列扩展为SE列,作为方程(3)。随后,特征重要性通过方程(4)推导,其中⊗是批次矩阵乘法,⊺是批次转置,软最大值应用于列轴。

第一部分的输出是 Mimportance 的 ∈ RB×P ×C,该指标涵盖了 P 个提示词产生的特征重要性。注意列嵌入未连接到输入,提示嵌入与前一个单元格的输出融合。在第3.3节,我们进一步讨论了这些设计及其与NLP提示学习的联系。
构造和展开特征嵌入
在Trompt中,范式特征通过嵌入层嵌入,数值特征通过密集层嵌入,正如之前的研究(Somepalli等,2021;Gorishniy 等,2021)。嵌入构造过程在图3的第二部分中展示了,其中Efeature ∈ RB×C×d 是批次的特征嵌入。 Mimportance 和 Efeature 的形状分别是 RB×P ×C 和 RB×C×d。由于 Efeature 缺乏提示维度,Trompt 将 Efeature 扩展为 Eˆfeature ∈ RB×P ×C×d 以在图3第三部分中通过密集层容纳 P 提示。
生成输出
Tprompt单元的输出是元素相乘的逐列和,取ˆ特征和微重要性,作为方程(5),其中⊙为元素相乘。注意,在按元素乘法时,微重要性的形状被视为RB×P ×C×1。此外,由于列为第三轴,形状从RB×P ×C×d缩小为RB×P ×d,按列求和。

Tprompt下游
Tprompt下游基于Tprompt单元的输出做出预测,该输出包含对应P个提示嵌入的表示。为了聚合这些表示,首先通过稠密层和软极大激活函数(方程6)推导出每个提示的权重。之后,加权和作为方程(7)计算出来。 随后通过两个密集层作为方程(8)进行预测,其中T为目标维度。


对于分类任务,T 是目标类别的数量。对于回归任务,T设置为1。如图2所示,样本通过一个Tprompt单元获得预测,因此在所有单元格中都能得到多个预测。在训练过程中,每个预测的损失会分别计算,并将损失加总以更新模型权重。而在推断过程中,所有单元格的预测被简单地平均为最终预测,称为方程(9),其中L是Tprompt单元的数量。

Tprompt的提示学习
Trompt的架构专门为表格数据设计,考虑了这类数据的独特特性以及基于树的模型的卓越性能。与传统操作不同,该设计可能显得非常规且与表格数据特征脱节。在本节中,我们将解释Trompt网络设计背后的理由,以及我们如何将提示学习应用于表格神经网络。 表格数据是结构化的,每列代表一个特定的数据集属性,且在单个样本间保持不变。基于树的模型成功依赖于为单个样本赋予特征重要性。这一概念在TabNet(Arik & Pfister,2021)和Net-DNF(Katzir等,2020)等模型中得到了探讨。然而,基于树的算法并不会明确为单个样本赋予特征重要性。相反,重要性在从根节点到叶节点的路径上隐式变化。只有该路径中涉及的列被视为样本到达对应叶节点的重要特征,代表样本特有的特征重要性。 鉴于表格数据的基本特性和基于树的模型学习策略,Trompt 旨在结合列的内在属性与样本特有的特征重要性,采用 NLP 中以提示学习为灵感的架构(Radford 等,2018;Brown 等,2020)。Trompt利用列嵌入表示每列的内在属性,并用提示嵌入提示列嵌入,生成给定提示词的特征重要性。列嵌入和提示嵌入在样本间都是不变的。然而,在提示嵌入列嵌入之前,提示嵌入会与前一个特罗姆普特单元的输出融合,如方程(2)所示),使输入相关的表示能够流动并推导出样本特有的特征重要性。特罗姆普特中的"提示"机制通过方程(4)中的矩阵乘法实现。 表1中展示了Trompt提示学习方法与自然语言处理的概念类比。需要注意的是,由于表格数据和自然语言处理任务之间的提示学习实现细节存在显著差异,这两者领域存在根本差异。因此,必须做出适当的调整以桥接这两个领域。
实验
本节将呈现实验结果和分析。首先,我们在第4.1节详细说明实验设置和特罗姆普特的构型。其次,Grinsztajn45 上 Trompt 的性能报告于第4.2节。第三,关于超参数和特罗姆普特结构的消融研究见第4.3节。最后,第4.4节通过合成和现实世界数据集探讨了特罗姆普特的可解释性。
实验配置
Trompt 的性能和消融研究主要聚焦于 Grinsztajn45 基准(Grinsztajn 等,2022)1.该基准涵盖多个领域的数据集,采用统一的方法评估不同模型,提供公平且全面的评估。此外,我们评估了 FT-Transformer 和 SAINT 选定数据集中 Trompt 的性能,并将其与最先进的表格神经网络进行比较。 对于可解释性分析,我们遵循TabNet(Arik & Pfister,2021)的实验设置。这涉及使用两个合成数据集(Syn2和Syn4)以及一个真实数据集(mushroom)来可视化注意力掩体。 Grinsztajn45 的设置见 4.1.1 节,Trompt 的实现细节见 4.1.2 节。此外,FT-Transformer和SAINT选择的数据集设置分别载于附录B.2和附录B.3
GRINSZTAJN45的设定
为了公平评估性能,我们遵循Grinsztajn45的配置,包括列车测试数据拆分、数据预处理和评估指标。Grinsztajn45 包含两种任务:分类任务和回归任务。请参阅附录A.1和附录A.2,了解Grinsztajn45的数据集选择标准和数据集归一过程。任务进一步根据(i)数据集规模(中型和大型)和(ii)类别特征的包含(仅数值且异构)进行分组。 此外,我们做了以下调整:(i) 省略了(Grinsztajn 等,2022)中实验结果不完整的模型,(ii) 添加了两个表现良好的基于树的模型以供比较,(iii) Trompt 使用了一个比对手更小的超参数搜索空间。调整细节详见附录A.3和附录A.4。
实现细节
Trompt 是通过 PyTorch 实现的。默认超参数见表2。嵌入的大小和稠密层的隐藏维数配置为 d。注意,架构设计中只有列嵌入和提示嵌入的大小必须相同。隐藏层的 维度 设为 d,以减少超参数并节省计算资源。另一方面,提示词数量和特罗姆普特单元格数设置为P和L。请参阅附录F,了解所有基线和Trompt的超参数搜索空间。

评估结果
分类任务的结果见第4.2.1节,回归任务的结果在第4.2.2节讨论。分类和回归任务的评估指标分别是准确率和r2分数。本节报告附录B.1中单个数据集的整体结果和遗留结果。此外,FT-Transformer和SAINT选定数据集的评估结果分别收录在附录B.2和附录B.3中。
分类任务
在中等规模的分类任务中,图5显示Trompt优于DNN模型。在有或无范畴特征的任务上,Trompt曲线始终高于深度神经网络(如SAINT、FTTransformer和ResNet)。此外,Trompt缩小了深度神经网络与基于树的模型之间的差距,尤其是在具有异构特征的任务上。在图5b中,特罗姆普特似乎是拥有四个树基模型的领先簇成员。GradientBoostingTree 起初缓慢,但在搜索结束时能追上领先集群。其他深度神经网络形成第二个簇,且与领先簇有间隙。 在大型分类任务中,基于树的模型仍占领先地位,但与深度神经网络的差距尚不明确。这与深度神经网络需要更多样本进行训练相呼应(LeCun等,2015)。图6a显示,特罗姆普特在具有数值特征的任务中表现优于所有模型,图6b则显示特罗姆普特在具有异构特征的任务中,性能与FT-Transformer相当。 在超参数搜索空间较小的情况下,特隆普特曲线相对平坦。平坦曲线也表明特罗姆普特在默认超参数下表现良好。它的经过穷尽搜索后的性能值得未来探索

回归任务
在中等规模回归任务中,图7显示Trompt的曲线优于深度神经网络,因为在有或无类别特征的任务中,Trompt的曲线始终高于SAINT、FT-Transformer和ResNet。图7a中深度神经网络与基于树的模型之间的差距不如图7b明显。在仅涉及数值特征的任务中,特罗姆普特与随机森林的表现相当。对于具有异质特征的任务,特隆普特缩小了差距,但低于所有基于树的模型。 在仅有数值特征的大规模回归任务中,图8a显示特罗姆普特在搜索末端略低于SAINT和傅里变换器。在具有异构特征的大规模回归任务中,图8b显示特罗姆普特在较大优势下优于深度神经网络。 一般来说,深度学习模型不擅长处理范类特征。如图5中所有具有异质特征的任务所示,Trompt缓解了这一弱点。除仅在仅有数值特征的大回归任务中外,Trompt在性能上优于最先进的深度神经网络。
消融实验
在本小节中,我们将讨论Tprompt在超参数和结构设计方面的消融研究结果。有关消融研究的具体设置,请参阅附录C。在主条目中,我们报告了两个主要的消融点:
(i)提示词数量和
(ii)通过密集层扩展特征嵌入的必要性。其他消融可见附录D。 对提示数量的消融。提示嵌入(Eprompt)在推导功能重要性方面起着重要作用。这里我们讨论调整提示数量的效果。 如表3所示,将提示数设为1的结果更差。然而,将默认数字减半或翻倍(128)对表现影响不大。结果表明,只要提示词数量足够适应数据集的模态,Trompt对提示词数量并不敏感。 对扩展特征嵌入的致密层消融。图3的第三部分使用稠密层扩展特征嵌入以适应P提示。这里我们讨论稠密层的必要性。 如表4所示,增加密集层确实能带来更好的结果,这也是Trompt的关键架构设计之一。设计上,添加密集层可实现用Trompt为每个提示生成不同的特征嵌入。没有稠密层,Trompt会被简化为每个提示使用相同的特征嵌入。表3和表4的结果表明,特征重要性的变异------源自提示嵌入层和扩展密集层------是Trompt卓越性能的关键。

可解释性
除了卓越的性能外,基于树的模型还以其可解释性著称。本文探讨Tprompt是否也能提供简明的特征重要性,突出显著特征。为此,我们按照TabNet的实验设计,在合成数据集和现实数据集上进行实验(Arik & Pfister, 2021)。为了推导每个样本的特隆普特特征重要性,将 R×P ×C 中的重要性∈ R∈ R×C 的 Mˆ 重要性,作为方程(10),其中微重要性权重为方程(6)的 Wprompt。 注意所有特罗姆普特胞体都推导出分离特征重要性。我们在此展示了所有细胞的平均结果,并将每个单元的结果保留在附录E.1中

合成数据集。Syn2和Syn4数据集用于研究每个模型学习到的特征重要性(Chen 等,2018)。模型在超采样训练集(10k到100k)上使用默认超参数进行训练,并在随机抽取的20个测试样本上进行评估。该配置与TabNet中的配置相同(Arik & Pfister,2021)。
图9和图10比较了数据集的重要特征与Trompt学到的特征。在Syn2数据集中,特征2--5非常重要(图9a),特隆普特对它们进行了极佳的关注(图9b)。在Syn4数据集中,特征0--1或2--5可能因特征10的值而重要(图10a)。如图10所示,特罗姆普特仍然正确地关注特征0--5,并发现特征10的影响。

**真实世界的数据集。**蘑菇数据集(Dua 和 Graff,2017)被用作现实世界的可视化数据集,称为 TabNet(Arik 和 Pfister,2021)。仅凭气味功能,大多数机器学习模型>测试准确率可达95%(Arik & Pfister,2021)。因此,Odor 的特征重要性被期望很高。 表5展示了特隆普特的三个最重要的特征和五个基于树的模型。如图所示,所有模特都将Odor排在前三名。Trompt的第二和第三名,即鳃大小和鳃色,也出现在其他模型的前三名中。实际上,帽色仅由XGBoost选择。如果排除它,所有模型最重要的功能合并为四个。特隆普特遗漏的是孢子印彩色,这是特罗姆普特的第五名。总体而言,Trompt选定的重要特征与基于树的模型一致,因此可用于机器学习领域熟悉的各种分析。 为了进一步证明实验结果并非临时拼凑,我们在更多真实世界数据集上重复实验。详情请参见附录E.2实验结果

讨论
在本节中,我们将进一步探讨Trompt的"提示"机制。第5.1节阐明了提示学习特罗姆普特如何适用于表格数据的基本假设。此外,由于特罗姆普特部分灵感来源于基于树的模型的学习策略,我们在第5.2节进一步讨论了特罗姆普特与树模型之间的区别。
对Tprompt中"提示"机制的进一步探讨
特隆普特中的"提示"机制被实现为方程(4)。该方程涉及扩展提示嵌入(SˆEprompt ∈ RB×P ×d)的矩阵乘法,以及扩展列嵌入(SEcolumn ∈ RB×C×d)的转置。它会得到 Mimportance ∈ RP ×C ,表示提示到列的功能重要性。矩阵乘法计算了SˆEprompt与SEcolumn之间的基于余弦的距离,并有利于样本特异表示与样本不变的内在属性之间的高度相似性。 为了更清楚地说明,SˆEprompt 由针对单个样本特有的 P 嵌入组成,除了第一个 Trompt Cell,其中 Oprev 是零张量,因为没有之前的 Trompt Cell,如方程(1)和(2)所述。另一方面,SEcolumn由C嵌入组成,代表表式数据集特有的内在属性,如公式(3)所述。 与计算查询与键之间的距离并推导令牌间相似度度量的自注意不同,特朗普特计算方程(4)中SˆEprompt与SEcolumn之间的距离,以推导样本到内在性质的相似度度量。计算的基本思想是捕捉每个样本与表式数据集内在属性之间的距离,我们假设将内在属性纳入模型表格神经网络的生成可能有助于做出良好的预测。
基于树的推测模型与树状模型的区别
正如第3.3节讨论的,利用提示学习推导特征重要性的理念,灵感来自基于树的模型学习算法和表格数据的内在属性。因此,特罗姆普特模型和基于树的模型有一个共同特点,即它们能够实现样本相关的特征重要性。然而,它们之间有两个主要区别。首先,为了结合表格数据的内在属性,Trompt使用列嵌入在样本间共享列信息,而基于树的模型则通过节点拆分特性学习列信息。其次,特罗姆普特模型和基于树的模型使用不同的技术来学习特征重要性。特罗姆普特通过提示学习显式推导特征重要性,而基于树的模型则隐含地在根到叶路径中变化特征重要性。
总结
本研究介绍了Trompt,一种用于表格数据分析的新型网络架构。Trompt 利用提示学习来确定单个样本中不同的特征重要性。我们的评估显示,Trompt的表现优于最先进的深度神经网络(SAINT和FT-Transformer),并缩小了深度神经网络与基于树的模型之间的性能差距。 深度学习中即时学习的出现令人期待。虽然Trompt的设计可能不够直观或适合语言模型提示,但它展示了在表格数据分析中利用提示的潜力。这项工作提出了一种深度神经网络的新策略,以挑战基于树的模型,未来在这方面的研究可以探索更多受提示启发的架构。