DeepSpeed-Chat RLHF 阶段代码解读(2) —— PPO 阶段

数据处理

这里和 DeepSpeed-Chat RLHF 阶段代码解读(1) ------ 奖励函数阶段 - 掘金 (juejin.cn) 处理基本一致,唯一的区别是输入不是 prompt + response,而是只有 prompt,response 靠 actor model 生成。

PPO

初始化

PPO 训练要进行模型初始化,这里一共四个模型: Actor model、Reference model: 开始的时候这两个模型是一样的,用途是一样的。 Critic model、Reward model: 开始的时候这两个模型是一样的,但是用途是不一样的,一个是用来产生 critic value,一个是用来产生 reward 的,虽然结构是一样的。另外,critic model 会随着 ppo 训练更新,但是在 ppo 阶段,Reward model 是不变的,比较抽象。

Rewards

Token Level KL-Penalty

python 复制代码
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                    action_mask):

    kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
    rewards = kl_divergence_estimate
    start = prompts.shape[1] - 1
    ends = start + action_mask[:, start:].sum(1) + 1
    reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                self.clip_reward_value)
    batch_size = log_probs.shape[0]
    for j in range(batch_size):
        rewards[j, start:ends[j]][-1] += reward_clip[j]

    return rewards

这里按照参考 4 里说的,除了 eos token 对应的 responese 的 reward score,对其余的每个时间步都增加了一个正则项,因为正则项的格式就是 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn) 提到的 advantage 项为负的情况一模一样,只是这里不涉及 loss 的计算。因此,这里的目的:新策略和之前的策略不一致,增加探索,得到负的 KL 散度,从而提高奖励。

GAE

python 复制代码
def get_advantages_and_returns(self, values, rewards, start):
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    # 逆序 advantage, 这样就按时间步顺序得到每个时间步的 advantage
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

可以看到每个 rlhf 里的实现,和 DeepSpeed-Chat RLHF 阶段代码解读(0) ------ 原始 PPO 代码解读 - 掘金 (juejin.cn)的实现,没有本质上的区别。

Loss

python 复制代码
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss
  • critic loss
python 复制代码
def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    if self.compute_fp32_loss:
        values = values.float()
        values_clipped = values_clipped.float()
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    # 这是损失函数的核心部分。首先,计算vf_loss1和vf_loss2中较大的那个值,然后乘以mask。
    # 这样做是为了只考虑有效的样本(由mask指示)。然后,取这个乘积的总和,除以mask中有效样本的数量(mask.sum()),得到平均损失。
    # 选取更大的 loss 是为了增加探索,防止过于乐观的估计。
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

RLHF 整体的流程

结合之前的文章,以及本篇文章的数据处理和 PPO 章节,相信读者对 RLHF 无论是原理和代码都有了一定的理解,这里再从整体梳理下使用 PPO 进行 RLHF 的流程。

step1 prompt 输入 actor model 得到 response step2(重要性采样): prompt + response 分别输入到 actor_model 和 reference model 得到 log_probs、ref_log_probs、reward_score、values,这部分的数据可以重复利用。 step3 计算 critic_loss、actor_loss,更新 actor_model。 大致上,ppo 主要就这个三个步骤。

整个流程下来,我的感觉,很繁琐,难训练,所以目前主流大模型很少使用原始的这套 RLHF 流程,更多使用 dpo 算法,而且 RLHF 的数据有限,很难对所有的 response 有一个公平的 rewar,所以下一个系列文章会介绍利用 dpo 的 RLAIF 算法,如 SPIN、self-reward etc。欢迎关注。

参考

  1. Negative KL-divergence RLHF implementation · Issue #736 · huggingface/trl (github.com)[2307.04964] Secrets of RLHF in Large Language Models Part I: PPO (arxiv.org)
  2. 关于ppo阶段,reward分数计算的问题 · Issue #26 · OpenLMLab/MOSS-RLHF (github.com)
  3. 2009.01325v3.pdf (arxiv.org)
相关推荐
仙人掌_lz3 天前
深度理解用于多智能体强化学习的单调价值函数分解QMIX算法:基于python从零实现
python·算法·强化学习·rl·价值函数
Mr.Winter`4 天前
深度强化学习 | 图文详细推导软性演员-评论家SAC算法原理
人工智能·深度学习·神经网络·机器学习·数据挖掘·机器人·强化学习
IT猿手4 天前
基于强化学习 Q-learning 算法求解城市场景下无人机三维路径规划研究,提供完整MATLAB代码
神经网络·算法·matlab·人机交互·无人机·强化学习·无人机三维路径规划
仙人掌_lz6 天前
理解多智能体深度确定性策略梯度MADDPG算法:基于python从零实现
python·算法·强化学习·策略梯度·rl
仙人掌_lz7 天前
深入理解深度Q网络DQN:基于python从零实现
python·算法·强化学习·dqn·rl
IT猿手7 天前
基于 Q-learning 的城市场景无人机三维路径规划算法研究,可以自定义地图,提供完整MATLAB代码
深度学习·算法·matlab·无人机·强化学习·qlearning·无人机路径规划
Two summers ago8 天前
arXiv2025 | TTRL: Test-Time Reinforcement Learning
论文阅读·人工智能·机器学习·llm·强化学习
仙人掌_lz9 天前
为特定领域微调嵌入模型:打造专属的自然语言处理利器
人工智能·ai·自然语言处理·embedding·强化学习·rl·bge
碣石潇湘无限路10 天前
【AI】基于生活案例的LLM强化学习(入门帖)
人工智能·经验分享·笔记·生活·openai·强化学习
人类发明了工具11 天前
【强化学习】强化学习算法 - 多臂老虎机问题
机器学习·强化学习·多臂老虎机