GPT-RLHF :深入解析奖励模型 (Reward Model)

GPT-RLHF :深入解析奖励模型 (Reward Model)

从 ChatGPT 到 Claude,再到各种开源大模型,它们惊艳表现的背后,除了不断增长的参数规模,还有一个至关重要的技术------基于人类反馈的强化学习 (RLHF)

RLHF 是一套复杂的流程,它旨在让大语言模型(LLM)的输出更符合人类的偏好、价值观,使其变得"有用且无害"。而在这套流程中,奖励模型 (Reward Model, RM) 扮演着"人类偏好代理人"的关键角色。它就像一个指南针,为强化学习阶段的 LLM 指明了优化的方向。

本文将深入探讨奖励模型 (RM):从它在 RLHF 中的位置开始,解析其模型结构,并详细举例说明其核心训练算法------Pairwise Ranking Loss。

一、 RLHF 的"三步走"战略

要理解 RM,首先要明白它在整个 RLHF 流程中的位置。RLHF 通常分为三个阶段:

  1. 阶段 1: SFT (监督微调)

    • 目标: 让预训练的 LLM 学会模仿人类的回答方式。
    • 做法: 收集少量高质量的"标注者示范数据"((prompt, response) 对),在预训练模型的基础上进行监督微调(Supervised Fine-Tuning)。
    • 产出: SFT 模型。这个模型已经能按指令回答问题,但其回答的质量、安全性和有用性还不稳定。
  2. 阶段 2: 训练奖励模型 (RM)

    • 目标: 训练一个模型,使其能够模拟人类的偏好,为任何 (prompt, response) 对打分。
    • 做法: 拿一个 SFT 模型,对同一个 prompt 生成 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 个不同的回答。人类标注者对这 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 个回答进行排序(从最好到最差)。RM 的任务不是预测一个绝对分数,而是学习这个"排序"。
    • 产出: RM 模型。这是我们本文的主角。
  3. 阶段 3: PPO (强化学习)

    • 目标: 使用 RM 作为"奖励函数",通过强化学习(如 PPO 算法)来优化 SFT 模型,使其生成的回答能获得更高的 RM 分数。
    • 做法: SFT 模型(此时称为策略模型)生成回答,RM 为这个回答打分(奖励)。PPO 算法根据这个奖励来更新策略模型的参数,使其更倾向于生成高分回答。
    • 产出: 最终的、经过"对齐"的 RLHF 模型。

二、 奖励模型 (RM) 的"庐山真面目":模型结构

一个常见的误解是:奖励模型是一个全新的、神秘的模型。

事实是:奖励模型 (RM) 通常与 SFT 模型使用相同(或相近)的 Transformer 架构。 它们的主要区别不在于"骨架",而在于"头部"------即模型的最后一层如何输出结果。

1. 输入:Prompt 和 Response 的拼接

RM 的任务是判断一个 response <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 对于一个 prompt <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 来说有多好。因此,它需要同时接收 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 作为输入。

  • 输入序列: RM 会将 prompt 和 response 拼接成一个单独的文本序列,并用特殊的分隔符(如 [SEP])隔开,最后加上一个结束符(如 [EOS])。

    • [Prompt 文本] [SEP] [Response 文本] [EOS]

2. 模型主体:Transformer 编码器

这个拼接后的序列被转换成 Token ID,然后输入到 Transformer 模型中。模型通过自注意力机制和前馈网络,充分理解 prompt 和 response 之间的语义关联。

3. 输出:从"预测词"到"预测分"

