PPO/GRPO算法在RLHF中的实现

PPO(Proximal Policy Optimization)

策略函数优化目标
JPPO(θ)=Eq∼P(Q), o∼πθold(O∣q)[1∣o∣∑t=1∣o∣min⁡(πθ(ot∣q,o<t)πθold(ot∣q,o<t)At,  clip⁡ ⁣(πθ(ot∣q,o<t)πθold(ot∣q,o<t),1−ϵ, 1+ϵ)At)] \mathcal{J}{\mathrm{PPO}}(\theta)=\mathbb{E}{q \sim P(Q),\, o \sim \pi_{\theta_{\text{old}}}(O \mid q)} \left[ \frac{1}{|o|} \sum_{t=1}^{|o|} \min \left( \frac{\pi_{\theta}(o_t \mid q, o_{<t})} {\pi_{\theta_{\text{old}}}(o_t \mid q, o_{<t})} A_t,\; \operatorname{clip}\!\left( \frac{\pi_{\theta}(o_t \mid q, o_{<t})} {\pi_{\theta_{\text{old}}}(o_t \mid q, o_{<t})}, 1-\epsilon,\, 1+\epsilon \right) A_t \right) \right] JPPO(θ)=Eq∼P(Q),o∼πθold(O∣q) ∣o∣1t=1∑∣o∣min(πθold(ot∣q,o<t)πθ(ot∣q,o<t)At,clip(πθold(ot∣q,o<t)πθ(ot∣q,o<t),1−ϵ,1+ϵ)At)

其中 AtA_tAt 为优势函数(Advantage Function),用于度量每步动作(在LLM中,动作ata_tat即为第ttt个token oto_tot)相对于平均水平的优劣。

优势函数与GAE(Generalized Advantage Estimation)

优势函数的标准定义为:
A(st,at)=Q(st,at)−V(st) A(s_t, a_t) = Q(s_t, a_t) - V(s_t) A(st,at)=Q(st,at)−V(st)

在实践中,真实的 Q(s,a)Q(s, a)Q(s,a) 与 V(s)V(s)V(s) 不可知,需从样本轨迹中估计。GAE 是一种高效的估计方法,其核心是 TD误差(Temporal-Difference Error)
δt=rt+γV(st+1)−V(st) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st)

其中 rtr_trt 为时刻 ttt 的奖励,γ\gammaγ 为折扣因子,V(st)V(s_t)V(st) 为状态值函数估计。

基于TD误差,GAE通过以下递推公式计算优势函数:
AtGAE(γ,λ)=δt+γλAt+1GAE(γ,λ) A_t^{\text{GAE}(\gamma, \lambda)} = \delta_t + \gamma \lambda A_{t+1}^{\text{GAE}(\gamma, \lambda)} AtGAE(γ,λ)=δt+γλAt+1GAE(γ,λ)

其代码实现如下:

python 复制代码
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
    delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
    lastgaelam = delta + args.gamma * args.lam * lastgaelam
    advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values

奖励函数设计

在语言模型中,每个token oto_tot 的即时奖励 rtr_trt 由两部分构成:
rt=rφ(q,o≤t)−βlog⁡πθ(ot∣q,o<t)πref(ot∣q,o<t) r_t = r_\varphi(q, o_{\le t}) - \beta \log \frac{\pi_\theta(o_t|q, o_{<t})}{\pi_{\text{ref}}(o_t|q, o_{<t})} rt=rφ(q,o≤t)−βlogπref(ot∣q,o<t)πθ(ot∣q,o<t)

