[COLM 2024] V-STaR: Training Verifiers for Self-Taught Reasoners

  • 本文是对 STaR 的改进方法,COLM 是 Conference On Language Models,大模型领域新出的会议,在国际上很知名,不过目前还没有被列入 ccf list(新会议一般不会列入);
  • 作者来自高校、微软研究院和 Google Deepmind;

读完STaR后,最直观的想法,1)数据利用率不够,合理化依然没有利用好剩下的数据,而剩下的数据对提高模型性能可能很重要(硬骨头),此外,模型没有一次性答对的样本也没有用上(即剩下的样本),这些一次没答对的数据中,可能部分步骤是有益的,能否利用上?或者错误的步骤能否也利用上?

2)由于LLM有随机性,应该让它多次输出,然后选最好的答案。如何选最好,比起用分类任务来做,不如用排序任务来做。因此可能需要QA模型多次输出,然后让一个模型来对答案排序。

背景知识

  • STaR:参考另一篇我写的 blog

  • DPO

  • RFT/ORM

  • 参考策略(Reference Policy):一个已经训练好的基础模型,作为行为的参考标准

  • 代理(Proxy):在机器学习中,当我们难以直接优化某个目标时,可以优化一个相关的、更容易计算的替代目标

  • causal language modeling objective:指的就是 GPT 这种注意力,next token prediction时注意力只能基于前面的token,(与之相对的 BERT 是双向的,可以利用后面的token)。

Abstract

大型语言模型 (LLMs) 的常见自我改进方法(例如 STaR)会在自我生成的解决方案上迭代微调 LLMs,以提高其解决问题的能力。然而,这些方法丢弃了在此过程中生成的大量不正确的解决方案,可能忽略了此类解决方案中的有价值的信息。为了解决这个缺点,我们提出了 V-STaR,它利用自我改进过程中生成的所有正确和错误的解决方案来训练使用 DPO 的验证器,以判断模型生成的解决方案的正确性。该验证器用于在推理时从许多候选解决方案中选择一个解决方案。运行 V-STaR 进行多次迭代会产生逐渐更好的推理器和验证器,与使用 LLaMA2 模型的常见代码生成和数学推理基准的现有自我改进和验证方法相比,测试精度提高了 4% 到 17%。

1 Introduction

学习识别和纠正错误是人类智力的一个特征(Metcalfe,2017)。在处理复杂的任务时,例如编码或解决数学问题,我们可以识别推理中的错误并探索解决方案的替代路径。为了提高 LLMs 的推理性能,有几种方法利用 LLMs 的能力来生成解决方案并在训练期间检查这些解决方案的正确性,例如使用测试用例进行代码生成。这些自我完善的方法,例如 STaR Zelikman 等人 (2022)、RFT (Yuan et al., 2023) 和 ReST EM (Singh et al., 2023),通过对它们的自生成的解法进行微调来改进 LLMs并可选地迭代运行此过程。然而,所有这些方法都是数据效率低下的,因为它们只使用正确的解决方案,并丢弃不正确的解决方案,这通常是模型生成的解决方案的很大一部分,特别是对于具有挑战性的推理任务。

与自我改进正交,改进 LLM 推理的另一个有希望的方向是在测试时使用学习好的 LLM 验证器(Cobbe 等人,2021;Wang 等人,2023b)。具体来说,LLM生成多个候选解决方案,验证者对这些解决方案进行排名并选择最佳的一个。此类验证器通过在从冻结 LLM 生成的解决方案数据集上微调 LLM 进行训练,并标有最终正确性(ORM,Cobbe 等人,2021)或步骤分步人工注释(Lightman 等人,2024)。使用此类验证器允许 LLMs 权衡额外的测试时计算以获得更好的性能。

本文提出了 Verification for Self-Taught Reasoners (V-STaR)。 V-STaR 的关键思想是在迭代自我改进过程中利用正确和错误的 LLM 生成的解决方案来训练使用 DPO 的验证者。此外还使用正确的解决方案将 LLM 训练为生成器。迭代的自我改进过程产生了逐步改进的生成器,并在增强数据上进行了训练,从而为验证者训练带来了更高质量的完成结果和更具挑战性的负面示例。在测试时,验证者对生成器中的多个候选解决方案进行排序,并选择最佳的一个。

