Training language models to follow instructions with human feedback 论文阅读

论文原文:https://arxiv.org/pdf/2203.02155

论文简介

语言模型越大并不意味着它能更好的理解用户的意图,因此在这篇论文中,展示了根据人的反馈对模型进行微调,使得语言模型能够在各种人物上更好的理解用户的意图。在评估中,1.3B参数的InstructGPT模型的输出比175B GPT-3的输出更受欢迎,尽管参数少了100倍。此外,InstructGPT模型虽然在公共的数据上的效果有所降低,但是真实性和减少有害方面生成的能力提升。论文表明,尽管InstructGPT仍然会犯一些简单的错误,但根据人类反馈进行微调是能够理解人类意图的一个有效的方式和方向。
**相当于是,OpenAI提出了"align"的概念,希望模型的输出与人类的意图"对齐",其用的方法是RLHF(Reinforcement Learning from Human Feedback)基于人类反馈的强化学习。**

方法和实验细节

Collect demonstration data, and train a supervised policy. (收集范例数据,并以有监督方式训练)

我们的打标签者提供了输入提示分布(prompt distribution)上所需行为的范例(有关此分布的详细信息,请参阅第 3.2 节)。 然后,我们使用有监督学习在该数据集上微调预训练的 GPT-3 模型。这部分就是根据prompts,也就是写的各种问题,进行标注,将prompts和标注的对话作为人工标注的数据集,对预训练的GPT-3进行有监督微调

Collect comparison data, and train a reward model. (收集比较数据,训练奖励模型)

我们收集了模型输出之间比较的数据集,其中打标记者根据输入标明了他们更喜欢的输出。 然后我们训练奖励模型来预测人类偏好的输出。用上一步得到的SFT模型生成各种问题的答案,再对这些答案进行比较(排序式)标注,如D>C>A=B,基于这个标注数据集,在去掉最后的嵌入层的SFT模型基础上进行有监督学习训练一个RM(reward model),这样使用模型来模仿标注者进行打分

Optimize a policy against the reward model using PPO. (使用PPO针对奖励模型优化策略)

我们使用RM奖励模型的输出作为标量奖励。 我们使用 PPO 算法微调监督策略以优化此奖励。

步骤2和步骤3可以不断迭代; 收集当前最佳策略的更多比较数据,用于训练新的 RM,然后训练新的策略。 在实践中,我们的大部分比较数据来自监管的学习,也有一些来自我们的PPO学习。用上一步的RM模型进行打分,然后分数就可以用强化学习来对SFT模型进行优化

数据集

打标签者提供了输入提示分布(prompt distribution)上所需行为的范例,根据论文所说,为了训练第一个InstructGPT模型,打标签者需要自己编写提示,分为三种:

  • Plain:只是要求标记者提出一个任意的任务,同时确保任务具有足够的多样性。
  • Few-shot:要求标注者提出一条指令,以及针对该指令的多个查询/相应对。
  • User-based:在OpenAI API的候补名单申请中陈述了许多用例,要求标注者提出与这些用例相对应的提示。
    根据这些提示,生成了三个用于微调过程的不同数据集:(1)SFT数据集,带有用于训练SFT模型的打标签者范例数据,(2)RM数据集,带有用于训练的模型已被打标签者分了等级的数据,(3)PPO数据集,没有任何人工标签,用于RLHF微调的输入。SFT数据集包含大约13k个训练提示数据(来自API和标记者编写),RM数据集有33k个训练提示数据(来自API和打标记者编写),PPO数据集有31k个训练提示数据(仅来自API)。

    上表显示了API提示(特别是RM数据集)的用例类别的分布,大多数用例都是生成的,而不是分类或QA。在表二中展示了一些说明性提示(由研究人员编写,以模仿提交给InstructGPT模型的提示类型)。

任务

训练任务来自两个来源:(1)由标注者编写的提示数据集和(2)提交给API上的早期InstructGPT模型的提示数据集。这些提示非常多样化,包括生成、问答、对话、摘要、提取和其他自然语言任务。数据集超过96%是英语。

对于每个自然语言提示,任务通常是通过自然语言指令直接指定的(例如"写一个关于聪明青蛙的故事"),但也可以通过少数例子间接指定(例如给出两个青蛙故事的例子,并提示模型生成一个新的)或隐含的连续(例如提供一个关于青蛙的故事的开始)。在每种情况下,我们都要求标注者尽最大努力推断出写提示的用户的意图,并要求他们跳过任务非常不清楚的输入(相当于当任务非常不清楚的时候,可以跳过回答,避免答非所问)。此外,在我们提供给他们的指示和他们的最佳判断的指导下,标注者还需考虑到隐含的意图,如回应的真实性,以及潜在的有害输出,如有偏见或有毒的语言。

模型

