自学习AI的新突破:微软·谷歌联合研发V-STaR,通过学习错误大幅提升解题准确率!

引言:自学习在大型语言模型中的新进展

该文提出了一种新的方法------V-STaR,传统的模型只采用正确答案训练生成器,而V-STaR利用在训练过程中生成的正确和错误答案来训练一个验证器,该验证器使用DPO加强生成器的性能。而在推理时,V-STaR可以从多个候选答案中选择一个最佳答案。

论文概览:V-STaR方法的提出背景与核心贡献

1. 论文标题、机构、论文链接

  • 论文标题:V-STaR: Training Verifiers for Self-Taught Reasoners
  • 机构:Université de Montréal; Microsoft Research; Google Deepmind
  • 论文链接:arxiv.org/pdf/2402.06...

V-STaR方法详解:结合正确与错误答案的自学习

1. V-STaR的核心思想

V-STaR(Verification for Self-Taught Reasoners)是一种新颖的自学习方法,它在传统的自学习方法基础上进行了改进。

  • 传统方法,如STaR,通过迭代地对大型语言模型(LLMs)进行微调,使用自生成的正确方案来提高其解决问题的能力。
  • 然而,这些方法通常会丢弃过程中产生的大量错误答案,而这些错误答案可能包含有价值的信息。
  • 语言模型可以通过学习正确和错误答案之间的差异,识别生成的错误的模式,从而提供更准确解决方案。

V-STaR的核心思想是在自学习过程中利用正确和错误的答案来训练一个验证器,该验证器使用DPO(Direct Preference Optimization)来判断模型生成的答案的正确性。

2. V-STaR的流程

  • 首先,利用训练数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> D S F T D_{SFT} </math>DSFT微调预训练LLM <math xmlns="http://www.w3.org/1998/Math/MathML"> G b a s e G_{base} </math>Gbase,从而得到生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G S F T G_{SFT} </math>GSFT。
  • 接下来,我们从生成器中为训练数据中的每个问题采样k个答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ i , j ∼ G ( y ∣ x i ) , j = 1 , 2 , ... , k \hat{y}{i,j} \sim G(y|x_i), \quad j=1,2,\ldots,k </math>y^i,j∼G(y∣xi),j=1,2,...,k。其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ D q u e r y x \in D{query} </math>x∈Dquery。
  • 使用真实答案或人工标注生成的答案以得到生成的每一个答案的正确性 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z。
  • 只使用正确的生成答案 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> z = 1 z = 1 </math>z=1的答案)作为生成器的训练数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> D G E N D_{GEN} </math>DGEN,表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x i , y ^ i , j ) (x_i, \hat{y}{i,j}) </math>(xi,y^i,j)。这里注意, <math xmlns="http://www.w3.org/1998/Math/MathML"> D G E N = ( x i , y ^ i , j ) ∪ D S F T D{GEN} = (x_i, \hat{y}{i,j}) \cup D{SFT} </math>DGEN=(xi,y^i,j)∪DSFT
  • 同时,将生成的所有的 答案与其正确性标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z一起放入验证器训练集 <math xmlns="http://www.w3.org/1998/Math/MathML"> D V E R D_{VER} </math>DVER中,表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x i , y ^ i , j , z i , j ) (x_i, \hat{y}{i,j}, z{i,j}) </math>(xi,y^i,j,zi,j),这样验证器就可以从生成器的错误答案中学习。
  • 在下一轮迭代 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t中,通过利用 <math xmlns="http://www.w3.org/1998/Math/MathML"> D G E N D_{GEN} </math>DGEN对生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G b a s e G_{base} </math>Gbase进行微调得到生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G t G_t </math>Gt

这时候可能有读者要问:那我们训练出的验证器究竟有什么用呢?

验证器的作用是在推理时候,如下图所示:

