大模型从失败中学习 —— 微调大模型以提升Agent性能

人工智能咨询培训老师叶梓 转载标明出处

以往的研究在微调LLMs作为Agent时,通常只使用成功的交互轨迹,而丢弃了未完成任务的轨迹。这不仅造成了数据和资源的浪费,也可能限制了微调过程中可能的优化路径。论文《Learning From Failure: Integrating Negative Examples when Fine-tuning Large Language Models as Agents》提出了负面感知训练(Negative-Aware Training, NAT)方法,通过适当的数据清洗和微调策略,使大模型能够从失败中学习,旨在提高模型在数学推理、多跳问答和策略性问答任务中的性能。

方法

图1为先前的方法和作者的NAT范式。在数据收集阶段,收集了LLMs与环境(工具)之间的交互。在数据处理阶段,先前的方法简单地过滤掉负面样本,而NAT通过在任务查询中添加提示来重新格式化轨迹,根据它们是正面还是负面。图1(c)给出了重格式化的正面和负面轨迹的示例。这里省略了系统提示,以简化说明。

如图1所示,Agent框架中任务解决过程被详细划分。首先,LLM被提供了一个系统提示,概述了(a)要解决的具体任务(例如,"解决一个数学问题"),(b)任务执行允许使用的工具,以及(c)预期的动作空间和输出格式(例如,finish[N]表示N是最终答案)。其次,引入一个查询实例。以ReAct格式提示模型回答查询,包括推理文本(称为"thoughts")和"actions"。最后,在互动阶段,系统使用预定义的工具执行LLM生成的动作,将结果观察返回给LLM,并提示后续动作,直到生成任务的完成动作,或交互轮次超过预定义阈值。

对于数学任务,作者设计了一个由SymPy实现的计算器,它接受数学表达式作为输入并输出结果。对于两个问答任务,作者设计了一个搜索工具,使用Serper 2 API。它接受搜索查询作为输入并返回谷歌搜索结果。他们进一步使用MPNet和DPR对搜索结果进行重新排名。

负面感知训练范式的流程包括数据收集、数据清洗、负面感知重格式化和微调四个阶段。其中负面感知重格式化是范式的核心部分,使Agent调整得更好。

  • 数据收集:对于每个任务,获得初始问题和相应的真实答案作为种子数据。然后使用GPT-3.5生成三次轨迹,每次使用不同的温度(0.2、0.5和0.7)。这能够收集多样化的正面和负面样本。通过比较预测答案和真实答案,可以将每个轨迹标记为正面或负面。

  • 负面感知重格式化:在Agent调整过程中区分正面样本和负面样本有助于教模型辨别成功和不成功的结果。附加一个字符串后缀,告诉模型训练样本是正面还是负面。对于正面样本,附加"Please generate a solution that correctly answers the question." 对于负面样本,附加"Please generate a solution that incorrectly answers the question。"

  • 微调和推理:使用重格式化后的轨迹对LLMs进行微调。损失只计算LLM生成的文本部分,这与微调聊天模型类似。在推理过程中,只使用正面样本的提示来提示微调后的Agent。

表格1展示了作者的方法与其他论文方法的比较。通过这些结果,可以看出NAT方法在不同任务上相较于其他方法有显著的性能提升。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory ------ 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加下方微信或评论留言,即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory。关注享粉丝福利,限时免费录播讲解。

LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

作者在数学推理、多跳问答和策略性问答任务上进行了实验。对于数学任务,他们使用GSM8k作为种子数据,并在GSM8k、ASDiv、SVAMP和MultiArith上测试性能。对于问答任务,他们在HotpotQA和StrategyQA上分别收集轨迹并测试性能。

主要将NAT与两个基线进行比较。"Vanilla"设置仅使用正面样本对LLMs进行微调,这是之前研究的做法。第二个设置是"Negative-Unaware Training (NUT)",它在未添加任何前缀或后缀的情况下纳入负面样本。

在LLaMA-2-Chat 7B和13B模型上进行实验,所有模型均微调2个epoch,批量大小为64。使用余弦调度器,总步骤的3%作为预热。最大学习率设置为5×10^-5。使用4×A100 GPU和DeepSpeed ZeRO 3阶段进行模型训练。

表2展示了数学任务的总体结果,观察到:(1) 纳入负面样本可以提高模型性能;(2) 采用负面感知训练(NAT)的模型不仅优于仅使用正面样本训练的相应模型(Vanilla),而且也优于直接纳入负面样本训练的相同模型(NUT);(3) 当正面样本较少或模型较小时,NAT的改进更为显著。具体来说,使用2k正面样本的7B模型,NAT实现了8.74%的性能提升,而使用5k正面样本的13B模型,性能提升为0.52%。

