自学习AI的新突破:微软·谷歌联合研发V-STaR,通过学习错误大幅提升解题准确率!

引言:自学习在大型语言模型中的新进展

该文提出了一种新的方法------V-STaR,传统的模型只采用正确答案训练生成器,而V-STaR利用在训练过程中生成的正确和错误答案来训练一个验证器,该验证器使用DPO加强生成器的性能。而在推理时,V-STaR可以从多个候选答案中选择一个最佳答案。

论文概览:V-STaR方法的提出背景与核心贡献

1. 论文标题、机构、论文链接

  • 论文标题:V-STaR: Training Verifiers for Self-Taught Reasoners
  • 机构:Université de Montréal; Microsoft Research; Google Deepmind
  • 论文链接:arxiv.org/pdf/2402.06...

V-STaR方法详解:结合正确与错误答案的自学习

1. V-STaR的核心思想

V-STaR(Verification for Self-Taught Reasoners)是一种新颖的自学习方法,它在传统的自学习方法基础上进行了改进。

  • 传统方法,如STaR,通过迭代地对大型语言模型(LLMs)进行微调,使用自生成的正确方案来提高其解决问题的能力。
  • 然而,这些方法通常会丢弃过程中产生的大量错误答案,而这些错误答案可能包含有价值的信息。
  • 语言模型可以通过学习正确和错误答案之间的差异,识别生成的错误的模式,从而提供更准确解决方案。

V-STaR的核心思想是在自学习过程中利用正确和错误的答案来训练一个验证器,该验证器使用DPO(Direct Preference Optimization)来判断模型生成的答案的正确性。

2. V-STaR的流程

  • 首先,利用训练数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> D S F T D_{SFT} </math>DSFT微调预训练LLM <math xmlns="http://www.w3.org/1998/Math/MathML"> G b a s e G_{base} </math>Gbase,从而得到生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G S F T G_{SFT} </math>GSFT。
  • 接下来,我们从生成器中为训练数据中的每个问题采样k个答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ i , j ∼ G ( y ∣ x i ) , j = 1 , 2 , ... , k \hat{y}{i,j} \sim G(y|x_i), \quad j=1,2,\ldots,k </math>y^i,j∼G(y∣xi),j=1,2,...,k。其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ D q u e r y x \in D{query} </math>x∈Dquery。
  • 使用真实答案或人工标注生成的答案以得到生成的每一个答案的正确性 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z。
  • 只使用正确的生成答案 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> z = 1 z = 1 </math>z=1的答案)作为生成器的训练数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> D G E N D_{GEN} </math>DGEN,表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x i , y ^ i , j ) (x_i, \hat{y}{i,j}) </math>(xi,y^i,j)。这里注意, <math xmlns="http://www.w3.org/1998/Math/MathML"> D G E N = ( x i , y ^ i , j ) ∪ D S F T D{GEN} = (x_i, \hat{y}{i,j}) \cup D{SFT} </math>DGEN=(xi,y^i,j)∪DSFT
  • 同时,将生成的所有的 答案与其正确性标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z一起放入验证器训练集 <math xmlns="http://www.w3.org/1998/Math/MathML"> D V E R D_{VER} </math>DVER中,表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x i , y ^ i , j , z i , j ) (x_i, \hat{y}{i,j}, z{i,j}) </math>(xi,y^i,j,zi,j),这样验证器就可以从生成器的错误答案中学习。
  • 在下一轮迭代 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t中,通过利用 <math xmlns="http://www.w3.org/1998/Math/MathML"> D G E N D_{GEN} </math>DGEN对生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G b a s e G_{base} </math>Gbase进行微调得到生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G t G_t </math>Gt

这时候可能有读者要问:那我们训练出的验证器究竟有什么用呢?

验证器的作用是在推理时候,如下图所示:

在推理时候,验证器模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> V T V^T </math>VT会在生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G T G^T </math>GT生成多个候选答案后,选出候选答案中的正确答案。从而增强整个模型架构的性能。

3. 验证器模型的训练方式

<math xmlns="http://www.w3.org/1998/Math/MathML"> V T V^T </math>VT采用了DPO的训练方法,即使不知道DPO也没关系,严格来说,其还是使用了对比学习的方案,通过同时学习正类样例和负类样例来进行学习。

这里,我们先看数据集 <math xmlns="http://www.w3.org/1998/Math/MathML"> D V E R D_{VER} </math>DVER的构建方式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> D V E R = { ( x i , y i , 1 + , y i , 1 − ) , ⋯   , ( x i , y i , m + , y i , m − ) } i = 1 N \mathcal{D}{\mathrm{VER}}=\left\{\left(x_i, y{i, 1}^{+}, y_{i, 1}^{-}\right), \cdots,\left(x_i, y_{i, m}^{+}, y_{i, m}^{-}\right)\right\}_{i=1}^N </math>DVER={(xi,yi,1+,yi,1−),⋯,(xi,yi,m+,yi,m−)}i=1N

其将正类样例和负类样例组合在一起,形成了训练数据集,而组合方式就是笛卡尔积
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ( y i + , y i − ) ∈ { y ^ i , j ∣ z i , j = 1 } × { y ^ i , j ∣ z i , j = 0 } . \left(y_i^{+}, y_i^{-}\right) \in\left\{\hat{y}{i, j} \mid z{i, j}=1\right\} \times\left\{\hat{y}{i, j} \mid z{i, j}=0\right\} . </math>(yi+,yi−)∈{y^i,j∣zi,j=1}×{y^i,j∣zi,j=0}.

