在这篇 NeurIPS 2023 论文中,来自新加坡国立大学和字节跳动的学者们受人类联想学习的启发,提出了数据集扩增的新范式,有效地提升了深度模型在小数据场景下的性能和泛化能力,极大地降低了人工收集和标注数据的时间和成本。代码已开源。
众所周知,深度神经网络的性能很大程度上依赖于训练数据的数量和质量,这使得深度学习难以广泛地应用在小数据任务上。例如,在医疗等领域的小数据应用场景中,人力收集和标注大规模的数据集往往费时费力。为了解决这一数据稀缺问题并最小化数据收集成本,该论文探索了一个数据集扩增新范式,旨在自动生成新数据从而将目标任务的小数据集扩充为更大且更具信息量的大数据集。这些扩增后的数据集致力于提升模型的性能和泛化能力,并能够用于训练不同的网络结构。
该工作发现只是利用现存方法无法很好地扩充数据集。(1)随机数据增强主要改变图片的表面视觉特征,但不能创造具有新物体内容的图片(如下图的荷花依然是同一个,没有新荷花的生成),因此所引入的信息量有限。更为严重的是,随机数据增强可能会裁剪医学图像的病灶(变)位置,导致样本的重要信息减少,甚至产生噪声数据。(2)直接利用预训练的生成(扩散)模型进行数据集扩增也不能很好地提升模型在目标任务上的性能。这是因为这些生成模型的预训练数据往往与目标数据存在较大的分布差异,这导致它们所生成的数据与目标任务存在一定的分布和类别差距,无法确保所生成的样本带有正确的类别标签且对模型训练有益。
为了更有效地进行数据集扩增,该工作探究了人类的联想学习:给定一个物体,人类可以利用他们累积的先验知识轻易地想象物体的不同变体,例如下图狗子在不同种类、不同颜色、不同形状或不同背景下的变体。这一想象学习的过程对于数据集扩增非常有启发性,因为它不仅是简单地扰动图片中动物体的外观,而是应用丰富的先验知识来创造具有新信息量的变体图片。
然而,我们无法直接建模人类作为先验模型来进行数据想象。但幸运地是,近期的生成模型(如 Stable Diffusion,DALL-E2)已经展现了强大的拟合大规模数据集分布的能力,能够生成内容丰富且逼真的图片。这启发了该论文使用预训练的生成模型作为先验模型,利用它们强大的先验知识来对小数据集进行高效地数据联想和扩增。
基于上述想法,该工作提出了一个新的指导式想象扩增框架(Guided Imagination Framework, GIF)。该方法能够有效提升深度神经网络在自然和医疗图片任务上的分类性能和泛化能力,并极大地减少因人工数据收集和标注所带来的巨大成本。同时,所扩增的数据集也有助于促进模型的迁移学习,并缓解长尾问题。
接下来让我们看看,这一数据集扩增新范式是怎么设计的。
方法
数据集扩增的挑战和指导标准 设计数据集扩增方法会有两个关键挑战:(1)如何使生成的样本带有正确的类别标签?(2)如何确保生成的样本带有新的信息量,从而促进模型训练?为了解决这两个挑战,该工作通过大量的实验发现了两个扩增指导标准:(1)类别一致的信息增强;(2)样本多样性提升。
方法框架 基于所发现扩增指导标准,该工作提出了指导式想象扩增框架(GIF)。对于每个输入的种子样本 x,GIF 首先利用先验生成模型的特征提取器提取样本特征 f,并对该特征进行噪音扰动:。设置噪音(z,b)最简单的方式是采用高斯随机噪声,但是它无法确保所生成的样本具有正确的类别标签并带来更多的信息量。因此,为了进行有效的数据集扩增,GIF 基于其发现的扩增指导标准对噪声扰动进行优化,即。
所用到的扩增指导标准实现如下。类一致的信息量指标:;样本多样性指标:。通过最大化这两个指标,GIF 能够有效优化噪声扰动,从而生成既保持类别一致性,又带来更大信息量的样本。
实验
扩增有效性 GIF 具有更强的扩增有效性:GIF-SD 在 6 个自然数据集上平均提高了 36.9% 分类精度,并在 3 个医疗数据集上平均提高了 13.5% 分类精度。
扩增效率 GIF 具有更强的扩增有效率:在 Cars 和 DTD 数据集上,使用 GIF-SD 进行 5 倍扩增的效果甚至超过了使用随机数据增强进行 20 倍扩增的效果。
可视化结果 现有的数据增强方法无法生成新的图像内容,而 GIF 可以较好地生成带有新内容的样本。
现有的增强方法甚至裁剪医学图像的病变位置,导致样本信息减少甚至产生噪声,而 GIF 可以更好地保持它们的类别语义。
计算和时间成本 与人工数据收集和标注相比,GIF 能够极大地降低数据集扩增的时间和成本。
扩增数据的通用性 一旦完成扩增,这些数据集可以直接用于训练各种不同的神经网络模型结构。
提升模型泛化能力 GIF 有助于提升模型的分布外泛化性能(OOD generalization)。
缓解长尾问题 GIF 有助于缓解长尾问题。
安全性检测 GIF 生成的图像是安全且无害的。
基于上述实验结果,我们有理由相信通过模拟人类的类比与想象学习,该论文所设计的方法能够有效地扩增小数据集,从而提升深度神经网络在小数据任务场景上的落地和应用。