DeepSpeed-Chat RLHF 阶段代码解读(2) —— PPO 阶段

数据处理

这里和 DeepSpeed-Chat RLHF 阶段代码解读(1) ------ 奖励函数阶段 - 掘金 (juejin.cn) 处理基本一致,唯一的区别是输入不是 prompt + response,而是只有 prompt,response 靠 actor model 生成。

PPO

初始化

PPO 训练要进行模型初始化,这里一共四个模型: Actor model、Reference model: 开始的时候这两个模型是一样的,用途是一样的。 Critic model、Reward model: 开始的时候这两个模型是一样的,但是用途是不一样的,一个是用来产生 critic value,一个是用来产生 reward 的,虽然结构是一样的。另外,critic model 会随着 ppo 训练更新,但是在 ppo 阶段,Reward model 是不变的,比较抽象。

Rewards

Token Level KL-Penalty

python 复制代码
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                    action_mask):

    kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
    rewards = kl_divergence_estimate
    start = prompts.shape[1] - 1
    ends = start + action_mask[:, start:].sum(1) + 1
    reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                self.clip_reward_value)
    batch_size = log_probs.shape[0]
    for j in range(batch_size):
        rewards[j, start:ends[j]][-1] += reward_clip[j]

    return rewards

这里按照参考 4 里说的,除了 eos token 对应的 responese 的 reward score,对其余的每个时间步都增加了一个正则项,因为正则项的格式就是 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn) 提到的 advantage 项为负的情况一模一样,只是这里不涉及 loss 的计算。因此,这里的目的:新策略和之前的策略不一致,增加探索,得到负的 KL 散度,从而提高奖励。

GAE

python 复制代码
def get_advantages_and_returns(self, values, rewards, start):
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    # 逆序 advantage, 这样就按时间步顺序得到每个时间步的 advantage
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

可以看到每个 rlhf 里的实现,和 DeepSpeed-Chat RLHF 阶段代码解读(0) ------ 原始 PPO 代码解读 - 掘金 (juejin.cn)的实现,没有本质上的区别。

Loss

python 复制代码
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss
  • critic loss
python 复制代码
def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    if self.compute_fp32_loss:
        values = values.float()
        values_clipped = values_clipped.float()
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    # 这是损失函数的核心部分。首先,计算vf_loss1和vf_loss2中较大的那个值,然后乘以mask。
    # 这样做是为了只考虑有效的样本(由mask指示)。然后,取这个乘积的总和,除以mask中有效样本的数量(mask.sum()),得到平均损失。
    # 选取更大的 loss 是为了增加探索,防止过于乐观的估计。
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

RLHF 整体的流程

结合之前的文章,以及本篇文章的数据处理和 PPO 章节,相信读者对 RLHF 无论是原理和代码都有了一定的理解,这里再从整体梳理下使用 PPO 进行 RLHF 的流程。

step1 prompt 输入 actor model 得到 response step2(重要性采样): prompt + response 分别输入到 actor_model 和 reference model 得到 log_probs、ref_log_probs、reward_score、values,这部分的数据可以重复利用。 step3 计算 critic_loss、actor_loss,更新 actor_model。 大致上,ppo 主要就这个三个步骤。

整个流程下来,我的感觉,很繁琐,难训练,所以目前主流大模型很少使用原始的这套 RLHF 流程,更多使用 dpo 算法,而且 RLHF 的数据有限,很难对所有的 response 有一个公平的 rewar,所以下一个系列文章会介绍利用 dpo 的 RLAIF 算法,如 SPIN、self-reward etc。欢迎关注。

参考

  1. Negative KL-divergence RLHF implementation · Issue #736 · huggingface/trl (github.com)[2307.04964] Secrets of RLHF in Large Language Models Part I: PPO (arxiv.org)
  2. 关于ppo阶段,reward分数计算的问题 · Issue #26 · OpenLMLab/MOSS-RLHF (github.com)
  3. 2009.01325v3.pdf (arxiv.org)
相关推荐
张三不嚣张7 小时前
PPO(近端策略优化)算法基本原理
人工智能·算法·强化学习·游戏策划
IT猿手4 天前
最新高性能多目标优化算法:多目标麋鹿优化算法(MOEHO)求解GLSMOP1-GLSMOP9及工程应用---盘式制动器设计,提供完整MATLAB代码
开发语言·算法·机器学习·matlab·强化学习
凳子花❀5 天前
强化学习与深度学习以及相关芯片之间的区别
人工智能·深度学习·神经网络·ai·强化学习
我爱C编程5 天前
基于Qlearning强化学习的机器人路线规划matlab仿真
matlab·机器人·强化学习·路线规划·qlearning·机器人路线规划
IT猿手7 天前
基于PWLCM混沌映射的麋鹿群优化算法(Elk herd optimizer,EHO)的多无人机协同路径规划,MATLAB代码
算法·elk·机器学习·matlab·无人机·聚类·强化学习
IT古董13 天前
【机器学习】机器学习的基本分类-强化学习(Reinforcement Learning, RL)
人工智能·机器学习·分类·强化学习
smartcat201015 天前
PPO系列3 - PPO原理
强化学习
IT猿手15 天前
强化学习路径规划:基于SARSA算法的移动机器人路径规划,可以更改地图大小及起始点,可以自定义障碍物,MATLAB代码
android·算法·机器学习·matlab·迁移学习·强化学习·多目标优化
smartcat201015 天前
PPO系列4 - Reward模型训练
强化学习
不去幼儿园16 天前
【强化学习】策略梯度---REINFORCE算法
人工智能·python·算法·机器学习·强化学习