我们从GPT-3预训练语言模型开始。这些模型是在广泛分布的互联网数据上进行训练的,可以适应广泛的下游任务,但行为特征不佳。从这些模型开始,我们用三种不同的技术训练模型:

  • 有监督微调(SFT------Supervised fine-tuning),我们使用监督学习对标记器演示中的GPT-3进行微调。我们训练了16个epoch,使用余弦学习率衰减,0.2的残差dropout。我们根据验证集上的RM分数进行最终的SFT模型选择。我们发现SFT模型在1个epoch后对验证损失上过拟合,然而我们发现尽管存在过拟合,但更多epochs的训练有助于RM分数和人类偏好评级。(尽管这个SFT模型训练更多的epoch会产生过拟合,但是这是为了得到后续的RM模型的初始化模型,对RM模型有帮助,并不是直接使用这个SFT模型,所以过拟合没关系
  • 奖励建模(RM------Reward model),从移除了最后的非嵌入层的SFT模型开始(GPT模型最后的softmax层是用于得到每个词的概率,去掉softmax层以后,增加一个线性层来投影,将所有词的输出投影到一个值上面,即输出一个标量的分数 ),我们训练了一个模型来接收提示和相应,并输出标量奖励。在本文中,我们只使用6B RM,这样可以节省大量计算,而且我们发现175B RM训练可能不稳定,因此不太适合用作RL(Reinforcement learning)中的值函数。RM在同一输入的两个模型输出之间进行比较的数据集上训练。他们使用交叉熵损失,将比较作为标签------奖励的差异代表了人类标记者更喜欢一种反应的对数几率。
    为了加速分等级数据的收集,我们向标签提供者提供 K = 4 K=4 K=4和 K = 9 K=9 K=9之间的任何排名相应。这会为显示给标签者的每个提示生成 ( K 2 ) = C K 2 \binom{K}{2}=C_K^2 (2K)=CK2比较。由于分等级数据在每个标记任务中都非常相关,我们发现,如果我们简单地将分等级数据混洗到一个数据集中,在数据集上的一次遍历会导致奖励模型过拟合。相反,我们将每个提示的所有 ( K 2 ) \binom{K}{2} (2K)比较数据作为单个批处理元素进行训练。这在计算上要高效得多,因为它只需要每次完成一次RM的前向传递(而不是超过 ( K 2 ) \binom{K}{2} (2K)次前向传递),而且因为它不在过拟合,大大提高了验证准确性和日志损失。
    具体来说,奖励模型的损失函数为(这里使用的是排序中最常见的pairwise ranking loss,成对排名损失 ):

    这里的 r θ ( x , y ) r_{\theta}(x,y) rθ(x,y)是表示prompt x x x和相应 y y y在参数为 θ \theta θ的奖励模型下的奖励值, y w y_w yw是在prompt x x x下生成的一对响应 y w y_w yw和 y l y_l yl中更受欢迎的那一个, D D D是比较的数据集。每一个排名对 y i , y j y_i,y_j yi,yj的损失是 − l o g ( σ ( y i − y j ) ) -log(\sigma(y_i-y_j)) −log(σ(yi−yj)),换成奖励函数就是 − l o g ( σ ( r θ ( x , y w ) ) − r θ ( x , y l ) ) -log(\sigma(r_{\theta}(x,y_w))-r_{\theta}(x,y_l)) −log(σ(rθ(x,yw))−rθ(x,yl)),然后共 C K 2 C_K^2 CK2个排序对,所以期望除以它。
    目标是最小化这个loss,也就是最大化这两个奖励的差值, l o g ( σ ) log(\sigma) log(σ)最开始的时候是把生成的每个输出对都作为单独的数据混洗到数据集中,这样的话就需要超过 ( K 2 ) \binom{K}{2} (2K)次前向传递,而且输出对之间有重复,这样容易过拟合,所以将所有的输出对都统一作为单个批处理元素进行训练,这样的话就只需要 K K K次前向传递,因为奖励模型只需要算出9个奖励。之所以取 K = 9 K=9 K=9,是因为考虑到人工标注的时候,很大一部分是花在读懂这个prompt,所以在 K = 4 K=4 K=4和 K = 9 K=9 K=9之间,只多了不到一倍的时间,但是标注的数据由6变成了36,多了6倍
    最后,由于RM损失对于奖励的变化是不变的,我们使用偏差对奖励模型进行归一化,以便在进行RL之前,标记器演示的平均得分为0。
  • 强化学习(RL------Reinforcement learning),我们使用PPO在我们的环境中微调了SFT模型。该模型是一个bandit环境,它呈现随机的客户提示并期望对提示的响应。给定提示和相应,它会产生由奖励模型确定的奖励并结束情节。此外,我们在每个token上上添加了SFT模型的每个token的KL惩罚,以减轻奖励模型的过度优化。从RM初始化值函数。我们称这些模型为PPO。
    我们还尝试将预训练梯度混合到PPO梯度中,以修复公共NLP数据集上的性能回归。我们称这些模型为"PPO-ptx"。我们在RL训练中最大化以下组合目标函数:

    其中 π Θ R L \pi_{\Theta}^{RL} πΘRL是学习到的RL策略, π S F T \pi^{SFT} πSFT是有监督训练的模型, D p r e t r a i n D_{pretrain} Dpretrain是预训练分布。KL奖励系数 β \beta β,预训练损失系数 γ \gamma γ分别控制KL惩罚和预训练梯度的强度。对于"PPO"模型, γ \gamma γ被设置为0,除非另有说明,本文中的InstructGPT指的是PPO-ptx模型。对于上面说的31k个prompts数据集 D D D,都使用当前的RL模型,也就是RL策略 π θ R L \pi_{\theta}^{RL} πθRL,输出 y y y,然后用RM模型得到分数 r θ ( x , y ) ,目标函数是希望这个分数最大化 r_{\theta}(x,y),目标函数是希望这个分数最大化 rθ(x,y),目标函数是希望这个分数最大化然后根据这个目标函数,更新RL模型,然后再用RM模型计算得分,反复迭代。
    目标函数中还有两项,在此分别解释一下, β l o g ( π Θ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta log(\pi_{\Theta}^{RL}(y|x)/\pi^{SFT}(y|x)) βlog(πΘRL(y∣x)/πSFT(y∣x))是正则项,这是PPO的主要思想,随着模型的更新,RL产生的输出 y y y和原始的 S F T SFT SFT模型输出的 y y y会逐渐不一样,即数据分布( y ∣ x y|x y∣x)的差异会越来越大, R L RL RL的输出可能会不准,所以论文在loss里加入了一个KL散度 KL ( P ∥ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) = ∫ P ( x ) log ⁡ ( P ( x ) Q ( x ) )   d x \text{KL}(P \parallel Q) = \sum_{x} P(x) \log \left(\frac{P(x)}{Q(x)}\right)= \int P(x) \log \left(\frac{P(x)}{Q(x)}\right)\, dx KL(P∥Q)=∑xP(x)log(Q(x)P(x))=∫P(x)log(Q(x)P(x))dx,用于描述一个概率分布相对于另一个概率分布的非对称性差异,相当于用这个散度来正则,希望RLSFT的输出分布不要偏太远,因为是最大化目标函数,所以要最小化KL散度需要在前面加一个负号。
    γ E x D p r e t r a i n [ l o g ( π Θ R L ( x ) ) ] \gamma E_x ~ D_{pretrain}[log(\pi_{\Theta}^{RL}(x))] γEx Dpretrain[log(πΘRL(x))],由于前两项目标函数只和人类排序部分有关,所以训练出来会导致模型仅仅对排序的结果较好,而在最终任务通用NLP任务上性能会下降,所以论文在loss中加入了GPT-3预训练模型的目标函数, D p r e t r a i n D_{pretrain} Dpretrain表示从训练GPT-3的预训练数据中采样 x x x,然后输入RL模型得到输出概率 π Θ R L ( x ) \pi_{\Theta}^{RL}(x) πΘRL(x),这样相当于是GPT-3本身的损失函数。

    总的来说,如果 γ = 0 \gamma=0 γ=0就是一个PPO函数,否则就是一个PPO加上一个GPT-3的目标函数的结合成为RL模型的目标函数,也就是PPO-ptx

讨论

论文提出,本文使用的"对齐技术"------RLHF,是用于对齐人类系统的一个重要方法。与预训练相比,增加模型对齐的成本是适中的(仅仅标注几万条prompt数据),与训练GPT-3的花费相比(海量的各种数据),只占一小部分。上述结果也表明,RLHF在使语言模型更加helpful(真实性和无害性是被隐式优化了)方面非常有效,甚至比模型增加100倍更有效。所以,在自然语言领域,研究alignment可能比训练更大规模的模型更具性价比。

align也有争议,就是到底要align人类到什么地步,是用户让做什么就做什么,还是要理解用户更深层的、内在的一些东西。此外最后的RL模型也不是必要的,如果在第一步多标数据,在GPT-3微调,步骤会变得简单,可能更加实用。

相关推荐
ZHOU_WUYI2 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1232 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界3 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221513 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2513 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街4 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台4 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网
加密新世界4 小时前
优化 Solana 程序
人工智能·算法·计算机视觉
hunteritself4 小时前
ChatGPT高级语音模式正在向Web网页端推出!
人工智能·gpt·chatgpt·openai·语音识别
Che_Che_5 小时前
Cross-Inlining Binary Function Similarity Detection
人工智能·网络安全·gnn·二进制相似度检测