在推理时候,验证器模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> V T V^T </math>VT会在生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G T G^T </math>GT生成多个候选答案后,选出候选答案中的正确答案。从而增强整个模型架构的性能。

3. 验证器模型的训练方式

<math xmlns="http://www.w3.org/1998/Math/MathML"> V T V^T </math>VT采用了DPO的训练方法,即使不知道DPO也没关系,严格来说,其还是使用了对比学习的方案,通过同时学习正类样例和负类样例来进行学习。

这里,我们先看数据集 <math xmlns="http://www.w3.org/1998/Math/MathML"> D V E R D_{VER} </math>DVER的构建方式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> D V E R = { ( x i , y i , 1 + , y i , 1 − ) , ⋯   , ( x i , y i , m + , y i , m − ) } i = 1 N \mathcal{D}{\mathrm{VER}}=\left\{\left(x_i, y{i, 1}^{+}, y_{i, 1}^{-}\right), \cdots,\left(x_i, y_{i, m}^{+}, y_{i, m}^{-}\right)\right\}_{i=1}^N </math>DVER={(xi,yi,1+,yi,1−),⋯,(xi,yi,m+,yi,m−)}i=1N

其将正类样例和负类样例组合在一起,形成了训练数据集,而组合方式就是笛卡尔积
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ( y i + , y i − ) ∈ { y ^ i , j ∣ z i , j = 1 } × { y ^ i , j ∣ z i , j = 0 } . \left(y_i^{+}, y_i^{-}\right) \in\left\{\hat{y}{i, j} \mid z{i, j}=1\right\} \times\left\{\hat{y}{i, j} \mid z{i, j}=0\right\} . </math>(yi+,yi−)∈{y^i,j∣zi,j=1}×{y^i,j∣zi,j=0}.

最后,采用DPO的损失函数进行训练,即:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> − E ( x , y + , y − ) ∼ D V E R [ log ⁡ σ ( r ^ ( x , y + ) − r ^ ( x , y − ) ) ] , -\mathbb{E}{\left(x, y^{+}, y^{-}\right) \sim \mathcal{D}{\mathrm{VER}}}\left[\log \sigma\left(\hat{r}\left(x, y^{+}\right)-\hat{r}\left(x, y^{-}\right)\right)\right], </math>−E(x,y+,y−)∼DVER[logσ(r^(x,y+)−r^(x,y−))],

DPO的目标是提高正确答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y + y^+ </math>y+相对于错误答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y − y^- </math>y−的相对对数概率。DPO损失函数引导验证器增加问题x的正确答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y + y^+ </math>y+的概率,同时减少错误答案 <math xmlns="http://www.w3.org/1998/Math/MathML"> y − y^- </math>y−的概率。研究发现,使用DPO训练的验证器比使用ORM(outcome-supervised reward model)风格的验证器更有效。

实验:评估V-STaR的有效性

1. 实验数据集与模型设置

文中所用的数据集包括:

  • 用于数学推理任务的数据集:GSM8K和MATH
  • 代码生成任务的数据集:MBPP和HumanEval

基础LLM:微调LLaMA2和CodeLLaMA模型。

2. 实验过程与迭代策略

在每次迭代中,我们从上一次迭代训练出的生成器中为每个查询采样k=16个完成项。

  • 对于GSM8K,最开始的训练样本来自仅在原始GSM8K训练数据上训练了2个时期的生成器。
  • 对于MBPP,这些数据来自预训练的CodeLLaMA模型

实验运行了3次迭代,并在每次迭代中采样K=16个答案来增强DGEN和DVER。

在推理时,我们使用生成器为每个测试问题生成128个候选答案。

实验结果与分析:V-STaR在数学推理和代码生成任务中的表现

1. V-STaR与现有方法的对比

V-STaR(Verification for Self-Taught Reasoners)是一种新提出的方法,它在自学习的过程中不仅利用正确的答案,还利用错误的答案来训练一个验证器。这种方法使用DPO(Direct Preference Optimization)来训练验证器,该验证器在推理时用于从多个候选答案中选择最佳答案。