这是 RM 与 SFT/LLM 最大的不同:

  • SFT/LLM 的头部 (LM Head):

    • 目标是预测下一个词。
    • 它的头部是一个巨大的线性层,将最后一个 Token 的隐藏状态(Hidden State)映射到整个词汇表(例如 50000 维)的 Logits,用于计算下一个词的概率。
  • RM 的头部 (Regression Head):

    • 目标是输出一个标量(一个数字)来代表"质量得分"。

    • 它会进行池化 (Pooling) :只取最后一个 Token (即 [EOS])的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h final h_{\text{final}} </math>hfinal(例如 4096 维),因为这个向量被认为编码了整个序列的语义。

    • 然后,它使用一个回归头(通常是一个简单的线性层)将这个高维向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> h final h_{\text{final}} </math>hfinal 映射为一个标量。

      <math xmlns="http://www.w3.org/1998/Math/MathML"> r = W head ⋅ h final + b head r = W_{\text{head}} \cdot h_{\text{final}} + b_{\text{head}} </math>r=Whead⋅hfinal+bhead

这个最终输出的标量 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r,就是 RM 对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , y ) (x, y) </math>(x,y) 对的奖励分数。

三、 RM 如何训练?Pairwise Ranking Loss 详解

我们已经知道,RM 的训练数据是人类的"排序",而不是"打分"。因为让人类在两个回答中选一个更好的("我更喜欢 A 而不是 B")远比给一个回答打绝对分数("这个回答是 7.5 分")要容易和准确。

RM 的训练目标是:模型给 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w y_w </math>yw(赢家)的分数 应该高于 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl(输家)的分数。

1. 损失函数公式

为了实现这个目标,RM 使用了 Pairwise Ranking Loss(成对排序损失),其公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) [ log ⁡ ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ) ] L(\theta) = - \frac{1}{\binom{K}{2}} \mathbb{E}{(x, y_w, y_l)} [\log(\sigma(r\theta(x, y_w) - r_\theta(x, y_l)))] </math>L(θ)=−(2K)1E(x,yw,yl)[log(σ(rθ(x,yw)−rθ(x,yl)))]

我们来拆解这个公式:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) L(\theta) </math>L(θ):我们要最小化的总损失, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 是 RM 的所有参数。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ( K 2 ) \binom{K}{2} </math>(2K):从 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 个排序好的响应中,可以抽出的"赢家-输家"配对总数。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , y w , y l ) (x, y_w, y_l) </math>(x,yw,yl):一个数据点,包含 Prompt <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x、更受偏好的响应 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w y_w </math>yw (winner) 和较差的响应 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl (loser)。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y w ) r_\theta(x, y_w) </math>rθ(x,yw) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y l ) r_\theta(x, y_l) </math>rθ(x,yl):RM 对赢家和输家(在 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 的上下文中)的打分。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y w ) − r θ ( x , y l ) r_\theta(x, y_w) - r_\theta(x, y_l) </math>rθ(x,yw)−rθ(x,yl):核心部分 。我们希望这个差值尽可能大(为正)。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( ⋅ ) \sigma(\cdot) </math>σ(⋅):Sigmoid 函数 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( z ) = 1 / ( 1 + e − z ) \sigma(z) = 1 / (1 + e^{-z}) </math>σ(z)=1/(1+e−z))。它将(-∞, +∞)的差值压缩到 (0, 1) 之间,可以被理解为"RM 认为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w y_w </math>yw 优于 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl 的概率"。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( ⋅ ) \log(\cdot) </math>log(⋅):对数似然。我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( ⋅ ) \sigma(\cdot) </math>σ(⋅) 趋近于 1(即 RM 100% 确定 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w > y l y_w > y_l </math>yw>yl),而 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( 1 ) = 0 \log(1) = 0 </math>log(1)=0。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> − ( ⋅ ) -(\cdot) </math>−(⋅):取负号。因为我们要最大化对数似然 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( σ ) \log(\sigma) </math>log(σ)(使其接近 0),所以我们要最小化它的相反数(使其从一个较大的负值接近 0)。

2. Pairwise Ranking Loss 计算举例

假设我们有 1 个 Prompt <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x,收集了 <math xmlns="http://www.w3.org/1998/Math/MathML"> K = 4 K=4 </math>K=4 个响应,人类排序为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 2 > y 3 > y 4 y_1 > y_2 > y_3 > y_4 </math>y1>y2>y3>y4。

步骤 1:找出所有偏好对

