Paper name
MAMMOTH: BUILDING MATH GENERALIST MODELS THROUGH HYBRID INSTRUCTION TUNING
Paper Reading Note
Paper URL: https://arxiv.org/pdf/2309.05653.pdf
Project URL: https://tiger-ai-lab.github.io/MAmmoTH/
Code URL: https://github.com/TIGER-AI-Lab/MAmmoTH
TL;DR
- 2023 年俄亥俄州立大学提出的大语言模型数理性能优化的文章,在训练数据集上做了一些探索,取得了超过现有开源 SOTA 方案(比如 WizardMath)的效果,基于 llama2 作为基座模型,在比较难的 MATH 数据集上精度提升了 22 个点(13.5->44.2)
Introduction
背景
- 尽管 LLM 领域近年来取得了显著进展,但封闭源模型和开源模型之间存在明显差距
- 封闭源模型如GPT-4,PaLM-2 和 Claude 2 在流行的数学推理评测集上占据主导地位,如 GSM8K 和 MATH,而开源模型如 Llama,Falcon,OPT 在所有基准上都远远落后
- 弥合上述这一差距的当前努力有两方面:
- 持续的预训练,如 Galactica 和 MINERVA,继续在超过 100B token 的数学相关网络数据上训练LLM。这种方法提高了模型的一般科学推理能力,但需要高昂的计算成本。
- 针对特定数据集的微调,如拒绝采样微调(RFT)和 WizardMath,使用特定于某些数据集的监督数据微调LLMs。尽管这些方法可以提高领域内性能,但它们不能推广到超出其微调数据范围的更广泛的数学推理任务。例如,RFT 和 WizardMath 可以使 GSM8K 的准确率提高 30%+,这是它们微调的数据集之一,但会导致在领域外数据集(如MMLU-Math 或 AQuA 上的准确率下降高达 10%。
- 数学推理是评估LLMs执行复杂多跳和定量推理能力的重要标志。在以前,这对神经网络来说是一项具有挑战性的任务,它们甚至难以解决基本的加法和减法问题。然而,最近的LLMs在数学推理方面取得了相当大的进展。关键突破是通过CoT提示和PoT提示
- CoT 提示鼓励LLMs在草稿上逐步解决问题,增强了数学推理的准确性和可解释性。但在计算精度以及复杂的数学或算法推理过程方面表现出困难(例如,解二次方程的根和计算矩阵特征值)
- PoT 提示将中间推理过程制定为一个程序,使用外部工具如 Python 执行以计算答案。通过将计算卸载到外部工具,这种方法提高了解决复杂数学问题的鲁棒性(例如,使用 sympy 解二次方程或使用 numpy 计算矩阵特征值)。但是,在处理更抽象的推理情景时,如常识推理、形式逻辑和抽象代数,特别是在没有内置 API 的情况下,PoT 的表现不佳
- 大多数现有的PoT工作局限于像 GPT-4 和 Codex 这样的专有模型。开源模型的 PoT 潜力尚未见分晓。我们的工作旨在通过指令调整优化 LLMs 的 CoT 和 PoT 推理能力。
本文方案
- 本文旨在提出一种轻量级但通用的数学指令调整方法,以增强 LLMs 的一般(即不限于微调任务)数学推理能力。
- 提出了 MAmmoTH,这是一系列专门针对一般数学问题解决而设计的开源大型语言模型(LLMs)
- 提出了 MathInstruct 数据集,由 13 个数学数据集组成,其中 6 个是本文新提出的:
- 包含思维链(CoT)和思维程序(PoT)两种构造方式的混合。CoT和PoT的混合不仅释放了工具使用的潜力,还允许针对不同数学问题采用不同的思维过程。
- 包含思维链(CoT)和思维程序(PoT)两种构造方式的混合。CoT和PoT的混合不仅释放了工具使用的潜力,还允许针对不同数学问题采用不同的思维过程。
- MAmmoTH 系列在所有规模的九个数学推理数据集上都明显优于现有的开源模型,平均准确率提高了 13% 到 29%
- MAmmoTH-7B 在 MATH 评测数据集上达到了 35% 的精度,超过之前最佳的开源 7B模型(WizardMath)25%
- MAmmoTH-34B 模型在 MATH 上实现了 46% 的准确率,甚至超过了 GPT4 的 CoT 结果
Methods
训练数据集准备
- 旨在编制一份高质量且多元化的数学指令调整数据集清单,具有以下特点:
- (1) 广泛涵盖不同数学领域和复杂性水平:确保接触到各种数学知识,促进模型的多功能性。
- 将选择范围缩小到一些广泛采用的高质量数据集,涵盖不同的数学领域和复杂性水平,如 GSM8K、MATH、AQuA、Camel 和 TheoremQA。
- 现有数据集中缺乏大学级别的数学知识,如抽象代数和形式逻辑。使用 GPT-4 合成了 TheoremQA 问题的 CoT 解释,并通过 Self-Instruct 创建了 question-CoT 对,利用网上找到的一些示例作为种子
- (2) 混合的 CoT & PoT 解释:将 CoT 解释和 PoT 解释混合到数据集中。大多数现有数据集提供的程序解释有限,导致 CoT 和 PoT 解释之间存在不平衡。为了填补这一空白,我们利用 GPT-4 来补充选定数据集的 PoT 解释,包括 MATH、AQuA、GSM8K 和 TheoremQA。然后,我们通过将这些由 GPT-4 合成的程序与人工注释的标准答案进行比较,来过滤这些合成的程序,从而确保了添加解释的高质量。按照这些准则,我们的指令数据集,包含 26 万个(指令、回答)对,涵盖了广泛的核心数学领域(算术、代数、概率、微积分、几何等等),包括混合的 CoT 和 PoT 解释,以及在语言和难度水平上提供了多样性。这证明了它的高质量和独特特点。
- (1) 广泛涵盖不同数学领域和复杂性水平:确保接触到各种数学知识,促进模型的多功能性。
训练细节
- 统一数据集格式为 Alpaca 格式
- llama2 作为基础模型,测试了 7B、13B、34B 和 70B 系列模型
- 对于 7B 和 13B 模型,我们使用了学习速率为 2e-5,对于 34B 和 70B 模型,我们使用了学习率为 1e-5
- batch size 为 128
- 训练 3 个 epoch
- 使用 deepspeed zero3 进行训练
Experiments
评测数据集
- 领域内数据集,即和训练集同源
- GSM8K、MATH、AQuA-RAT、NumGLUE
- 领域外数据集,即和训练集不同源
- SVAMP、Mathematics、SimulEq、SAT-Math、MMLUMath
- 这些广泛选择的评估数据集包括来自初中、高中和大学水平的数学问题。其中一些数据集甚至包括形式逻辑和常识推理。选择这些数据集是为了确保全面评估模型在处理陌生情况和不同数学领域时的能力。所选择的评估数据集包括开放式问题和多项选择题。我们对开放式问题采用 PoT 解码(例如,GSM8K、MATH),因为大多数问题可以通过程序来解决。我们对多项选择题采用 CoT 解码(例如,AQuA、MMLU),因为这些数据集中的大部分问题可以通过 CoT 更好地处理。CoT 解码不需要任何触发词,而 PoT 解码需要一个触发短语"让我们写一个程序来解决这个问题"(Let's write a program to solve the problem)。
对比模型
-
对比的基线模型分为以下类别:
- 闭源语言模型(LLM):我们考虑了 4 个闭源语言模型,包括 GPT-4,GPT-4(代码解释器),PaLM-2 Unicorn,Claude-2 和 Codex。GPT-4、PaLM-2 和 Claude-2 使用 CoT 提示,而GPT-4(代码解释器)和 Codex 使用 PoT 提示。
- Llama 基础模型:对于基础模型,我们考虑了Llama-1/2,Llama-2-Chat。
- 代码模型:为了与不同的编码器模型进行比较,我们选择了 Code-Llama,CodeT5+ 和 CodeGen。
- STEM 预训练:我们主要涵盖了 Galactica,以了解专门用于 STEM 知识的模型的性能。
- 指令调整:我们包括了 Orca-Platypus,Vicuna-1.5,Tulu,Platypus-2 和 Guanaco。我们涵盖了一系列使用不同类型数据集训练的模型。
- 数据集特定微调:我们包括了 RFT 和 WizardMath,这两者都是专门针对 GSM8K 和 MATH 数据集进行微调的模型。我们将它们包括在内是为了了解它们的泛化能力。
-
对于大多数基线模型,我们选择 CoT 提示以最大化它们的性能,因为它们在程序生成方面表现不佳。所有的"代码模型"都使用 PoT 提示。对于 GSM8K、MATH、AQuA 和 NumGLUE,我们将评估 8-shot in-context-learning 和 zero-shot 两种设置,以报告最高分数。对于 SVAMP、Mathematics、SimulEq、SAT和MMLU,我们使用 5-shot in-context-learning 以保持与之前的工作一致。我们的 few-shot 示例主要来自 PHP1。对于 MAmmoTH 和 MAmmoTH-Coder,我们始终在 zero-shot 设置下进行评估。对于所有模型,我们允许最大长度为 2048 个 token 以进行解码。
主要评测结果
-
in-domain 评测结果,MAmmoTH 和 MAmmoTH-Coder 可以超过目前开源的 SOTA 模型,其中 MATH 数据集涨点幅度尤为突出。不过 70B 下的模型和 wizardMath 在 gsm8k 数据集上还有 5 个点差距
-
out-of-domain 评测结果,在 out-of-domain 上的涨幅比 in-domain 上还大,作者没有具体讲原因,个人认为这里和训练数据集的多样性有关。另外一个值得关注的点是 codellama 效果明显优于 llama2,MAmmoTH-Coder 效果也明显优于 MAmmoTH,说明基座模型使用代码预训练有利于后续数理的模型迭代
数据集消融实验
-
基于 llama2 7B 模型对不同数据集进行消融实验,PoT 效果优于 CoT,另外混合 CoT 和 PoT 数据训练能有最优效果,同时也远优于只基于 gsm 和 math 训练的效果
-
作者任务这种综合性提升来自两个方面:
- CoT 子集可以帮助维持通用的基于语言的推理技能,以处理 PoT 处理不好的情况,例如 AQuA、SAT 和 MMLU 中的多选题
- PoT 子集可以教会模型如何利用 Python API 以高精度解决复杂的数学问题,例如需要复杂计算的 MATH 问题。以下案例研究展示了 PoT 和 CoT 在解决不同类型数学问题时的各自优势
-
pot 的优势案例:解题步骤明显少于 cot 的例子,数值计算正确率高,
-
pot 无法处理的案例:感觉主要是一些选择题不适合用 pot,比如需要对每个选项做一些简单确认,或一些偏逻辑的计算表达
-
总之,本文显著提升归因于:
- 1)涵盖不同数学领域和复杂级别的多样数据来源
- 2)CoT 和 PoT 指导微调策略的混合
主要子集的影响
- 鉴于在 MAmmoTH 训练中使用了多样的 MathInstruct 数据源,了解每个数据集对模型整体性能的贡献非常重要。我们专注于四个重要的子集:GSM8K、MATH、Camel 和 AQuA。我们进行了一项实验,逐渐将每个数据集添加到训练中,并将其性能与在整个 MathInstruct 上进行微调的性能进行比较
当训练初期的数据不够多样化(例如,仅使用 GSM8K 数据)时,总体的泛化性能非常差:模型只适用于分布内的数据,难以回答超出 GSM 问题范围的问题。然后,逐渐添加其他主要子集后,除了在它们自己的测试集上看到改善之外,我们还可以观察到 MAmmoTH 变得更加擅长数学问题。这些结果强调了多样的数据源对 MAmmoTH 性能的重大影响,这是使 MAmmoTH 成为数学通才的核心方面。这些结果还为未来的数据策划和收集工作提供了有价值的见解(例如,我们应始终收集多样的数据,避免仅收集特定类型的数据)。为了帮助理解本文所提出的 6 个新的精选数据集的贡献,我们将它们从 MathInstruct 中移除,并在现有数据上训练一个模型。如上表所示,我们新精选的数据显著提高了许多数据集的性能,使总体性能提高了6.3%,这反映了新精选数据集的重要性。
Thoughts
- 本文的基本结论还是符合直觉的,比如
- 训练集需要足够丰富,更多来源的数据集训练的模型有更强的泛化性
- PoT 和 CoT 各有优势,结合两者的优点能取得最佳效果
- 基于代码增量预训练后的模型作为基座模型能提升数理能力
- 测试阶段如何同时结合 PoT 和 CoT 进行精度优化本文没有过多介绍,应该有研究空间