表3和表4展示了在HotpotQA和StrategyQA上的结果。在这里,NAT-2是NAT的一个变体,它将负面数据分为两类,并为每类使用不同的提示。在HotpotQA上,NAT-2在EM和f1分数上分别比没有负面样本提高了2%和6%。与NUT相比,NAT在EM和f1上仍然分别提高了约1%。在StrategyQA上,NAT比没有负面样本和NUT分别提高了8%和3%。

表格2到4展示了LLMs在微调为Agent时从负面样本中学习的能力。作者深入探讨了可能影响负面感知训练有效性的各种因素。他们试图回答以下问题:(1) 给定固定数量的正面样本,应该使用多少负面数据?(2) 模型从负面轨迹中学到了什么?(3) 所有类型的负面样本都有益吗?(4) 哪些因素促成了负面感知训练(NAT)优于负面无感知训练(NUT)?

训练样本数量的影响:初步分析关注负面样本数量的影响。保持正面样本数量恒定在2k和5k,同时将负面样本从0调整到12k。结果显示,随着负面数据量的增加,性能得到提升,当负面样本数量约为11k时性能趋于平稳。

数据质量的重要性:从不同模型中获取负面数据,以研究负面数据质量在NAT中的影响。具体来说,将来自GPT-3.5的数据视为高质量示例。相比之下,使用微调后的LLaMA-2-7B模型生成了10k负面样本作为低质量数据的代表。实验结果强调了数据质量在NAT中的关键作用。

模型通过NAT学到了什么:分析了由LLaMA-2-7B训练的GSM8K测试集的轨迹,这些模型分别使用正面样本(Vanilla)、NUT和NAT。表6显示了不同训练策略下模型的准确性、动作错误率(错误调用工具的百分比)和平均回合数。纳入负面样本也引入了更多的动作错误,这可能导致微调模型与Vanilla相比有更多的动作错误。然而,在纳入负面样本后,NUT和NAT的准确性都提高了。这表明负面样本主要通过教授模型更好的"思想"(即推理和规划)来起作用。

负面样本与正面样本的相似作用:为了进一步探索模型是否像从正面轨迹中学习一样从负面轨迹中学习,作者随机抽取了100个成功的轨迹作为开发集,并测量了使用2k正面样本(不与开发集重叠)和不同数量负面样本训练的模型的困惑度。图4显示了随着负面数据量的增加,困惑度的变化。随着更多负面数据的纳入,困惑度降低,这表明模型学会用失败轨迹的知识来适应成功的轨迹。然而,这条曲线在最后似乎是水平的,并且与4k和5k正面样本的曲线之间仍然存在很大差距,这表明一些来自成功轨迹的属性或知识永远无法从失败的轨迹中学到。

添加提示的选择:已经有不同的研究表明,提示对LLM性能至关重要。在这里,作者探索了添加提示的可解释性。具体而言是提示的内容使LLMs能够从成功和失败的轨迹中不同地学习,还是仅仅区分这些轨迹?他们提出了两组提示。一组是具有可解释性的提示,例如让模型生成正确或错误的轨迹。另一组是没有可解释性的提示。例如,可以为查询添加不同的字母作为前缀。表7显示了使用可解释和不可解释提示训练的模型的结果。不同的提示在性能上没有显示出大的差异,这表明NAT的性能提升来自于简单地区分正面和负面数据。

实验结果表明,与传统的仅使用正面样本或简单地结合正面和负面样本的方法相比,NAT方法在多个任务和模型尺寸上都显示出了优越的性能。特别是在数据稀缺的场景下,NAT的性能提升更为显著。

论文链接:https://arxiv.org/pdf/2402.11651

代码链接:https://github.com/Reason-Wang/NAT

相关推荐
理想不理想v17 分钟前
前端项目性能优化(详细)
前端·性能优化
PP东2 小时前
ES6学习Generator 函数(生成器)(八)
javascript·学习·es6
云起无垠2 小时前
【论文速读】| FirmRCA:面向 ARM 嵌入式固件的后模糊测试分析,并实现高效的基于事件的故障定位
人工智能·自动化
小屁不止是运维4 小时前
麒麟操作系统服务架构保姆级教程(二)ssh远程连接
linux·运维·服务器·学习·架构·ssh
Leweslyh4 小时前
物理信息神经网络(PINN)八课时教案
人工智能·深度学习·神经网络·物理信息神经网络
love you joyfully5 小时前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
该醒醒了~5 小时前
PaddlePaddle推理模型利用Paddle2ONNX转换成onnx模型
人工智能·paddlepaddle
小树苗1935 小时前
DePIN潜力项目Spheron解读:激活闲置硬件,赋能Web3与AI
人工智能·web3
follycat5 小时前
bestphp‘s revenge
学习·web安全
职业考试资料墙5 小时前
二级建造师考试题库及答案
学习·考试·题库