DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读

为了理解 DeepSpeed-Chat RLHF 的 RLHF 全部过程,这个系列会分三篇文章分别介绍: 原始 PPO 代码解读RLHF 奖励函数代码解读RLHF PPO 代码解读 这是系列的第一篇文章,我们来一步一步的看 PPO 算法的代码实现,对于 PPO 算法原理不太了解的同学,可以参考之前的文章:

  1. 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn)
  2. 深度强化学习(DRL)算法 2 ------ PPO 之 GAE 篇 - 掘金 (juejin.cn)

Clipped Surrogate 函数实现

python 复制代码
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):
    end = start + args.minibatch_size
    mb_inds = b_inds[start:end]

    _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
    logratio = newlogprob - b_logprobs[mb_inds]
    ratio = logratio.exp()

    mb_advantages = b_advantages[mb_inds]
    if args.norm_adv:
        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

    # Policy loss
    pg_loss1 = -mb_advantages * ratio
    pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
    pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函数的实现很简单,这里不再赘述,理解算法原理,代码自然而然就可以看懂,核心是 get_action_and_value 函数的理解。

python 复制代码
def get_action_and_value(self, x, action=None):
    logits = self.actor(x)
    # probs 相当于计算 softmax
    probs = Categorical(logits=logits)
    if action is None:
        action = probs.sample()
    # probs.log_prob(action) 计算的是 p(a|s) 的 log 形式,方便计算 Clipped Surrogate 函数里的 ratio
    return action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 实现

直接来看 gae 可能比较抽象,我们先来看蒙特卡洛方法实现的优势估计,对蒙特卡洛方法不熟悉的同学,可以参考之前的文章。 深度强化学习(DRL)算法 附录 3 ------ 蒙特卡洛方法(MC)和时序差分(TD) - 掘金 (juejin.cn) 两种方法都采用了反向迭代(因为反向迭代更好计算)的方式来实现优势估计。

python 复制代码
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)

returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
    if t == args.num_steps - 1:
        nextnonterminal = 1.0 - next_done
        next_return = last_value
    else:
        nextnonterminal = 1.0 - dones[t+1]
        next_return = returns[t+1]
    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代码做了什么事情呢,last_value 对应最后的 step(对应 step t) 产生的期望回报,如果 step t-1 整个流程没有结束,那么 t-1 时刻的期望回报就是 reward(t-1) + args.gamma * nextnonterminal * next_return,这样一步一步往后推,就可以计算每一个 step 的期望回报,从而得到每一步的优势,还没理解的话,看下面每个时间步的拆解。关于 last_value 的使用,这里由于没有后续的回报可以累积,所以直接使用 last_value 作为最后一个时间步的回报。关于下面为啥用 return[t-1] 替换原始公式的 value[t-1],这样计算的话就相当于蒙特卡洛方法的优势估计,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相当于 TD(1) 的优势估计。

python 复制代码
# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我们可以看到一步一步往前推,最后就得到蒙特卡洛方法的优势估计

理解了上面讲的蒙特卡洛方法实现的优势估计,再来看 gae 的实现,我们可以看到代码实现上十分的相似,只是多了 delta 的计算,这里的 delta 对应的就是之前 PPO GAE 篇里介绍的 delta。

python 复制代码
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)

advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
    if t == args.num_steps - 1:
        nextnonterminal = 1.0 - next_done
        nextvalues = last_value
    else:
        nextnonterminal = 1.0 - dones[t+1]
        nextvalues = values[t+1]
    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

这里通过反向迭代的方式计算 GAE advantage,可能理解上比较抽象,举个例子,就很好理解了:

python 复制代码
# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐项展开,每一时刻的 GAE Advantage 和 PPO GAE 篇里介绍的公式是一模一样的,这里 GAE 就是一种数学公式表达,核心思想是 n step 的优势估计的加权平均,通过数学技巧恰好是上面的形式。

参考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION
相关推荐
kkkkkkkkk_12016 小时前
【强化学习】07周博磊强化学习纲要学习笔记——第四课上
学习·强化学习
free-elcmacom10 小时前
机器学习高阶教程<2>优化理论实战:BERT用AdamW、强化学习爱SGD
人工智能·python·机器学习·bert·强化学习·大模型训练的优化器选择逻辑
AI-Frontiers1 天前
小白也能看懂的RLHF-PPO:原理篇
强化学习
传说故事1 天前
RL中的同步和异步(On-Policy & Off-Policy)的通俗解释
人工智能·强化学习
视觉&物联智能3 天前
【杂谈】-RL即服务:解锁新一轮自主浪潮
人工智能·ai·chatgpt·aigc·强化学习·agi·deepseek
自动化小秋葵4 天前
强化学习-数据操作与预处理
强化学习
kkkkkkkkk_12015 天前
【强化学习】06周博磊强化学习纲要学习笔记——第三课下
笔记·学习·强化学习
i.ajls7 天前
强化学习入门-5(MAPPO)
笔记·机器学习·强化学习·mappo
kkkkkkkkk_12018 天前
【强化学习】05周博磊强化学习纲要学习笔记——第三课上
笔记·学习·强化学习
强化学习与机器人控制仿真9 天前
ProtoMotions 3 入门教程(一)开源 GPU 加速人形机器人强化学习仿真训练框架
人工智能·stm32·深度学习·机器人·强化学习·人形机器人·模仿学习