其中 rφr_\varphirφ 为奖励模型的输出,πref\pi_{\text{ref}}πref 为参考模型(初始SFT模型),β\betaβ 为KL散度系数。实际实现中,通常只在序列的最后一个词元(EOS 处赋予奖励模型给出的分数,中间词元的奖励主要为KL惩罚。以下是trl库中的实现:

python 复制代码
logr = ref_logprobs - logprobs
kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr  # Else statement is k3
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[actual_start, actual_end] += scores  

价值函数更新

Critic网络的目标是逼近状态的真实价值,其更新目标(Returns)为 Gt=At+V(st)G_t = A_t + V(s_t)Gt=At+V(st),即:
V(st)=Eat∼π[Q(st,at)]=Eat∼π[A(st,at)+V(st)] V(s_t) = \mathbb{E}{a_t \sim \pi} [Q(s_t, a_t)] = \mathbb{E}{a_t \sim \pi} [A(s_t, a_t) + V(s_t)] V(st)=Eat∼π[Q(st,at)]=Eat∼π[A(st,at)+V(st)]

在实践中,使用 returns = advantages + values 作为目标值进行回归,以更新价值函数网络。


GRPO(Group Relative Policy Optimization)

GRPO是PPO的一种变体,主要省去了价值函数网络,简化了优势函数的计算。

策略函数优化目标
JGRPO(θ)=Eq∼P(Q), {oi}i=1G∼πθold(O∣q)1G∑i=1G1∣oi∣∑t=1∣oi∣{min⁡[πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t)A^i,t,  clip⁡ ⁣(πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t),1−ϵ, 1+ϵ)A^i,t]−β DKL ⁣[πθ ∥ πref]} \mathcal{J}{\mathrm{GRPO}}(\theta)=\mathbb{E}{q \sim P(Q),\, \{o_i\}{i=1}^G\sim \pi{\theta_{\text{old}}}(O \mid q)} \frac{1}{G}\sum_{i=1}^{G}\frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left\{ \min \left[ \frac{\pi_{\theta}(o_{i,t} \mid q, o_{i, <t})} {\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i, <t})} \hat{A}{i,t},\; \operatorname{clip}\!\left( \frac{\pi{\theta}(o_{i,t} \mid q, o_{i, <t})} {\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i, <t})}, 1-\epsilon,\, 1+\epsilon \right) \hat{A}{i,t} \right] -\beta \, \mathbb{D}{\mathrm{KL}}\!\left[ \pi_\theta \,\|\, \pi_{\mathrm{ref}} \right] \right\} JGRPO(θ)=Eq∼P(Q),{oi}i=1G∼πθold(O∣q)G1i=1∑G∣oi∣1t=1∑∣oi∣{min[πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)A^i,t,clip(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t),1−ϵ,1+ϵ)A^i,t]−βDKL[πθ∥πref]}

优势函数计算

GRPO的优势函数计算更为简单,直接对组内轨迹的累计奖励进行归一化
A^i,t=r~i=ri−mean⁡(r)std⁡(r) \hat{A}_{i,t}=\tilde{r}_i=\frac{r_i - \operatorname{mean}(r)}{\operatorname{std}(r)} A^i,t=r~i=std(r)ri−mean(r)

优势函数移除了PPO算法中的KL惩罚项。
KL散度估计与作用

为防止策略过度偏离原始的SFT模型,目标函数中显式加入了KL惩罚项 β DKL[πθ ∥ πref]\beta \, \mathbb{D}{\mathrm{KL}}[\pi\theta \,\|\, \pi_{\mathrm{ref}}]βDKL[πθ∥πref]。其无偏估计为:
DKL ⁣(πθ ∥ πref)≈πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−log⁡πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−1 \mathbb{D}{\mathrm{KL}}\!\left( \pi\theta \,\|\, \pi_{\mathrm{ref}} \right) \approx \frac{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1 DKL(πθ∥πref)≈πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−logπθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−1

总结

在GRPO框架中:

  • πref\pi_{\mathrm{ref}}πref 是固定的初始SFT模型,用于提供稳定性约束,防止策略"遗忘"初始知识。
  • πθold\pi_{\theta_{\text{old}}}πθold 是上一次策略更新前的参数,用于计算重要性采样比率和实现裁剪(Clip),以控制单步更新幅度。
  • Clip机制 通过限制比率防止与上一步策略偏离太远,保证更新稳定性。
  • KL惩罚项 则约束策略不偏离初始SFT模型太远,保持生成质量与安全性。

这种设计使GRPO在简化价值函数的同时,通过组内归一化和双重约束(Clip与KL)维持了训练的稳定与高效。

相关推荐
leoufung2 小时前
Word Break:深度理解 DP 前缀结束点的核心思想
算法·word·动态规划
Aaron15882 小时前
三种主流接收机架构(超外差、零中频、射频直采)对比及发展趋势浅析
c语言·人工智能·算法·fpga开发·架构·硬件架构·信号处理
乐迪信息5 小时前
乐迪信息:目标检测算法+AI摄像机:煤矿全场景识别方案
人工智能·物联网·算法·目标检测·目标跟踪·语音识别
前端小L11 小时前
贪心算法专题(十):维度权衡的艺术——「根据身高重建队列」
javascript·算法·贪心算法
方得一笔11 小时前
自定义常用的字符串函数(strlen,strcpy,strcmp,strcat)
算法
Xの哲學11 小时前
Linux SMP 实现机制深度剖析
linux·服务器·网络·算法·边缘计算
wuk99812 小时前
使用PCA算法进行故障诊断的MATLAB仿真
算法·matlab
额呃呃12 小时前
二分查找细节理解
数据结构·算法
无尽的罚坐人生12 小时前
hot 100 283. 移动零
数据结构·算法·双指针