我们根据经验评估 V-STaR 提高 LLMs 的推理能力:(1)数学问题: GSM8K 和 MATH 子集, (2) 代码生成问题:MBPP 和 HumanEval。微调 LLaMA2 和 CodeLLaMA ,我们将 V-STaR 与其他自我改进(RFT、STaR)和基于验证的方法(ORM)、self-consistency 进行比较,以及使用非迭代 V-STaR 基线(RFT + Verifier)相同数量的生成样本来引导生成器和验证器。 V-STaR 的效果非常好,与之前的自我改进和基于验证的数学推理方法相比,测试准确率在数学推理上提高了 6% 到 17%,代码生成上提高了 4% 到 12%。值得注意的是,7B V-STaR 在 GSM8K 上超越了基础 LLaMA2 70B(8-shot),并且几乎与 HumanEval 上的 CodeLLaMA 34B(zero-shot)相匹配。

本文贡献:

  • 提出V-STaR,一种简单而有效的方法,它使用从LLM迭代生成的正确和错误解决方案来训练更好的生成器和验证器。 V-STaR 在数学推理和代码生成方面优于先前的自我改进方法(RFT、STaR)以及 ORM 验证。
  • 作为次要贡献,我们发现 DPO 对于训练验证者而言比 ORM 方法更有效。我们还提出了 Best-of-k 公式,类似于 Pass@k,通过验证可靠地评估测试性能。

2 Preliminaries

2.1 Self-improvement approaches

Self-Taught Reasoner: STaR 略。

Rejection Sampling Fine-tuning: RFT 方法(非迭代方法),先在 SFT 数据上微调,然后每个问题采样k个solution,只保留正确的,然后在这周方法增强的数据集上微调(只进行一次)。

STaR†:相比 STaR 的一次回答变为多次回答,并且保持迭代。单次迭代里类似 RFT。

2.2 Test-time verification

Cobbe 训练了 verifiers,也就是 outcome-supervised reward model (ORM),用于评估候选解决方案对于给定问题的正确概率。在测试时,语言模型 G 生成许多候选解决方案,并选择验证者排名最高的一个,也称为 best-of-k。类似 RFT,每个问题生成k个回答,然后对其标注 z i , j z_{i,j} zi,j是一个二元标签(0,1),表示正确与否。然后V在Dver上微调,给定(x,y)后预测(z)。

2.3 Preference learning with DPO

利用人类反馈微调预训练 LLMs 可以在下游任务中带来巨大的性能提升。DPO 改进了上述方法,在微调过程中不使用单独训练的奖励模型。(详细看 DPO 解读论文,本文略)。

3 V-STaR: Verifiers for self-taught reasoners

现有的自我改进方法,丢弃了模型生成的不正确的解决方案。然而,不正确的解决方案也可能包含有价值的信息,语言模型可以从给定问题的正确和错误解决方案之间的差异中学习,并识别生成中的错误模式,从而增强其提供更准确解决方案的能力。在这项工作中,我们提出了 V-STaR,它在迭代过程中利用错误和正确生成的解决方案来训练更好的生成器和验证器。

V-STaR 训练和 ORM 等之间的区别是我们的验证者训练数据是迭代收集的,每次迭代都来自更好的生成器,而 ORM 仅从固定生成器收集数据,该固定生成器仅在原始 SFT 数据上进行微调。我们将此 ORM 方法作为基线进行比较,如第 4.1 节中所述。

3.1Training verifiers with DPO

一些补充:训练 verifier 有两个思路,把 next token prediction 和二元分类损失加权做多任务学习,另一个思路是,二元分类本质上是想要鼓励模型输出分类正确的答案,这个其实本质上是鼓励模型按照二元分类正确的方向回答,是一种偏好。所以也可以视为 preference learning(偏好学习)。通过偏好学习可以直接统一这个多任务目标。这本质上变成了一种 alignment (对齐)。

