GPT-RLHF :深入解析奖励模型 (Reward Model)
从 ChatGPT 到 Claude,再到各种开源大模型,它们惊艳表现的背后,除了不断增长的参数规模,还有一个至关重要的技术------基于人类反馈的强化学习 (RLHF) 。
RLHF 是一套复杂的流程,它旨在让大语言模型(LLM)的输出更符合人类的偏好、价值观,使其变得"有用且无害"。而在这套流程中,奖励模型 (Reward Model, RM) 扮演着"人类偏好代理人"的关键角色。它就像一个指南针,为强化学习阶段的 LLM 指明了优化的方向。
本文将深入探讨奖励模型 (RM):从它在 RLHF 中的位置开始,解析其模型结构,并详细举例说明其核心训练算法------Pairwise Ranking Loss。
一、 RLHF 的"三步走"战略
要理解 RM,首先要明白它在整个 RLHF 流程中的位置。RLHF 通常分为三个阶段:
-
阶段 1: SFT (监督微调)
- 目标: 让预训练的 LLM 学会模仿人类的回答方式。
- 做法: 收集少量高质量的"标注者示范数据"((prompt, response) 对),在预训练模型的基础上进行监督微调(Supervised Fine-Tuning)。
- 产出: SFT 模型。这个模型已经能按指令回答问题,但其回答的质量、安全性和有用性还不稳定。
-
阶段 2: 训练奖励模型 (RM)
- 目标: 训练一个模型,使其能够模拟人类的偏好,为任何 (prompt, response) 对打分。
- 做法: 拿一个 SFT 模型,对同一个 prompt 生成 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 个不同的回答。人类标注者对这 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 个回答进行排序(从最好到最差)。RM 的任务不是预测一个绝对分数,而是学习这个"排序"。
- 产出: RM 模型。这是我们本文的主角。
-
阶段 3: PPO (强化学习)
- 目标: 使用 RM 作为"奖励函数",通过强化学习(如 PPO 算法)来优化 SFT 模型,使其生成的回答能获得更高的 RM 分数。
- 做法: SFT 模型(此时称为策略模型)生成回答,RM 为这个回答打分(奖励)。PPO 算法根据这个奖励来更新策略模型的参数,使其更倾向于生成高分回答。
- 产出: 最终的、经过"对齐"的 RLHF 模型。
二、 奖励模型 (RM) 的"庐山真面目":模型结构
一个常见的误解是:奖励模型是一个全新的、神秘的模型。
事实是:奖励模型 (RM) 通常与 SFT 模型使用相同(或相近)的 Transformer 架构。 它们的主要区别不在于"骨架",而在于"头部"------即模型的最后一层如何输出结果。
1. 输入:Prompt 和 Response 的拼接
RM 的任务是判断一个 response <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 对于一个 prompt <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 来说有多好。因此,它需要同时接收 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 作为输入。
-
输入序列: RM 会将 prompt 和 response 拼接成一个单独的文本序列,并用特殊的分隔符(如
[SEP])隔开,最后加上一个结束符(如[EOS])。[Prompt 文本] [SEP] [Response 文本] [EOS]
2. 模型主体:Transformer 编码器
这个拼接后的序列被转换成 Token ID,然后输入到 Transformer 模型中。模型通过自注意力机制和前馈网络,充分理解 prompt 和 response 之间的语义关联。
3. 输出:从"预测词"到"预测分"
这是 RM 与 SFT/LLM 最大的不同:
-
SFT/LLM 的头部 (LM Head):
- 目标是预测下一个词。
- 它的头部是一个巨大的线性层,将最后一个 Token 的隐藏状态(Hidden State)映射到整个词汇表(例如 50000 维)的 Logits,用于计算下一个词的概率。
-
RM 的头部 (Regression Head):
-
目标是输出一个标量(一个数字)来代表"质量得分"。
-
它会进行池化 (Pooling) :只取最后一个 Token (即
[EOS])的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h final h_{\text{final}} </math>hfinal(例如 4096 维),因为这个向量被认为编码了整个序列的语义。 -
然后,它使用一个回归头(通常是一个简单的线性层)将这个高维向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> h final h_{\text{final}} </math>hfinal 映射为一个标量。
<math xmlns="http://www.w3.org/1998/Math/MathML"> r = W head ⋅ h final + b head r = W_{\text{head}} \cdot h_{\text{final}} + b_{\text{head}} </math>r=Whead⋅hfinal+bhead
-
这个最终输出的标量 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r,就是 RM 对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , y ) (x, y) </math>(x,y) 对的奖励分数。
三、 RM 如何训练?Pairwise Ranking Loss 详解
我们已经知道,RM 的训练数据是人类的"排序",而不是"打分"。因为让人类在两个回答中选一个更好的("我更喜欢 A 而不是 B")远比给一个回答打绝对分数("这个回答是 7.5 分")要容易和准确。
RM 的训练目标是:模型给 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w y_w </math>yw(赢家)的分数 应该高于 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl(输家)的分数。
1. 损失函数公式
为了实现这个目标,RM 使用了 Pairwise Ranking Loss(成对排序损失),其公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) [ log ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ) ] L(\theta) = - \frac{1}{\binom{K}{2}} \mathbb{E}{(x, y_w, y_l)} [\log(\sigma(r\theta(x, y_w) - r_\theta(x, y_l)))] </math>L(θ)=−(2K)1E(x,yw,yl)[log(σ(rθ(x,yw)−rθ(x,yl)))]
我们来拆解这个公式:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) L(\theta) </math>L(θ):我们要最小化的总损失, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 是 RM 的所有参数。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ( K 2 ) \binom{K}{2} </math>(2K):从 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 个排序好的响应中,可以抽出的"赢家-输家"配对总数。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , y w , y l ) (x, y_w, y_l) </math>(x,yw,yl):一个数据点,包含 Prompt <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x、更受偏好的响应 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w y_w </math>yw (winner) 和较差的响应 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl (loser)。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y w ) r_\theta(x, y_w) </math>rθ(x,yw) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y l ) r_\theta(x, y_l) </math>rθ(x,yl):RM 对赢家和输家(在 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 的上下文中)的打分。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y w ) − r θ ( x , y l ) r_\theta(x, y_w) - r_\theta(x, y_l) </math>rθ(x,yw)−rθ(x,yl):核心部分 。我们希望这个差值尽可能大(为正)。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( ⋅ ) \sigma(\cdot) </math>σ(⋅):Sigmoid 函数 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( z ) = 1 / ( 1 + e − z ) \sigma(z) = 1 / (1 + e^{-z}) </math>σ(z)=1/(1+e−z))。它将(-∞, +∞)的差值压缩到 (0, 1) 之间,可以被理解为"RM 认为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w y_w </math>yw 优于 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl 的概率"。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> log ( ⋅ ) \log(\cdot) </math>log(⋅):对数似然。我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( ⋅ ) \sigma(\cdot) </math>σ(⋅) 趋近于 1(即 RM 100% 确定 <math xmlns="http://www.w3.org/1998/Math/MathML"> y w > y l y_w > y_l </math>yw>yl),而 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ( 1 ) = 0 \log(1) = 0 </math>log(1)=0。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> − ( ⋅ ) -(\cdot) </math>−(⋅):取负号。因为我们要最大化对数似然 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ( σ ) \log(\sigma) </math>log(σ)(使其接近 0),所以我们要最小化它的相反数(使其从一个较大的负值接近 0)。
2. Pairwise Ranking Loss 计算举例
假设我们有 1 个 Prompt <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x,收集了 <math xmlns="http://www.w3.org/1998/Math/MathML"> K = 4 K=4 </math>K=4 个响应,人类排序为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 2 > y 3 > y 4 y_1 > y_2 > y_3 > y_4 </math>y1>y2>y3>y4。
步骤 1:找出所有偏好对
总共有 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 4 2 ) = 6 \binom{4}{2} = 6 </math>(24)=6 个偏好对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( y w , y l ) (y_w, y_l) </math>(yw,yl):
( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 2 y_1, y_2 </math>y1,y2), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 3 y_1, y_3 </math>y1,y3), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 4 y_1, y_4 </math>y1,y4), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 3 y_2, y_3 </math>y2,y3), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 4 y_2, y_4 </math>y2,y4), ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 , y 4 y_3, y_4 </math>y3,y4)
步骤 2:获取当前 RM 的打分
我们将这 4 个 (prompt, response) 对输入 RM,假设(在一个训练批次中)模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 给出了以下分数:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 1 ) = 2.5 r_\theta(x, y_1) = 2.5 </math>rθ(x,y1)=2.5
- <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 2 ) = 1.9 r_\theta(x, y_2) = 1.9 </math>rθ(x,y2)=1.9
- <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 3 ) = 2.1 r_\theta(x, y_3) = 2.1 </math>rθ(x,y3)=2.1 <-- 注意!RM 搞错了!
- <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 4 ) = − 1.0 r_\theta(x, y_4) = -1.0 </math>rθ(x,y4)=−1.0
分析: 人类偏好 <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 > y 3 y_2 > y_3 </math>y2>y3,但 RM 却给出了 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1.9 < 2.1 1.9 < 2.1 </math>1.9<2.1 的分数。我们期望损失函数能够"惩罚"这个错误。
步骤 3:计算每个偏好对的 Loss
我们计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> L pair = log ( σ ( r w − r l ) ) L_{\text{pair}} = \log(\sigma(r_w - r_l)) </math>Lpair=log(σ(rw−rl)):
| 偏好对 (w, l) | 对应的人类偏好 | RM 分数 rw,rl | 差值 (rw−rl) | σ(差值) (概率) | log(σ) (单对损失) |
|---|---|---|---|---|---|
| ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 2 y_1, y_2 </math>y1,y2) | <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 2 y_1 > y_2 </math>y1>y2 | (2.5, 1.9) | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.6 0.6 </math>0.6 | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.646 0.646 </math>0.646 | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.437 -0.437 </math>−0.437 |
| ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 3 y_1, y_3 </math>y1,y3) | <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 3 y_1 > y_3 </math>y1>y3 | (2.5, 2.1) | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.4 0.4 </math>0.4 | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.599 0.599 </math>0.599 | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.512 -0.512 </math>−0.512 |
| ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 4 y_1, y_4 </math>y1,y4) | <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 > y 4 y_1 > y_4 </math>y1>y4 | (2.5, -1.0) | <math xmlns="http://www.w3.org/1998/Math/MathML"> 3.5 3.5 </math>3.5 | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.971 0.971 </math>0.971 | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.029 -0.029 </math>−0.029 |
| ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 3 y_2, y_3 </math>y2,y3) | <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 > y 3 y_2 > y_3 </math>y2>y3 | (1.9, 2.1) | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.2 -0.2 </math>−0.2 | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.450 0.450 </math>0.450 | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.798 -0.798 </math>−0.798 |
| ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 4 y_2, y_4 </math>y2,y4) | <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 > y 4 y_2 > y_4 </math>y2>y4 | (1.9, -1.0) | <math xmlns="http://www.w3.org/1998/Math/MathML"> 2.9 2.9 </math>2.9 | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.948 0.948 </math>0.948 | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.053 -0.053 </math>−0.053 |
| ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 , y 4 y_3, y_4 </math>y3,y4) | <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 > y 4 y_3 > y_4 </math>y3>y4 | (2.1, -1.0) | <math xmlns="http://www.w3.org/1998/Math/MathML"> 3.1 3.1 </math>3.1 | <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.957 0.957 </math>0.957 | <math xmlns="http://www.w3.org/1998/Math/MathML"> − 0.044 -0.044 </math>−0.044 |
观察:
- 对于 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 2 , y 3 y_2, y_3 </math>y2,y3) 这一对,由于 RM 搞错了,差值为负(-0.2)。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( − 0.2 ) ≈ 0.450 \sigma(-0.2) \approx 0.450 </math>σ(−0.2)≈0.450,这个概率小于 0.5,代表 RM 认为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y l y_l </math>yl (即 <math xmlns="http://www.w3.org/1998/Math/MathML"> y 3 y_3 </math>y3) 赢的概率更高。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> log ( 0.450 ) ≈ − 0.798 \log(0.450) \approx -0.798 </math>log(0.450)≈−0.798,这是一个很大的负数,因此它贡献了最大的"惩罚"(Loss)。
- 对于 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> y 1 , y 4 y_1, y_4 </math>y1,y4) 这种 RM 预测正确的对,差值很大 (3.5), <math xmlns="http://www.w3.org/1998/Math/MathML"> log ( σ ) \log(\sigma) </math>log(σ) 接近 0,贡献的 Loss 很小。
步骤 4:计算最终的 Loss
-
求和 (Sum):
<math xmlns="http://www.w3.org/1998/Math/MathML"> ( − 0.437 ) + ( − 0.512 ) + ( − 0.029 ) + ( − 0.798 ) + ( − 0.053 ) + ( − 0.044 ) = − 1.873 (-0.437) + (-0.512) + (-0.029) + (-0.798) + (-0.053) + (-0.044) = -1.873 </math>(−0.437)+(−0.512)+(−0.029)+(−0.798)+(−0.053)+(−0.044)=−1.873
-
取平均 (Divide by 6):
<math xmlns="http://www.w3.org/1998/Math/MathML"> − 1.873 / 6 ≈ − 0.312 -1.873 / 6 \approx -0.312 </math>−1.873/6≈−0.312
-
取负 (Negate):
<math xmlns="http://www.w3.org/1998/Math/MathML"> loss ( θ ) = − ( − 0.312 ) = 0.312 \text{loss}(\theta) = -(-0.312) = 0.312 </math>loss(θ)=−(−0.312)=0.312
这个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.312 0.312 </math>0.312 就是这个 batch 的最终损失值。在反向传播时,这个损失会促使模型参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 进行更新,目标是拉高 <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 2 ) r_\theta(x, y_2) </math>rθ(x,y2) 的分数 并压低 <math xmlns="http://www.w3.org/1998/Math/MathML"> r θ ( x , y 3 ) r_\theta(x, y_3) </math>rθ(x,y3) 的分数 ,使 <math xmlns="http://www.w3.org/1998/Math/MathML"> r w − r l r_w - r_l </math>rw−rl 的差值变大,直到 Loss 趋近于 0。
总结
奖励模型 (RM) 是 RLHF 流程中承上启下的关键。它通过一个巧妙的 Pairwise Ranking Loss,将人类模糊、相对的"偏好排序"数据,转化为了一个可以被优化的、输出标量分数的神经网络。
这个 RM 的质量,直接决定了 PPO 阶段强化学习的天花板。一个能准确理解人类偏好的 RM,是训练出"有用且无害"的 AI 助手的真正指南。