DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读

为了理解 DeepSpeed-Chat RLHF 的 RLHF 全部过程,这个系列会分三篇文章分别介绍: 原始 PPO 代码解读RLHF 奖励函数代码解读RLHF PPO 代码解读 这是系列的第一篇文章,我们来一步一步的看 PPO 算法的代码实现,对于 PPO 算法原理不太了解的同学,可以参考之前的文章:

  1. 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn)
  2. 深度强化学习(DRL)算法 2 ------ PPO 之 GAE 篇 - 掘金 (juejin.cn)

Clipped Surrogate 函数实现

python 复制代码
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):
    end = start + args.minibatch_size
    mb_inds = b_inds[start:end]

    _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
    logratio = newlogprob - b_logprobs[mb_inds]
    ratio = logratio.exp()

    mb_advantages = b_advantages[mb_inds]
    if args.norm_adv:
        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

    # Policy loss
    pg_loss1 = -mb_advantages * ratio
    pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
    pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函数的实现很简单,这里不再赘述,理解算法原理,代码自然而然就可以看懂,核心是 get_action_and_value 函数的理解。

python 复制代码
def get_action_and_value(self, x, action=None):
    logits = self.actor(x)
    # probs 相当于计算 softmax
    probs = Categorical(logits=logits)
    if action is None:
        action = probs.sample()
    # probs.log_prob(action) 计算的是 p(a|s) 的 log 形式,方便计算 Clipped Surrogate 函数里的 ratio
    return action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 实现

直接来看 gae 可能比较抽象,我们先来看蒙特卡洛方法实现的优势估计,对蒙特卡洛方法不熟悉的同学,可以参考之前的文章。 深度强化学习(DRL)算法 附录 3 ------ 蒙特卡洛方法(MC)和时序差分(TD) - 掘金 (juejin.cn) 两种方法都采用了反向迭代(因为反向迭代更好计算)的方式来实现优势估计。

python 复制代码
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)

returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
    if t == args.num_steps - 1:
        nextnonterminal = 1.0 - next_done
        next_return = last_value
    else:
        nextnonterminal = 1.0 - dones[t+1]
        next_return = returns[t+1]
    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代码做了什么事情呢,last_value 对应最后的 step(对应 step t) 产生的期望回报,如果 step t-1 整个流程没有结束,那么 t-1 时刻的期望回报就是 reward(t-1) + args.gamma * nextnonterminal * next_return,这样一步一步往后推,就可以计算每一个 step 的期望回报,从而得到每一步的优势,还没理解的话,看下面每个时间步的拆解。关于 last_value 的使用,这里由于没有后续的回报可以累积,所以直接使用 last_value 作为最后一个时间步的回报。关于下面为啥用 returnt-1 替换原始公式的 valuet-1,这样计算的话就相当于蒙特卡洛方法的优势估计,如果next_return = returnst+1 改成 next_value = valuest+1 就相当于 TD(1) 的优势估计。

python 复制代码
# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我们可以看到一步一步往前推,最后就得到蒙特卡洛方法的优势估计

理解了上面讲的蒙特卡洛方法实现的优势估计,再来看 gae 的实现,我们可以看到代码实现上十分的相似,只是多了 delta 的计算,这里的 delta 对应的就是之前 PPO GAE 篇里介绍的 delta。

python 复制代码
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)

advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
    if t == args.num_steps - 1:
        nextnonterminal = 1.0 - next_done
        nextvalues = last_value
    else:
        nextnonterminal = 1.0 - dones[t+1]
        nextvalues = values[t+1]
    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

这里通过反向迭代的方式计算 GAE advantage,可能理解上比较抽象,举个例子,就很好理解了:

python 复制代码
# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐项展开,每一时刻的 GAE Advantage 和 PPO GAE 篇里介绍的公式是一模一样的,这里 GAE 就是一种数学公式表达,核心思想是 n step 的优势估计的加权平均,通过数学技巧恰好是上面的形式。

参考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION
相关推荐
盼小辉丶2 天前
PyTorch强化学习实战——使用高级组件复现DQN
pytorch·深度学习·强化学习
亲爱的阿瞎3 天前
1、强化学习中的回报与策略
强化学习
勾股导航4 天前
REINFORCE算法
人工智能·强化学习·reinforce 算法
勾股导航4 天前
A2C算法
人工智能·强化学习·a2c
勾股导航5 天前
DQN算法
人工智能·强化学习
SP FA5 天前
深度强化学习与控制(二):无模型强化学习
人工智能·强化学习·dqn
盼小辉丶5 天前
PyTorch强化学习实战(10)——强化学习高级组件
人工智能·pytorch·python·强化学习
威化饼的一隅6 天前
【大模型LLM学习】Agentic RL—基于Qwen3-4b训练Travel Planning Agent
大模型·llm·agent·强化学习·智能体·agentic rl·旅游智能体
盼小辉丶8 天前
PyTorch强化学习实战——Atari游戏包装器
pytorch·深度学习·强化学习
viperrrrrrrrrr78 天前
强化学习入门笔记
人工智能·强化学习