实验结果表明:

  1. V-STaR在数学推理任务上的测试准确率比现有的自学习和验证方法提高了4%到17%,在代码生成任务上提高了4%到12%。

  2. 使用V-STaR微调的7B模型在GSM8K数据集上的表现超过了基础LLaMA2 70B模型(8-shot),并且在HumanEval数据集上几乎与CodeLLaMA 34B模型(zero-shot)相当。

2. V-STaR在不同数据集上的迁移能力

实验中还测试了V-STaR的迁移能力。

  1. 在数学推理方面,仅使用GSM8K训练数据训练的生成器和验证器在整个GSM8K测试集和MATH测试集的子集上进行了评估。
  2. 在代码生成方面,使用MBPP训练数据训练的模型在MBPP的完整测试集和HumanEval测试集上进行了评估。
  3. 结果显示,V-STaR在这些任务上的迁移表现优于其他方法。

讨论:DPO与传统ORM验证器的比较

1. DPO在训练验证器中的优势

实验结果表明:

  1. DPO训练的验证器比传统的ORM(outcome-supervised reward model)风格的验证器表现更好。
  2. DPO验证器在使用LoRA适配器时表现更佳,表明DPO在训练验证器时具有更高的样本效率。

2. DPO验证器作为生成器的性能评估

DPO训练的模型也可以作为生成器使用。实验评估了DPO验证器作为生成器的能力,结果显示,在仅有2000次训练更新的情况下,DPO作为生成器的能力下降,而其作为验证器的性能则显著提高。这表明DPO目标用于验证时非常有效,能够在有限的训练更新中显著提高模型的验证性能。

结论与未来展望:V-STaR的潜力与挑战

V-STaR通过迭代过程中利用正确和错误的生成答案来训练更好的生成器和验证器,显示了在多步推理任务中提升LLMs性能的潜力。我们的实证评估表明,V-STaR在数学推理和代码生成任务上相对于其他自学习和基于验证的方法显示出6%至17%的绝对改进。使用V-STaR微调的7B模型在GSM8K上的性能超过了基础LLaMA2 70B(8-shot),并且在HumanEval上几乎与CodeLLaMA 34B(zero-shot)相匹配。

相关推荐
boooo_hhh23 分钟前
深度学习笔记16-VGG-16算法-Pytorch实现人脸识别
pytorch·深度学习·机器学习
AnnyYoung27 分钟前
华为云deepseek大模型平台:deepseek满血版
人工智能·ai·华为云
INDEMIND1 小时前
INDEMIND:AI视觉赋能服务机器人,“零”碰撞避障技术实现全天候安全
人工智能·视觉导航·服务机器人·商用机器人
慕容木木1 小时前
【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体的替代品,可本地部署+知识库,注册即可有750w的token使用
人工智能·火山引擎·deepseek·deepseek r1
南 阳1 小时前
百度搜索全面接入DeepSeek-R1满血版:AI与搜索的全新融合
人工智能·chatgpt
企鹅侠客2 小时前
开源免费文档翻译工具 可支持pdf、word、excel、ppt
人工智能·pdf·word·excel·自动翻译
冰淇淋百宝箱2 小时前
AI 安全时代:SDL与大模型结合的“王炸组合”——技术落地与实战指南
人工智能·安全
Elastic 中国社区官方博客3 小时前
Elasticsearch Open Inference API 增加了对 Jina AI 嵌入和 Rerank 模型的支持
大数据·人工智能·elasticsearch·搜索引擎·ai·全文检索·jina
AWS官方合作商3 小时前
Amazon Lex:AI对话引擎重构企业服务新范式
人工智能·ai·机器人·aws
workflower3 小时前
Prompt Engineering的重要性
大数据·人工智能·设计模式·prompt·软件工程·需求分析·ai编程