最后,采用DPO的损失函数进行训练,即:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> − E ( x , y + , y − ) ∼ D V E R [ log ⁡ σ ( r ^ ( x , y + ) − r ^ ( x , y − ) ) ] , -\mathbb{E}{\left(x, y^{+}, y^{-}\right) \sim \mathcal{D}{\mathrm{VER}}}\left[\log \sigma\left(\hat{r}\left(x, y^{+}\right)-\hat{r}\left(x, y^{-}\right)\right)\right], </math>−E(x,y+,y−)∼DVER[logσ(r^(x,y+)−r^(x,y−))],

DPO的目标是提高正确答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y + y^+ </math>y+相对于错误答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y − y^- </math>y−的相对对数概率。DPO损失函数引导验证器增加问题x的正确答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y + y^+ </math>y+的概率,同时减少错误答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y − y^- </math>y−的概率。研究发现,使用DPO训练的验证器比使用ORM(outcome-supervised reward model)风格的验证器更有效。

实验:评估V-STaR的有效性

1. 实验数据集与模型设置

文中所用的数据集包括:

  • 用于数学推理任务的数据集:GSM8K和MATH
  • 代码生成任务的数据集:MBPP和HumanEval

基础LLM:微调LLaMA2和CodeLLaMA模型。

2. 实验过程与迭代策略

在每次迭代中,我们从上一次迭代训练出的生成器中为每个查询采样k=16个完成项。

  • 对于GSM8K,最开始的训练样本来自仅在原始GSM8K训练数据上训练了2个时期的生成器。
  • 对于MBPP,这些数据来自预训练的CodeLLaMA模型

实验运行了3次迭代,并在每次迭代中采样K=16个答案来增强DGEN和DVER。

在推理时,我们使用生成器为每个测试问题生成128个候选答案。

实验结果与分析:V-STaR在数学推理和代码生成任务中的表现

1. V-STaR与现有方法的对比

V-STaR(Verification for Self-Taught Reasoners)是一种新提出的方法,它在自学习的过程中不仅利用正确的答案,还利用错误的答案来训练一个验证器。这种方法使用DPO(Direct Preference Optimization)来训练验证器,该验证器在推理时用于从多个候选答案中选择最佳答案。

实验结果表明:

  1. V-STaR在数学推理任务上的测试准确率比现有的自学习和验证方法提高了4%到17%,在代码生成任务上提高了4%到12%。

  2. 使用V-STaR微调的7B模型在GSM8K数据集上的表现超过了基础LLaMA2 70B模型(8-shot),并且在HumanEval数据集上几乎与CodeLLaMA 34B模型(zero-shot)相当。

2. V-STaR在不同数据集上的迁移能力

实验中还测试了V-STaR的迁移能力。

  1. 在数学推理方面,仅使用GSM8K训练数据训练的生成器和验证器在整个GSM8K测试集和MATH测试集的子集上进行了评估。
  2. 在代码生成方面,使用MBPP训练数据训练的模型在MBPP的完整测试集和HumanEval测试集上进行了评估。
  3. 结果显示,V-STaR在这些任务上的迁移表现优于其他方法。

讨论:DPO与传统ORM验证器的比较

1. DPO在训练验证器中的优势

实验结果表明:

  1. DPO训练的验证器比传统的ORM(outcome-supervised reward model)风格的验证器表现更好。
  2. DPO验证器在使用LoRA适配器时表现更佳,表明DPO在训练验证器时具有更高的样本效率。

2. DPO验证器作为生成器的性能评估

DPO训练的模型也可以作为生成器使用。实验评估了DPO验证器作为生成器的能力,结果显示,在仅有2000次训练更新的情况下,DPO作为生成器的能力下降,而其作为验证器的性能则显著提高。这表明DPO目标用于验证时非常有效,能够在有限的训练更新中显著提高模型的验证性能。

结论与未来展望:V-STaR的潜力与挑战

V-STaR通过迭代过程中利用正确和错误的生成答案来训练更好的生成器和验证器,显示了在多步推理任务中提升LLMs性能的潜力。我们的实证评估表明,V-STaR在数学推理和代码生成任务上相对于其他自学习和基于验证的方法显示出6%至17%的绝对改进。使用V-STaR微调的7B模型在GSM8K上的性能超过了基础LLaMA2 70B(8-shot),并且在HumanEval上几乎与CodeLLaMA 34B(zero-shot)相匹配。

相关推荐
最新快讯1 小时前
科技快讯 | 阿里云百炼MCP服务上线;英伟达官宣:CUDA 工具链将全面原生支持 Python
人工智能
__Benco2 小时前
OpenHarmony子系统开发 - 热管理(一)
人工智能·harmonyos
吴法刚3 小时前
14-Hugging Face 模型微调训练(基于 BERT 的中文评价情感分析(二分类))
人工智能·深度学习·自然语言处理·分类·langchain·bert·langgraph
碳基学AI3 小时前
北京大学DeepSeek内部研讨系列:AI在新媒体运营中的应用与挑战|122页PPT下载方法
大数据·人工智能·python·算法·ai·新媒体运营·产品运营
是店小二呀4 小时前
Llama 4革命性发布与绿色AI前沿研究
人工智能·llama
2301_799755344 小时前
文件内容课堂总结
人工智能
杰克逊的日记4 小时前
AI集群设计
人工智能·ai·gpu·ai集群·pytorach
技术程序猿华锋4 小时前
Zotero PDF Translate 翻译插件使用OpenAI API配置教程
人工智能·chatgpt·机器翻译
龙萱坤诺4 小时前
GPT-4o-image模型:开启AI图片编辑新时代
人工智能·深度学习
SeaTunnel4 小时前
【同步教程】基于Apache SeaTunnel从MySQL同步到MySQL——Demo方舟计划
大数据·人工智能·apache·etl