DeepSpeed-Chat RLHF 阶段代码解读(2) —— PPO 阶段

数据处理

这里和 DeepSpeed-Chat RLHF 阶段代码解读(1) ------ 奖励函数阶段 - 掘金 (juejin.cn) 处理基本一致,唯一的区别是输入不是 prompt + response,而是只有 prompt,response 靠 actor model 生成。

PPO

初始化

PPO 训练要进行模型初始化,这里一共四个模型: Actor model、Reference model: 开始的时候这两个模型是一样的,用途是一样的。 Critic model、Reward model: 开始的时候这两个模型是一样的,但是用途是不一样的,一个是用来产生 critic value,一个是用来产生 reward 的,虽然结构是一样的。另外,critic model 会随着 ppo 训练更新,但是在 ppo 阶段,Reward model 是不变的,比较抽象。

Rewards

Token Level KL-Penalty

python 复制代码
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                    action_mask):

    kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
    rewards = kl_divergence_estimate
    start = prompts.shape[1] - 1
    ends = start + action_mask[:, start:].sum(1) + 1
    reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                self.clip_reward_value)
    batch_size = log_probs.shape[0]
    for j in range(batch_size):
        rewards[j, start:ends[j]][-1] += reward_clip[j]

    return rewards

这里按照参考 4 里说的,除了 eos token 对应的 responese 的 reward score,对其余的每个时间步都增加了一个正则项,因为正则项的格式就是 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn) 提到的 advantage 项为负的情况一模一样,只是这里不涉及 loss 的计算。因此,这里的目的:新策略和之前的策略不一致,增加探索,得到负的 KL 散度,从而提高奖励。

GAE

python 复制代码
def get_advantages_and_returns(self, values, rewards, start):
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    # 逆序 advantage, 这样就按时间步顺序得到每个时间步的 advantage
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

可以看到每个 rlhf 里的实现,和 DeepSpeed-Chat RLHF 阶段代码解读(0) ------ 原始 PPO 代码解读 - 掘金 (juejin.cn)的实现,没有本质上的区别。

Loss

python 复制代码
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss
  • critic loss
python 复制代码
def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    if self.compute_fp32_loss:
        values = values.float()
        values_clipped = values_clipped.float()
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    # 这是损失函数的核心部分。首先,计算vf_loss1和vf_loss2中较大的那个值,然后乘以mask。
    # 这样做是为了只考虑有效的样本(由mask指示)。然后,取这个乘积的总和,除以mask中有效样本的数量(mask.sum()),得到平均损失。
    # 选取更大的 loss 是为了增加探索,防止过于乐观的估计。
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

RLHF 整体的流程

结合之前的文章,以及本篇文章的数据处理和 PPO 章节,相信读者对 RLHF 无论是原理和代码都有了一定的理解,这里再从整体梳理下使用 PPO 进行 RLHF 的流程。

step1 prompt 输入 actor model 得到 response step2(重要性采样): prompt + response 分别输入到 actor_model 和 reference model 得到 log_probs、ref_log_probs、reward_score、values,这部分的数据可以重复利用。 step3 计算 critic_loss、actor_loss,更新 actor_model。 大致上,ppo 主要就这个三个步骤。

整个流程下来,我的感觉,很繁琐,难训练,所以目前主流大模型很少使用原始的这套 RLHF 流程,更多使用 dpo 算法,而且 RLHF 的数据有限,很难对所有的 response 有一个公平的 rewar,所以下一个系列文章会介绍利用 dpo 的 RLAIF 算法,如 SPIN、self-reward etc。欢迎关注。

参考

  1. Negative KL-divergence RLHF implementation · Issue #736 · huggingface/trl (github.com)[2307.04964] Secrets of RLHF in Large Language Models Part I: PPO (arxiv.org)
  2. 关于ppo阶段,reward分数计算的问题 · Issue #26 · OpenLMLab/MOSS-RLHF (github.com)
  3. 2009.01325v3.pdf (arxiv.org)
相关推荐
字节跳动开源3 天前
最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!
开源·llm·强化学习
DataFountain数据科学13 天前
《文心一言插件设计与开发》赛题三等奖方案 | NoteTable
大数据·人工智能·数学建模·文心一言·强化学习·数据竞赛
人工智能培训咨询叶梓23 天前
语言模型与人类反馈的深度融合:Chain of Hindsight技术
人工智能·深度学习·语言模型·自然语言处理·性能优化·强化学习·大模型微调
Gaoshu1011 个月前
◇【论文_20170828 v2】PPO 算法〔OpenAI〕: Proximal Policy Optimization Algorithms
强化学习·论文整理
lijianhua_97121 个月前
先进制造aps专题二十六 基于强化学习的人工智能ai生产排程aps模型简介
人工智能·强化学习·aps
Gaoshu1012 个月前
《强化学习的数学原理》(2024春)_西湖大学赵世钰 Ch10 Actor-Critic 方法 » P2
笔记·强化学习
Nicolas8932 个月前
【算法业务】基于Multi-Armed Bandits的个性化push文案自动优选算法实践
强化学习·推荐算法·多臂老虎机·个性化推送系统·push系统·用户激活·文案优选
机器白学2 个月前
【强化学习系列】Gym库使用——创建自己的强化学习环境3:矢量化环境+奖励函数设计
强化学习
荒野火狐2 个月前
【FreeRL】我的深度学习库构建思想
人工智能·深度学习·强化学习·dqn
Nicolas8933 个月前
【大模型理论篇】强化学习RL与大模型智能体
大模型·llm·强化学习·策略梯度·dqn·rl·智能体