当前的 LLM 验证者是通过语言建模和二元分类损失相结合的方式进行训练的,这两个目标可以通过离线偏好学习方法来统一,例如 DPO。其中近似参考策略是语言建模目标的 proxy,而分类损失是奖励建模的 proxy。根据经验,我们发现使用 LoRA 适配器时,DPO 验证器比 ORM 形式的验证器更好。

  • 偏好对数据集:正确答案和错误答案的笛卡尔积构成的 preference pair;

4 Empirical results

为了证明 V-STaR 的有效性,我们在两个广泛使用的数据集上进行了实验:用于解决数学问题的 GSM8K (Cobbe et al., 2021) 和用于解决代码生成问题的 MBPP (Austin et al., 2021)。们还使用 Hendrycks' MATH (Hendrycks et al., 2021) 和 HumanEval (Chen et al., 2021) 评估 V-STaR 的迁移泛化性能。具体来说,对于数学推理,我们仅使用 GSM8K 训练数据训练生成器和验证器,并在整个 GSM8K 测试集和 MATH 测试集的子集上对其进行评估。对于代码生成,我们使用 MBPP 训练数据训练模型,并在 MBPP 和 HumanEval 的完整测试集上对其进行评估。

Models

在实验中,我们使用 LoRA (Hu et al., 2022) 微调 LLaMA2 (Touvron et al., 2023) 和 CodeLLaMA (Rozière et al., 2023) 7B 和 13B 模型。生成器使用因果语言建模目标进行训练,我们的基线 (V-STaR[1 Iter]) 和 V-STaR 验证器使用 DPO 进行训练。 DPO 的参考策略 G S F T G_{SFT} GSFT 分别在 GSM8K 和 MBPP 的 2 和 3 epoch 的原始训练数据上进行训练。

Data generation

对于每次迭代, k=16 补全是从前一次迭代的生成器中对每个查询进行采样的。对于 GSM8K,第一个迭代样本来自仅在原始 GSM8K 训练数据上训练 2 个 epochs 的生成器。对于 MBPP,此数据来自 3 次预训练的 CodeLLaMA(参见 §B)。通过检查数学问题的最终答案和运行编码问题的测试用例来标记完成的正确性。

4.1 Baselines and metrics

在推理时,我们使用生成器为每个测试问题生成 128 个候选解决方案。我们报告生成器的 Pass@1准确度,它估计从生成器随机采样的答案正确的概率。所有基于验证者的方法都使用 best-of-64 精度,使用(eq 3)中的公式以及 self-consistency baseline。

生成了128个,只用了 best-of-64,有意思

4.2 Reliable estimation of Best-of-k accuracy

  • 原本的方法:生成k个回答,用 V 排序,选最高分作为答案,重复这个过程多次取平均值。
  • 优化:【主要是去掉了多次重复】,生成一组固定的 N (N>k,本文N=128,k=64),然后不放回抽取 k 个样本,直接采用其中 V 打分最高的作为结果。best-of-k 通过重复这个过程 M 次取平均得到(等效数学公式 eq3)。
    【这里是个组合公式:这个公式大概思路是,先统计 N choose k 有多少 case,然后再来计算从 N 中 选 k 能选对的 case 数量,得到 k 次抽样能正确的概率,然后分子大意是,因为k个里面每次都是以score最大作为最终answer,所以不妨先按V的分数降序排序,然后进行组合,先假设一定会选中最大的,然后通过剩下的里面抽(k-1)个,来计算这种 k 抽样的 case 有多少次,同理,固定第二大的,然后从剩下的中抽样(k-1)个,之所以不抽样比他大的,是因为如果选到比他大的,情况就退化为第一种情况,最终结果选最大的。。通过这个公式计算可以避免M次重复实验】

4.3 V-STaR on Math Reasoning and Code Generation

【质疑1:哥们,你这里就用了几十条样本微调?V-STaR 难道不是数据量更大才能体现出优越性吗?STaR用了1w条,这里差了这么多?】

如图 2 所示,V-STaR 在 LLaMA2 7B 和 13B 模型的 GSM8K、MBPP、MATH 子集和 HumanEval 测试集上相比 baseline 显示出增益。在数学方面,我们报告测试精度比 STaR† 和 Verification 提高了 6% 到 17%,代码生成任务提高了 4% 到 12%。图 3 中 V-STaR [1 iter] 的增益表明,与具有相同生成预算的非迭代方法相比,迭代生成收集验证者训练数据的解决方案可带来更好的分布和质量。我们在图 7 中显示了 MBPP 上生成器和验证器每次迭代的增益。我们还在 MBPP 上训练了生成器和验证器的四次迭代,这导致了 0.3% 的边际增益。