总共有 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 4 2 ) = 6 \binom{4}{2} = 6 </math>(24)=6 个偏好对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( y w , y l ) (y_w, y_l) </math>(yw,yl):

( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 2 y_1, y_2 </math>y1,y2), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 3 y_1, y_3 </math>y1,y3), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 4 y_1, y_4 </math>y1,y4), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 3 y_2, y_3 </math>y2,y3), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 4 y_2, y_4 </math>y2,y4), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 , y 4 y_3, y_4 </math>y3,y4)

步骤 2:获取当前 RM 的打分

我们将这 4 个 (prompt, response) 对输入 RM,假设(在一个训练批次中)模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 给出了以下分数:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 1 ) = 2.5 r_\theta(x, y_1) = 2.5 </math>rθ(x,y1)=2.5
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 2 ) = 1.9 r_\theta(x, y_2) = 1.9 </math>rθ(x,y2)=1.9
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 3 ) = 2.1 r_\theta(x, y_3) = 2.1 </math>rθ(x,y3)=2.1 <-- 注意!RM 搞错了!
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 4 ) = − 1.0 r_\theta(x, y_4) = -1.0 </math>rθ(x,y4)=−1.0

分析: 人类偏好 <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 > y 3 y_2 > y_3 </math>y2>y3,但 RM 却给出了 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1.9 < 2.1 1.9 < 2.1 </math>1.9<2.1 的分数。我们期望损失函数能够"惩罚"这个错误。

步骤 3:计算每个偏好对的 Loss

我们计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> L pair = log ⁡ ( σ ( r w − r l ) ) L_{\text{pair}} = \log(\sigma(r_w - r_l)) </math>Lpair=log(σ(rw−rl)):

偏好对 (w, l) 对应的人类偏好 RM 分数 rw​,rl​ 差值 (rw​−rl​) σ(差值) (概率) log(σ) (单对损失)
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 2 y_1, y_2 </math>y1,y2) <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 2 y_1 > y_2 </math>y1>y2 (2.5, 1.9) <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.6 0.6 </math>0.6 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.646 0.646 </math>0.646 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.437 -0.437 </math>−0.437
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 3 y_1, y_3 </math>y1,y3) <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 3 y_1 > y_3 </math>y1>y3 (2.5, 2.1) <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.4 0.4 </math>0.4 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.599 0.599 </math>0.599 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.512 -0.512 </math>−0.512
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 4 y_1, y_4 </math>y1,y4) <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 4 y_1 > y_4 </math>y1>y4 (2.5, -1.0) <math xmlns="http://www.w3.org/1998/Math/MathML"> 3.5 3.5 </math>3.5 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.971 0.971 </math>0.971 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.029 -0.029 </math>−0.029
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 3 y_2, y_3 </math>y2,y3) <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 > y 3 y_2 > y_3 </math>y2>y3 (1.9, 2.1) <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.2 -0.2 </math>−0.2 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.450 0.450 </math>0.450 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.798 -0.798 </math>−0.798
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 4 y_2, y_4 </math>y2,y4) <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 > y 4 y_2 > y_4 </math>y2>y4 (1.9, -1.0) <math xmlns="http://www.w3.org/1998/Math/MathML"> 2.9 2.9 </math>2.9 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.948 0.948 </math>0.948 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.053 -0.053 </math>−0.053
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 , y 4 y_3, y_4 </math>y3,y4) <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 > y 4 y_3 > y_4 </math>y3>y4 (2.1, -1.0) <math xmlns="http://www.w3.org/1998/Math/MathML"> 3.1 3.1 </math>3.1 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.957 0.957 </math>0.957 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.044 -0.044 </math>−0.044

