DeepSpeed-Chat RLHF 阶段代码解读(2) —— PPO 阶段

数据处理

这里和 DeepSpeed-Chat RLHF 阶段代码解读(1) ------ 奖励函数阶段 - 掘金 (juejin.cn) 处理基本一致,唯一的区别是输入不是 prompt + response,而是只有 prompt,response 靠 actor model 生成。

PPO

初始化

PPO 训练要进行模型初始化,这里一共四个模型: Actor model、Reference model: 开始的时候这两个模型是一样的,用途是一样的。 Critic model、Reward model: 开始的时候这两个模型是一样的,但是用途是不一样的,一个是用来产生 critic value,一个是用来产生 reward 的,虽然结构是一样的。另外,critic model 会随着 ppo 训练更新,但是在 ppo 阶段,Reward model 是不变的,比较抽象。

Rewards

Token Level KL-Penalty

python 复制代码
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                    action_mask):

    kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
    rewards = kl_divergence_estimate
    start = prompts.shape[1] - 1
    ends = start + action_mask[:, start:].sum(1) + 1
    reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                self.clip_reward_value)
    batch_size = log_probs.shape[0]
    for j in range(batch_size):
        rewards[j, start:ends[j]][-1] += reward_clip[j]

    return rewards

这里按照参考 4 里说的,除了 eos token 对应的 responese 的 reward score,对其余的每个时间步都增加了一个正则项,因为正则项的格式就是 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn) 提到的 advantage 项为负的情况一模一样,只是这里不涉及 loss 的计算。因此,这里的目的:新策略和之前的策略不一致,增加探索,得到负的 KL 散度,从而提高奖励。

GAE

python 复制代码
def get_advantages_and_returns(self, values, rewards, start):
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    # 逆序 advantage, 这样就按时间步顺序得到每个时间步的 advantage
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

可以看到每个 rlhf 里的实现,和 DeepSpeed-Chat RLHF 阶段代码解读(0) ------ 原始 PPO 代码解读 - 掘金 (juejin.cn)的实现,没有本质上的区别。

Loss

python 复制代码
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss
  • critic loss
python 复制代码
def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    if self.compute_fp32_loss:
        values = values.float()
        values_clipped = values_clipped.float()
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    # 这是损失函数的核心部分。首先,计算vf_loss1和vf_loss2中较大的那个值,然后乘以mask。
    # 这样做是为了只考虑有效的样本(由mask指示)。然后,取这个乘积的总和,除以mask中有效样本的数量(mask.sum()),得到平均损失。
    # 选取更大的 loss 是为了增加探索,防止过于乐观的估计。
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

RLHF 整体的流程

结合之前的文章,以及本篇文章的数据处理和 PPO 章节,相信读者对 RLHF 无论是原理和代码都有了一定的理解,这里再从整体梳理下使用 PPO 进行 RLHF 的流程。

step1 prompt 输入 actor model 得到 response step2(重要性采样): prompt + response 分别输入到 actor_model 和 reference model 得到 log_probs、ref_log_probs、reward_score、values,这部分的数据可以重复利用。 step3 计算 critic_loss、actor_loss,更新 actor_model。 大致上,ppo 主要就这个三个步骤。

整个流程下来,我的感觉,很繁琐,难训练,所以目前主流大模型很少使用原始的这套 RLHF 流程,更多使用 dpo 算法,而且 RLHF 的数据有限,很难对所有的 response 有一个公平的 rewar,所以下一个系列文章会介绍利用 dpo 的 RLAIF 算法,如 SPIN、self-reward etc。欢迎关注。

参考

  1. Negative KL-divergence RLHF implementation · Issue #736 · huggingface/trl (github.com)[2307.04964] Secrets of RLHF in Large Language Models Part I: PPO (arxiv.org)
  2. 关于ppo阶段,reward分数计算的问题 · Issue #26 · OpenLMLab/MOSS-RLHF (github.com)
  3. 2009.01325v3.pdf (arxiv.org)
相关推荐
盼小辉丶12 小时前
PyTorch强化学习实战——使用交叉熵方法解决 FrozenLake 环境
人工智能·pytorch·python·强化学习
Luca_kill1 天前
深度解构 Hermes Agent:从“中央调度”到“自我进化”的架构哲学
大模型·强化学习·agent框架·ai架构·hermes agent
盼小辉丶2 天前
PyTorch强化学习实战(6)——交叉熵方法详解与实现
人工智能·pytorch·python·强化学习
盼小辉丶3 天前
PyTorch强化学习实战(5)——PyTorch Ignite 事件驱动机制与实践
人工智能·pytorch·python·强化学习
joshchen2153 天前
强化学习基础(赵世钰)第一章
人工智能·深度学习·算法·机器学习·强化学习
joshchen2154 天前
强化学习基础(赵世钰)第二章 贝尔曼方程
人工智能·python·机器学习·强化学习
星马梦缘6 天前
强化学习实战8.3——用PPO打赢星际争霸【编写自定义环境GYM】
人工智能·强化学习·gymnasium·星际争霸·sc2·starcraft2·sb3
盼小辉丶7 天前
PyTorch强化学习实战(4)——PyTorch基础
人工智能·pytorch·python·强化学习
星马梦缘7 天前
强化学习实战8——用PPO打赢星际争霸【整合版】
强化学习·ppo·星际争霸·sc2·starcraft2·sb3
Narrastory8 天前
Note:强化学习(六)
人工智能·深度学习·强化学习