Out-of-domain performance of V-STaR 在 MBPP 上训练的生成器和验证者在 HumanEval 上进行评估,而在 GSM8K 上训练的生成器和验证者在 MATH 测试集的子集上进行评估(见图 2 和图 4)。一般来说,我们观察到所有方法的绝对 Pass@1 和 Best-of-64 分数较低,因为这两个任务被认为比 GSM8K 和 MBPP 更难。也就是说,迭代 V-STaR 在任务和跨模型大小上都优于基线 和 V-STaR [1 iter]。利用不正确的解决方案来训练验证器比仅使用使用正确的模型生成的解决方案(STaR or RFT)来 bootstrap 有很大的改进。虽然由于计算约束,我们使用 LoRA 适配器,但我们假设 V-STaR 的增益可能会因完全参数微调而更大。

Best-of-k accuracy. 图 5 显示了 k = 1 到 k = 64 的测试准确度,每个测试问题计算 128 个候选解决方案。Best-of-1 等价于 Pass@1 并忽略验证者分数。k ≥ 16 的最佳 k 饱和,V-STaR [1 Iter] 和 V-STaR 之间的差距保持一致。

4.4 Comparing DPO vs. ORM verifiers

我们使用 LoRA 适配器训练了 ORM-style 验证器,如 §2.2 中所述。与基于 DPO 的验证器相比,这些验证器的性能似乎确实相对较差。图 5(a) 显示了在相同训练数据上使用 DPO 训练的 V-STaR [1 Iter] 和 ORM 风格验证器之间的比较。在 GSM8K 任务中,ORM 无法有效地搜索生成的候选解决方案,以获取候选数量超过 4 的候选解决方案。对于超过 16 的候选解决方案数量,ORM 样式验证器的性能也比 MBPP 中基于 DPO 的验证器差。

4.5 How many completions can V-STaR be extended to?

图 6 显示了 V-STaR 7B 在 GSM8K 上的最佳 K 准确度,由方程 eq3 计算。V-STaR 在搜索大量候选解决方案时优于多数投票 (Wang et al., 2023c)(参见附录中的 §F)。虽然V-STaR在k≤64时比多数投票更有效,但对于较大的k值,性能差距开始略有下降,类似于Cobbe等人(2021)报告中的性能下降。此外,V-STaR 可用于任何问题解决任务,我们可以在其中验证正确性,而多数投票不适用于代码生成等任务。我们还尝试将验证者分数与重新排序策略相结合,例如加权重新排序和加权多数投票 Liu et al. (2023),但没有观察到性能提升。

4.6 Evaluating DPO verifier as a generator

由于 DPO 微调模型也可以用作生成器,我们评估了 DPO 验证器的生成能力。图 6 显示了 V-STaR 验证器在训练更新上的 Pass@1 和 Best-of-64, β 系数代表与 DPO 目标中的 SFT 策略接近程度。验证者的解决能力在仅少量训练更新之后便开始下降。相比之下,使用 DPO 目标进行验证似乎是有效的,因为模型的 Best-of-64 仅通过 2k 训练更新显着增加。

4.7 Should the verifier be in the training loop?

【这一节感觉完全是为了应付审稿人加的。你写了个寂寞啊,这个图根本看不懂】

可以在每次迭代时训练中间验证者并过滤正确的解决方案。我们尝试将验证器放入训练循环中,以从生成器中筛选出正确的解决方案,以进行下一次训练迭代。为此,我们从生成器中对每个查询k=64 完成进行了采样,标记其正确性,并根据验证者分数仅选取前 8 个,从错误集中抽取尽可能多的样本,以便每个查询的正确和错误完成总数为 16 或更少。在 MBPP 循环中使用验证器运行 3 次迭代后,最终的 Best-of-64准确度和 Pass@1分别为 53.2 和 46.34。

