Accurate predictions on small data with a tabular foundation model
Accurate predictions on small data with a tabular foundation model | Nature
使用一种基于表格的模型来对小型数据实现准确预测
Abstract:
基于其他列来填充标签列中缺失值的基本预测任务对于各种应用至关重要。
Main:
然而,这些传统的机器学习模型有几个缺点。未经重大修改,它们在分布外的预测表现较差,并且难以将知识从一个数据集转移到另一个数据集 。最后,由于它们不传播梯度,因此很难与神经网络结合使用。
这种新的监督式表格学习方法可以应用于任何小型到中等规模的数据集,并且在样本数量最多为 10,000 个和特征数量最多为 500 个的数据集中表现出色。
Result:
Methods:
表格的结构设计
基于因果模型合成数据
TabPFN 的性能依赖于生成合适的合成训练数据集,这些数据集能够捕捉真实世界表格数据的特征和挑战。为了生成这样的数据集,我们开发了一种基于结构因果模型( SCMs )的方法。 SCMs 提供了一个正式的框架,用于表示数据背后的因果关系和生成过程。
- 生成流程首先采样高级超参数,例如数据集大小、特征数量和难度级别,以控制每个合成数据集的整体属性。
- 基于这些超参数,我们构建一个结构因果模型,该模型编码生成数据集的计算函数。每个节点包含一个向量,计算图中的每条边根据连接类型实现一个函数。
- 在第一步中,使用随机噪声变量生成初始化数据,并将其输入到图的根节点中,然后通过计算图传播以生成每个样本。
- 在第二步中,我们在图中随机采样特征和目标节点的位置,分别标记为F和T。
- 在第三步中,我们提取在采样的特征和目标节点位置处的中间数据表示。
- 在第四步中,我们对提取的数据进行后处理。
- 我们检索最终的数据集。
- 我们绘制特征对之间的交互图,节点颜色表示样本的类别。
- 首先,由于transformer是为序列设计的,它们将输入数据视为单个序列,而不是利用表格结构。
- 其次,机器学习模型通常用于拟合-预测模型中,在这种模型中,模型仅在训练集上拟合一次,然后重复用于多个测试数据集。
- 然而,基于transformer的ICL算法在一个步骤中接收训练和测试数据,因此同时执行训练和预测 。因此,当重新使用已拟合的模型时,它必须重新计算训练集上的计算。
- transformer架构是灵活的深度学习和基础模型的首选架构。使用所谓的注意力机制在序列项之间结合信息,从而使它们能够有效地捕捉长程依赖性并学习数据中的复杂关系。
- TabPFN解决了其中两个关键限制。
- 数据生成: 定义了一个生成过程(称为我们的先验),用于合成具有不同特征与目标变量关系的多样化表格数据集,旨在捕捉模型可能遇到的各种潜在情景。定义了一个生成过程(称为我们的先验),用于合成具有不同特征与目标变量关系的多样化表格数据集,旨在捕捉模型可能遇到的各种潜在情景。
- 预训练 :我们训练一个变换器模型,即我们的 PFN,来预测所有合成数据集中被掩盖的目标值,输入特征和未掩盖的样本作为上下文提供给模型。此步骤仅在模型开发期间执行一次,学习一个通用的学习算法,以便预测任何数据集。
- 真实世界预测:经过训练的模型现在可以应用于任意未见过的真实世界数据集。训练样本作为上下文提供给模型,模型通过 ICL(in-context learning,即上下文学习)预测这些未见数据集的标签。
- TabPFN 利用上下文学习( ICL ) ,这是导致大型语言模型表现出惊人性能的相同机制,生成了一种完全学习的强大表格预测算法。尽管 ICL 最初是在大型语言模型中观察到的,但最近的研究表明,通过 ICL ,转换器可以学习诸如逻辑回归等简单算法。先验数据拟合网络( PFNs )表明,即使是复杂的算法,如高斯过程和贝叶斯神经网络,也可以通过 ICL **进行近似。**ICL 使我们能够学习更广泛的可能算法空间,包括那些不存在封闭形式解的情况。
- TabPFN 的核心思想是生成大量的合成表格数据集,然后训练基于 transformer 的神经网络来学习解决这些合成预测任务。这种方法利用了 ICL 作为基于示例的声明式编程框架,用于算法的设计。
- ICL方法与标准的监督深度学习有着根本性的区别。通常,模型是根据数据集进行训练,在单个样本或批次上根据手工设计的权重更新算法(如Adam24)更新模型参数。在推理时,学习到的模型被应用于测试样本。相比之下,我们的方法是在多个数据集上进行训练,并且在推理时应用于整个数据集,而不是单个样本。在应用于实际数据集之前,模型会在数百万个代表不同预测任务的合成数据集上进行一次预训练。在推理时,模型接收一个包含标注训练样本和未标注测试样本的未见过的数据集,并在一个单一的神经网络前向传递中对这个数据集进行训练和预测。
- 引入了 TabPFN,这是一种针对小型到中型表格数据的基础模型。
- 在人工智能的历史上,手动创建的算法组件已经被性能更好的端到端学习组件所取代。在计算机视觉中,如SIFT(尺度不变特征变换)和HOG(方向梯度直方图)等手工设计的特征已被学习到的卷积所取代。在自然语言处理中,基于语法的方法已被学习到的转换器所取代。在游戏中使用的定制开局和终局库的设计已被端到端学习策略所取代。在这里,我们将这种端到端学习扩展到无处不在的表格数据领域。
- 表格数据的多样性使它们与未处理的文本和图像等模态区分开来。例如,在语言模型中,一个词的意义在不同文档中是一致的,而在表格数据集中,相同的值可能意味着完全不同的东西 。这种专业化导致了大量较小的、独立的数据集和相关模型的激增。举例来说,在流行的表格基准测试网站 openml.org 上,截至撰写时,76% 的数据集包含不到 10,000 行。
- 深度学习方法在处理表格数据时传统上一直面临困难,因为数据集之间以及原始数据本身的异质性 :表格包含各种尺度和类型的列,也称为特征(布尔型、分类型、有序型、整型、浮点型),还有不平衡或缺失的数据、不重要的特征、异常值等。这使得非深度学习方法,如基于树的模型,成为迄今为止最强有力的竞争者。
- 在2.8秒内,TabPFN在一个分类设置中超越了一个经过4小时调优的强大基线组合。
- 作为一种生成式变换器基础模型,该模型还允许微调、数据生成、密度估计和学习可重用嵌入。
- TabPFN是一种通过在数百万个合成数据集上学习而来的学习算法,展示了这种方法在算法开发中的强大能力。
- 通过提高不同领域的建模能力,TabPFN有潜力加速科学发现并在各个领域中增强重要决策。
- 尽管深度学习已经革新了从原始数据中的学习,并带来了众多高调的成功案例,但在过去的20年里,梯度提升决策树在表格数据领域占据主导地位。
- 在这里,我们介绍了表格先验拟合网络(TabPFN),这是一种表格基础模型,它在多达10,000个样本的数据集上显著优于所有先前的方法,并且训练时间大大减少。
- 表格数据,即按行列组织的电子表格,在从生物医学到粒子物理、经济学和气候科学等各个科学领域中无处不在。