论文浅尝 | 逐步蒸馏!使用少量训练数据和较小模型超越大语言模型

笔记整理:康婧淇,东南大学硕士生,研究方向为自然语言处理、信息抽取

链接:https://arxiv.org/abs/2305.02301

1. 动机

本文的动机是将大型语言模型(LLMs)的任务特定知识提炼到更小的专业模型中。作者改变了对LLMs的视角,将其视为可以进行推理的代理,能够生成自然语言的理由来解释其预测的标签。这些理由可以包含与任务相关的知识,例如"面积=长度×宽度",这些知识原本需要大量数据才能让小型任务特定模型学习。作者利用提取的理由作为额外的、更丰富的信息来训练小型模型,通过多任务训练设置,同时进行标签预测和理由预测任务。通过逐步蒸馏,我们可以学习到性能优于LLMs的任务特定小型模型,使用的模型参数比LLMs少500倍以上,并且相比传统的微调或提炼方法,所需的训练样本数量要少得多。这些小型模型在4个自然语言处理基准测试中取得了令人期待的实证结果,相比微调和提炼,我们的模型在平均数据集上使用的训练样本数量减少了50%以上,并且在性能上超过了LLMs,同时模型尺寸也大大减小(最多减小了2000倍),从而大大降低了模型部署所需的计算成本。此外,本文的方法还能够在只有无标签数据的情况下,性能与LLMs相当甚至更好。本文通过一个11B T5模型超越了540B参数的PaLM模型的性能。本文进一步展示了当一个小型模型的性能不如LLMs时,逐步提炼方法相比标准提炼方法能够更有效地利用额外的无标签数据来匹配LLMs的性能。

2. 贡献

本文的主要贡献包括:

(1)提出了一种逐步蒸馏的方法,从大型语言模型(LLMs)中提取rationales,并将其作为训练小型任务特定模型的信息监督。通过使用rationales,这种方法可以减少训练小型模型所需的训练数据集,并降低实现甚至超过原始LLMs性能所需的模型大小。

(2)实验结果表明,相比于微调和蒸馏,所提出模型在平均数据集上使用50%以下的训练样本数量时,性能更好(最高可减少85%)。此外,该模型在比LLMs小得多的模型尺寸下表现出色(最高可减少2000倍),从而大大降低了模型部署所需的计算成本。

(3)本文的方法可以更高效地利用额外的无标记数据,以提高小型模型的性能,相比标准的蒸馏方法,逐步蒸馏方法可以更有效地匹配LLMs的性能。

3. 方法

本文提出了逐步蒸馏新范式,利用 LLM 对其预测的推理能力,以数据高效率的方式训练更小的模型。整体框架如图 1所示。该范式有两个简单的步骤:首先,给定一个 LLM 和一个无标签的数据集,提示 LLM 生成输出标签以及证明该标签成立的理由。理由用自然语言解释,为模型预测的标签提供支持。理由是当前自监督 LLM 的一个涌现的行为属性。

图1:逐步蒸馏方法框架图

4. 实验

作者在实验中验证了逐步蒸馏的有效性。首先,与标准的微调和任务蒸馏方法相比,逐步蒸馏有助于实现更好的性能,训练实例的数量少得多,大幅提高了学习小型特定任务模型的数据效率。

其次,研究表明,逐步蒸馏方法以更小的模型大小超越了 LLM 的性能,与大型语言模型相比,大大降低了部署成本。

最后,本文研究了逐步蒸馏方法在超过 LLM 的性能方面所需的最低资源,包括训练示例数量和模型大小。论文展示了逐步蒸馏方法通过使用更少的数据和更小的模型,同时提高了数据效率和部署效率。

5. 总结

本文介绍了一种名为逐步蒸馏的新机制,用于训练较小的模型并减少训练数据的需求。该方法通过在多任务框架中利用大型语言模型(LLMs)的理由作为额外的监督来训练小模型。研究发现,与微调和蒸馏相比,这种机制在使用更少的标记/未标记训练样本的情况下实现了更好的性能。此外,与少样本提示的LLMs相比,使用更小的模型尺寸也能实现更好的性能。通过这种方法,研究人员成功地减小了模型尺寸和所需数据量,使得微调的770M T5模型在使用仅80%的可用数据时就能胜过少样本提示的540B PaLM模型,而标准微调相同的T5模型即使使用100%的数据集也难以达到相同的效果。这项研究对于解决大型语言模型在实际应用中的内存和计算资源需求问题具有重要意义。


OpenKG

OpenKG(中文开放知识图谱)旨在推动以中文为核心的知识图谱数据的开放、互联及众包,并促进知识图谱算法、工具及平台的开源开放。

点击阅读原文 ,进入 OpenKG 网站。

相关推荐
余炜yw8 分钟前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐25 分钟前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1231 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr1 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner1 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao1 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!1 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统
果冻人工智能1 小时前
OpenAI 是怎么“压力测试”大型语言模型的?
人工智能·语言模型·压力测试
日出等日落1 小时前
Windows电脑本地部署llamafile并接入Qwen大语言模型远程AI对话实战
人工智能·语言模型·自然语言处理
麦麦大数据1 小时前
Python棉花病虫害图谱系统CNN识别+AI问答知识neo4j vue+flask深度学习神经网络可视化
人工智能·python·深度学习