我们的结果表明,让验证者参与训练循环并不能为这项任务带来实质性的好处。 V-STaR 更简单,循环中没有验证器,并且不需要在每次迭代时训练验证器;然而,我们没有在每次迭代时尝试其他任务和生成器的不同采样策略。我们将这个问题的更详细的研究留给未来的工作。

4.8 Gains across V-STaR iteration

图 7 显示了 MBPP 上生成器和验证器每次迭代所取得的改进。验证者迭代(从 Ver.1 到 Ver.3)的增益比生成器迭代更大,这凸显了验证者的重要性。

具有挑战性的多步推理任务推动了LLMs的创新研究,例如通过中间步骤来回答给定问题(Wei等人,2022;Kojima等人,2022)。最近的大量工作研究了如何提高这些中间步骤的正确性并降低获得正确解决方案的成本。

Self-training and self-improvement. STaR、reinforced self-training (Gulcehre et al., 2023)、rejection fine-tuning (Yuan et al., 2023) 等等。

  • 对比损失来使正确的解决方案比错误的解决方案更有可能
  • 使用成功解决方案的中间状态作为监督来改进 credit 分配
  • 基于强化学习的 LLM 微调是很困难的,除非它通过一些监督微调步骤进行初始化
  • 使用更强大的LLM来编辑较小模型生成的不正确的基本原理,并为其微调提供积极的数据

Training verifiers

  • 为数学推理任务引入了验证器------对推理链进行评分或排名的模型,
  • 过程监督(rationale的正确性)相对于结果监督(无论答案是否正确)增强了微调LLMs的性能
  • 为单个推理步骤导出奖励信号的方法,结合解决solution-level and step-level验证器,并用辅助信息(例如程序执行结果)增强验证器
  • rationale generation 被视为图搜索问题,要么使用 stepwise verifier 来指导搜索,要么通过蒙特卡洛来估计步骤的质量

验证者可以被视为基于人类注释训练的奖励模型------通过人类反馈来训练满足验证者需求的 RL 实例(Ziegler 等人,2019)------或者基于合成数据,从而形成 RL 的形式具有人工智能反馈(Bai et al., 2022;Yang et al., 2023)。验证者也可以被视为生成模型,例如通过以指示解决方案的正标签或负标签的控制令牌为条件(Korbak 等人,2023)或通过提取分数作为候选者之后特殊令牌的可能性解决方案(Liu 等人,2023)。

6 Conclusion

略。

Appendix

A Algorithm

相关推荐
一 铭16 小时前
《Hands_On_LLM》8.2 RAG: 利用语言模型进行语义搜索(Semantic Search with Language Models)
人工智能·语言模型·大模型·llm
网安打工仔20 小时前
斯坦福李飞飞最新巨著《AI Agent综述》
人工智能·自然语言处理·大模型·llm·agent·ai大模型·大模型入门
健忘的派大星20 小时前
【AI大模型】根据官方案例使用milvus向量数据库打造问答RAG系统
人工智能·ai·语言模型·llm·milvus·agi·rag
Milkha2 天前
大模型训练工具,小白也能轻松搞定!
llm·模型训练
HyperAI超神经2 天前
超越 GPT-4o!从 HTML 到 Markdown,一键整理复杂网页;AI 对话不再冰冷,大模型对话微调数据集让响应更流畅
人工智能·深度学习·llm·html·数据集·多模态·gpt-4o
阿正的梦工坊3 天前
使用Sum计算Loss和解决梯度累积(Gradient Accumulation)的Bug
llm
yuanlulu3 天前
昇腾环境ppstreuct部署问题记录
人工智能·深度学习·llm·ocr·ppstructure
高性能服务器3 天前
英伟达 2025 CES:GPU与智算中心协同驱动 GPU算力智能变革
大数据·语言模型·llm·aigc·gpu算力·智算中心·ai算力
uncle_ll5 天前
ChatGPT大模型极简应用开发-目录
人工智能·gpt·chatgpt·大模型·llm
AI趋势预见5 天前
基于金融新闻的大型语言模型强化学习在投资组合管理中的应用
人工智能·深度学习·神经网络·语言模型·自然语言处理·金融·llm