数据处理
这里和 DeepSpeed-Chat RLHF 阶段代码解读(1) ------ 奖励函数阶段 - 掘金 (juejin.cn) 处理基本一致,唯一的区别是输入不是 prompt + response,而是只有 prompt,response 靠 actor model 生成。
PPO
初始化
PPO 训练要进行模型初始化,这里一共四个模型: Actor model、Reference model: 开始的时候这两个模型是一样的,用途是一样的。 Critic model、Reward model: 开始的时候这两个模型是一样的,但是用途是不一样的,一个是用来产生 critic value,一个是用来产生 reward 的,虽然结构是一样的。另外,critic model 会随着 ppo 训练更新,但是在 ppo 阶段,Reward model 是不变的,比较抽象。
Rewards
Token Level KL-Penalty
python
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
action_mask):
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1) + 1
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j]
return rewards
这里按照参考 4 里说的,除了 eos token 对应的 responese 的 reward score,对其余的每个时间步都增加了一个正则项,因为正则项的格式就是 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn) 提到的 advantage 项为负的情况一模一样,只是这里不涉及 loss 的计算。因此,这里的目的:新策略和之前的策略不一致,增加探索,得到负的 KL 散度,从而提高奖励。
GAE
python
def get_advantages_and_returns(self, values, rewards, start):
# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
# 逆序 advantage, 这样就按时间步顺序得到每个时间步的 advantage
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns
可以看到每个 rlhf 里的实现,和 DeepSpeed-Chat RLHF 阶段代码解读(0) ------ 原始 PPO 代码解读 - 掘金 (juejin.cn)的实现,没有本质上的区别。
Loss
- actor loss actor_loss 的核心是重要性采样,重要性采样的核心思想是:旧策略的采样在一定时间段内,可以用于新策略的训练,提高数据的有效利用。关于 loss 为什么可以写成 -ratio * advantage 的证明,可以看: 深度强化学习(DRL)算法 2 ------ PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn)。 使用 torch.max 的原因,和 critic loss 一样,为了避免乐观估计,增加探索。
python
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
- critic loss
python
def critic_loss_fn(self, values, old_values, returns, mask):
## value loss
values_clipped = torch.clamp(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
# 这是损失函数的核心部分。首先,计算vf_loss1和vf_loss2中较大的那个值,然后乘以mask。
# 这样做是为了只考虑有效的样本(由mask指示)。然后,取这个乘积的总和,除以mask中有效样本的数量(mask.sum()),得到平均损失。
# 选取更大的 loss 是为了增加探索,防止过于乐观的估计。
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss
RLHF 整体的流程
结合之前的文章,以及本篇文章的数据处理和 PPO 章节,相信读者对 RLHF 无论是原理和代码都有了一定的理解,这里再从整体梳理下使用 PPO 进行 RLHF 的流程。
step1 prompt 输入 actor model 得到 response step2(重要性采样): prompt + response 分别输入到 actor_model 和 reference model 得到 log_probs、ref_log_probs、reward_score、values,这部分的数据可以重复利用。 step3 计算 critic_loss、actor_loss,更新 actor_model。 大致上,ppo 主要就这个三个步骤。
整个流程下来,我的感觉,很繁琐,难训练,所以目前主流大模型很少使用原始的这套 RLHF 流程,更多使用 dpo 算法,而且 RLHF 的数据有限,很难对所有的 response 有一个公平的 rewar,所以下一个系列文章会介绍利用 dpo 的 RLAIF 算法,如 SPIN、self-reward etc。欢迎关注。