大语言模型-RLHF(七)-PPO实践(Proximal Policy Optimization)原理&实现&代码逐行注释

open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO实践。

生活中,我们经常会遇到,希望chatgpt在指定内容范围内回答问题。目前的解决方案大致可以分为两大类,一类是知识库外挂,代表作如langchain。把chatgpt的结果转换为向量在知识库里检索。如下图,本质上最终还是一种向量检索,chatgpt的能力其实是打了一个大的折扣。

另外一类是扩展现有LLM模型的Context处理长度,把候选直接作为llm模型的Context。这里涉及到两个问题,一个是如何扩展Context长度,一个是如何让llm模型只在指定Context内回答问题。今天我们ppo优化主要解决llm模型只在指定Context内回答问题。


样本

我们在1000篇文章中随机选择30篇作为prompt,让模型从这30篇文章中选择出我们想要的文章。

python 复制代码
        #随机选择30篇作为prompt
       random_articles = df.sample(n=31)
       random_article = random_articles.iloc[0]
       cat = random_article['category']
       article_list = [title + ' (' + cat + ')' for title, cat in zip(random_articles['title'], random_articles['category'])]
       input_str = construct_input(article_list, cat)
       input_ids = tokenizer.encode(input_str, return_tensors='pt').to('cuda')

模型准确率判定

可以回答多篇结果,如果模型有我们希望的回答的结果,加1分,不符合减1分。

python 复制代码
        #判断命中条数
       for ans in answer.split('\n'):
           similarity_threshold = 0.9  # 相似度阈值
           # 判断是否在input中且分类是否一致
           if is_similar(ans, article_list, similarity_threshold):
               positive_num = positive_num +1
               break
       print(i, 'accuracy:', positive_num / (i+1))

rm样本制作

第一种

正例:选择一条在prompt中符合条件的新闻为正例

负例:随机选择一条不在prompt中的新闻作为负例,

第二种,

正例:sft一次预测多条,从预测的结果中,挑选出符合条件的为正

负例:sft一次预测多条,从预测的结果中,挑选出不符合条件的为负

比较的结果是第二种方案会好一些。

也可以参考这篇博文ChatGLM-RLHF(三)-RM(Reward Model)实现&代码逐行注释_Pillars-Creation的博客-CSDN博客

ppo训练预测

ppo原理前一章节已经讲了,传送门ChatGLM-RLHF(六)-PPO(Proximal Policy Optimization)原理&实现&代码逐行注释_Pillars-Creation的博客-CSDN博客

需要注意的就是,因为训练时候需要加载sft和rm两个模型, 你需要一个大一点显存的gpu,本例在A100,40G显存上跑通。如果显存小了容易报显存不足的错误。

训练结果

原始预测结果

sft预测结果

ppo预测结果

几点体会,

1,好的sft可以解决大部分的问题,从上面实验看简单sft训练后准确率就可以得到明显提升

2,要根据自身需要定制好的rm样本和loss。有时候单纯根据sft样本,模型可能很难总结出你真正的目的,rm可以帮助模型更好的理解人的期望。

3,rm单独使用效果不一定比sft效果更好,这也比较好理解,rm需要人工标注pair对,数量总是有限的,并且这个pair对,是否清晰表达给了模型用户的全部意图,容易顾此失彼。所以rm我们更多用在最后,结合ppo纠正模型。

4,rm过程可以进行多次,把自己的目标拆解成几个rm过程,更容易达到我们的目标

5,PPO过程确实帮助模型效果得到了提升,并且可以从比较粗劣的rm结果和sft模型对比中学到知识。

完整代码可以参考:

GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

相关推荐
知舟不叙1 小时前
基于OpenCV实现实时颜色检测
人工智能·opencv·计算机视觉·颜色检测
蓑雨春归2 小时前
探索Agent的发展潜力:大模型与具身智能的融合
人工智能
每日新鲜事3 小时前
Lavazza拉瓦萨再度牵手兰博基尼汽车 百年咖啡注入超跑速度
人工智能
说私域3 小时前
传统企业数字化转型:以定制开发开源 AI 智能名片 S2B2C 商城小程序源码为核心的销售环节突破
大数据·人工智能·开源
geneculture4 小时前
社会应用融智学的人力资源模式:潜能开发评估;认知基建资产
人工智能·课程设计·融智学的重要应用·三级潜能开发系统·人力资源升维·认知基建·认知银行
仙人掌_lz6 小时前
Qwen-3 微调实战:用 Python 和 Unsloth 打造专属 AI 模型
人工智能·python·ai·lora·llm·微调·qwen3
美林数据Tempodata7 小时前
大模型驱动数据分析革新:美林数据智能问数解决方案破局传统 BI 痛点
数据库·人工智能·数据分析·大模型·智能问数
硅谷秋水7 小时前
NORA:一个用于具身任务的小型开源通才视觉-语言-动作模型
人工智能·深度学习·机器学习·计算机视觉·语言模型·机器人
正儿八经的数字经7 小时前
人工智能100问☞第46问:AI是如何“学习”的?
人工智能·学习
飞哥数智坊8 小时前
别卷提示词了!像带新人一样“带”AI,产出效率翻倍
人工智能