观察:

  • 对于 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 3 y_2, y_3 </math>y2,y3) 这一对,由于 RM 搞错了,差值为负(-0.2)。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( − 0.2 ) ≈ 0.450 \sigma(-0.2) \approx 0.450 </math>σ(−0.2)≈0.450,这个概率小于 0.5,代表 RM 认为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl (即 <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 y_3 </math>y3) 赢的概率更高。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( 0.450 ) ≈ − 0.798 \log(0.450) \approx -0.798 </math>log(0.450)≈−0.798,这是一个很大的负数,因此它贡献了最大的"惩罚"(Loss)。
  • 对于 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 4 y_1, y_4 </math>y1,y4) 这种 RM 预测正确的对,差值很大 (3.5), <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( σ ) \log(\sigma) </math>log(σ) 接近 0,贡献的 Loss 很小。

步骤 4:计算最终的 Loss

  1. 求和 (Sum):

    <math xmlns="http://www.w3.org/1998/Math/MathML"> ( − 0.437 ) + ( − 0.512 ) + ( − 0.029 ) + ( − 0.798 ) + ( − 0.053 ) + ( − 0.044 ) = − 1.873 (-0.437) + (-0.512) + (-0.029) + (-0.798) + (-0.053) + (-0.044) = -1.873 </math>(−0.437)+(−0.512)+(−0.029)+(−0.798)+(−0.053)+(−0.044)=−1.873

  2. 取平均 (Divide by 6):

    <math xmlns="http://www.w3.org/1998/Math/MathML"> − 1.873 / 6 ≈ − 0.312 -1.873 / 6 \approx -0.312 </math>−1.873/6≈−0.312

  3. 取负 (Negate):

    <math xmlns="http://www.w3.org/1998/Math/MathML"> loss ( θ ) = − ( − 0.312 ) = 0.312 \text{loss}(\theta) = -(-0.312) = 0.312 </math>loss(θ)=−(−0.312)=0.312

这个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.312 0.312 </math>0.312 就是这个 batch 的最终损失值。在反向传播时,这个损失会促使模型参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 进行更新,目标是拉高 <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 2 ) r_\theta(x, y_2) </math>rθ(x,y2) 的分数压低 <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 3 ) r_\theta(x, y_3) </math>rθ(x,y3) 的分数 ,使 <math xmlns="http://www.w3.org/1998/Math/MathML"> r w − r l r_w - r_l </math>rw−rl 的差值变大,直到 Loss 趋近于 0。

总结

奖励模型 (RM) 是 RLHF 流程中承上启下的关键。它通过一个巧妙的 Pairwise Ranking Loss,将人类模糊、相对的"偏好排序"数据,转化为了一个可以被优化的、输出标量分数的神经网络。

这个 RM 的质量,直接决定了 PPO 阶段强化学习的天花板。一个能准确理解人类偏好的 RM,是训练出"有用且无害"的 AI 助手的真正指南。

相关推荐
kk_net88993 小时前
PyTorch Geometric 图神经网络实战利器
人工智能·pytorch·神经网络·其他
新智元3 小时前
只要强化学习 1/10 成本!翁荔的 Thinking Machines 盯上了 Qwen 的黑科技
人工智能·openai
No.Ada3 小时前
基于脑电图(EEG)的认知负荷检测实验范式与深度神经网络的系统综述 论文笔记
论文阅读·人工智能·dnn
CV视觉3 小时前
智能体综述:探索基于大型语言模型的智能体:定义、方法与前景
人工智能·语言模型·chatgpt·stable diffusion·prompt·aigc·agi
新智元3 小时前
90 后王虹连夺两大「菲尔兹奖」风向标!韦神都来听她讲课,陶哲轩盛赞
人工智能·openai
MicroTech20254 小时前
微算法科技(NASDAQ MLGO)探索自适应差分隐私机制(如AdaDP),根据任务复杂度动态调整噪声
人工智能·科技·算法
预测模型的开发与应用研究5 小时前
从入门到实操:贝叶斯分析完整技术步骤与核心R包指南
开发语言·人工智能·r语言
TaoSense5 小时前
Milvus向量数据库介绍
大数据·人工智能
海森大数据5 小时前
AI突破“化学空间困境”:一场药物设计的范式革命
人工智能·语言模型