从RLHF到PPO:让AI学会说人话

📚 目录

  1. RLHF整体框架:三阶段训练
  2. 第三阶段的四个组件:形象理解
  3. 组件的模型结构:共享Base,替换Head
  4. 训练流程:一次完整迭代
  5. PPO的核心创新:Clip机制
  6. 代码实现与常见问题

📌 前置概念:对齐 vs RLHF

RLHF是什么?

RLHF = 大模型对齐的一种实现方法

markdown 复制代码
大模型对齐(Alignment)
├─ 目标:让AI符合人类价值观
│   ├─ Helpful(有帮助)
│   ├─ Harmless(无害)
│   └─ Honest(诚实)
│
└─ 实现方法(多种):
    ├─ RLHF(最常用)← 本文重点
    │   └─ 从人类反馈中学习
    ├─ Constitutional AI(Anthropic的方法)
    │   └─ 用AI自我批评
    ├─ RLAIF(从AI反馈中学习)
    │   └─ 用AI替代人类反馈
    └─ DPO(直接偏好优化)
        └─ 不用RL,直接优化偏好

关系:

  • 对齐是目标(让AI对齐人类价值观)
  • RLHF是手段(用人类反馈实现对齐)
  • ChatGPT、Claude、Llama2-Chat都用RLHF实现对齐

类比:

ini 复制代码
对齐 = 目的地(把AI训练成好助手)
RLHF = 交通工具(坐飞机去)
PPO = 飞机的引擎(具体技术)

本文重点: 讲解RLHF(最主流的对齐方法)以及其中用到的PPO算法。


🎯 Part 1: RLHF整体框架 ------ ChatGPT怎么训练的

1.1 问题:为什么需要RLHF?

训练前的AI(只有预训练):

复制代码
你:什么是黑洞?
AI:黑洞黑洞黑洞黑洞...(只会重复)

你:帮我写一封邮件
AI:邮件邮件邮件...(没法用)

原因: 预训练只学会了"续写文本",没学会"对话"和"遵循指令"

解决方案: RLHF(Reinforcement Learning from Human Feedback)

  • 从人类反馈中学习
  • 让AI对齐人类偏好

1.2 RLHF三阶段流程

arduino 复制代码
┌─────────────────────────────────────────────────────┐
│              RLHF完整训练流程                       │
├─────────────────────────────────────────────────────┤
│                                                      │
│  【阶段1】监督微调(SFT)                           │
│  ────────────────────────────────────────           │
│  输入:预训练模型(GPT-3)                          │
│  数据:人类标注的高质量对话                         │
│       问题:"什么是黑洞?"                          │
│       回答:"黑洞是引力极强的天体..."              │
│  方法:监督学习(像训练语言模型一样)               │
│  输出:SFT模型(会基本对话了)                      │
│                                                      │
│  目的:让模型学会对话的基本格式                     │
│                                                      │
│  ↓                                                   │
│                                                      │
│  【阶段2】训练奖励模型(RM)                        │
│  ────────────────────────────────────────           │
│  输入:SFT模型                                       │
│  数据:人类偏好数据                                 │
│       问题:"什么是黑洞?"                          │
│       回答A:"黑洞是引力极强的天体..."(详细)     │
│       回答B:"不知道"(敷衍)                       │
│       人类选择:A > B                                │
│  方法:成对比较训练                                 │
│  输出:Reward Model(会打分的裁判)                 │
│                                                      │
│  目的:用RM替代昂贵的人类反馈                       │
│                                                      │
│  ↓                                                   │
│                                                      │
│  【阶段3】PPO强化学习 ← 本文重点                   │
│  ────────────────────────────────────────           │
│  输入:SFT模型、Reward Model                        │
│  方法:强化学习(PPO算法)                          │
│  输出:最终的ChatGPT                                │
│                                                      │
│  目的:根据RM反馈优化模型                           │
│                                                      │
└─────────────────────────────────────────────────────┘

1.3 类比:训练宠物狗

arduino 复制代码
阶段1(SFT):
教狗基本动作
"坐下" → 示范 → 狗学会坐

阶段2(RM训练):
训练一个裁判
裁判学会:哪个动作是"好坐姿"

阶段3(PPO):
狗自己练习
裁判打分 → 狗调整 → 越来越好

🎭 Part 2: 阶段3的四个组件 ------ 形象理解

2.1 组件概览

复制代码
PPO训练系统 = 4个角色配合

┌────────────────────────────────────────┐
│  1. Actor(演员)                      │
│     职责:生成回答                     │
│     状态:训练中,不断改进             │
│     ↓                                   │
│  2. Reward Model(裁判)               │
│     职责:给回答打分                   │
│     状态:固定不变(阶段2训练好的)    │
│     ↓                                   │
│  3. Reference(起点)                  │
│     职责:记住训练前的样子             │
│     状态:固定不变(Actor的初始副本)  │
│     ↓                                   │
│  4. Critic(教练)                     │
│     职责:预测能得多少分               │
│     状态:训练中,辅助Actor学习        │
└────────────────────────────────────────┘

2.2 四个角色的详细作用

角色1:Actor(演员)------ 主角

erlang 复制代码
任务:生成回答

例子:
输入:什么是黑洞?
Actor:黑洞是时空中引力极强的区域,连光都无法逃脱...

状态:要训练的主模型
来源:从SFT模型初始化
目标:根据RM反馈改进回答质量

类比:

  • 选秀节目的选手
  • 要根据评委反馈改进表演

角色2:Reward Model(裁判)------ 评判者

arduino 复制代码
任务:给回答打分

例子:
Actor的回答:"黑洞是时空中引力极强的区域..."
RM打分:8.5分(很好!)

状态:固定不变(阶段2已训练好)
来源:阶段2用人类偏好数据训练
作用:替代人类打分(便宜、快速)

类比:

  • 选秀节目的评委
  • 已经知道什么是"好表演"(从人类学的)

角色3:Reference(起点)------ 锚点

erlang 复制代码
任务:记住训练前的Actor

例子:
Actor训练前:"黑洞是引力极强的天体"(生成概率10%)
Reference:记住这个状态
Actor训练后:"黑洞是..."(概率变成50%)
Reference:对比差异,防止变化太大

状态:固定不变(Actor训练前的复制品)
来源:从SFT模型复制(和Actor初始状态完全一样)
作用:防止Actor偏离初始策略太远(保持多样性)

类比:

  • 选手的"初始风格"
  • 可以改进,但别丢了本来的特色

为什么需要Reference?

arduino 复制代码
没有Reference:
Actor发现"黑洞"这个词得高分
Actor疯狂增加"黑洞"的概率
任何问题都先说"黑洞" ❌

有Reference:
Actor想增加"黑洞"的概率
Reference:等等,你和训练前差太多了
惩罚 = -0.5分
Actor:那我还是保持一定风格 ✅

角色4:Critic(教练)------ 辅助者

arduino 复制代码
任务:预测"这个回答能得多少分"

例子:
Actor生成回答:"黑洞是..."
Critic预测:这个能得8分
实际得分:8.5分
差距:+0.5分(好于预期!)

状态:训练中(和Actor一起训练)
来源:从SFT模型初始化
作用:帮助计算Advantage(相对好坏)

类比:

  • 运动员的教练
  • 预测表现,帮助分析进步方向

Advantage的意义:

ini 复制代码
Advantage = 实际得分 - 预测得分

例子1:
实际 = 8.5分,预测 = 8分
Advantage = +0.5(好于预期)
→ 增加这个回答的概率 ✅

例子2:
实际 = 7分,预测 = 8分
Advantage = -1(差于预期)
→ 降低这个回答的概率 ✅

关键:衡量的是"相对好坏",不是绝对分数
→ 降低方差,训练更稳定

2.3 四个组件的配合

arduino 复制代码
一次训练的完整流程:

你提问:"什么是黑洞?"
  ↓
Actor(演员):
  生成:"黑洞是引力极强的天体..."
  ↓
Reward Model(裁判):
  打分:"这个回答8.5分"
  ↓
Reference(起点):
  检查:"和训练前对比,差异2%,可以"
  总奖励 = 8.5 - 0.02 = 8.48分
  ↓
Critic(教练):
  预测:"我觉得能得8分"
  实际:8.48分
  Advantage = +0.48(好于预期)
  ↓
训练系统:
  "Actor表现好于预期,奖励它"
  "但用PPO慢慢调整,别改太多"

🏗️ Part 3: 组件的模型结构 ------ 共享Base,替换Head

关键洞察:Actor、RM、Critic用的是同一个Base模型,只是最后一层不同

3.1 整体架构

scss 复制代码
┌──────────────────────────────────────────────────┐
│           共享的Base模型(Transformer)          │
│                                                   │
│  所有组件都用同一个预训练模型作为基础:          │
│  - GPT-3                                          │
│  - LLaMA                                          │
│  - 任何预训练的Transformer                       │
│                                                   │
│  Base模型输出:hidden_state [batch, seq_len, 768]│
└──────────────────────────────────────────────────┘
                        │
        ┌───────────────┼───────────────┬──────────┐
        ↓               ↓               ↓          ↓
   ┌─────────┐    ┌──────────┐   ┌─────────┐ ┌─────────┐
   │ Actor   │    │ Critic   │   │  RM     │ │  Ref    │
   │ Head    │    │ Head     │   │ Head    │ │ Head    │
   └─────────┘    └──────────┘   └─────────┘ └─────────┘
        │               │              │           │
        ↓               ↓              ↓           ↓
   LM Head        Value Head      Reward Head  LM Head
   (vocab_size)   (1个值)         (1个值)      (vocab_size)

3.2 具体结构对比

python 复制代码
# 共享的Base模型
base_model = GPT2Model.from_pretrained('gpt2')
# 输出: [batch, seq_len, hidden_size=768]

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# 1. Actor: Base + LM Head
class Actor(nn.Module):
    def __init__(self):
        self.base = base_model  # 共享
        self.lm_head = nn.Linear(768, 50257)  # vocab_size

    def forward(self, input_ids):
        hidden = self.base(input_ids)  # [batch, seq, 768]
        logits = self.lm_head(hidden)  # [batch, seq, 50257]
        probs = softmax(logits)        # 每个token的概率
        return probs

# 输出:词表上的概率分布
# 用途:生成下一个token

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# 2. Critic: Base + Value Head
class Critic(nn.Module):
    def __init__(self):
        self.base = base_model  # 共享(参数独立)
        self.value_head = nn.Linear(768, 1)  # 输出1个值

    def forward(self, input_ids):
        hidden = self.base(input_ids)    # [batch, seq, 768]
        last_hidden = hidden[:, -1, :]   # 取最后token [batch, 768]
        value = self.value_head(last_hidden)  # [batch, 1]
        return value

# 输出:一个标量(状态价值)
# 用途:预测这个状态能得多少分

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# 3. Reward Model: Base + Reward Head
class RewardModel(nn.Module):
    def __init__(self):
        self.base = base_model  # 共享(参数独立)
        self.reward_head = nn.Linear(768, 1)  # 输出1个值

    def forward(self, input_ids):
        hidden = self.base(input_ids)    # [batch, seq, 768]
        last_hidden = hidden[:, -1, :]   # 取最后token [batch, 768]
        reward = self.reward_head(last_hidden)  # [batch, 1]
        return reward

# 输出:一个标量(奖励分数)
# 用途:评价回答质量

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# 4. Reference: 和Actor完全一样
class Reference(nn.Module):
    def __init__(self):
        self.base = base_model  # 共享(参数独立)
        self.lm_head = nn.Linear(768, 50257)

    def forward(self, input_ids):
        # 和Actor一模一样
        hidden = self.base(input_ids)
        logits = self.lm_head(hidden)
        probs = softmax(logits)
        return probs

# 输出:词表上的概率分布(和Actor一样)
# 用途:记住训练前的Actor

3.3 关键点

1. Base模型是共享的(架构),但参数是独立的

ini 复制代码
初始化时:
Actor.base    = copy(SFT模型)  ← 会训练
Critic.base   = copy(SFT模型)  ← 会训练
RM.base       = copy(SFT模型)  ← 固定(阶段2训练好的)
Reference.base = copy(Actor.base初始状态) ← 固定

训练后:
Actor.base    → 参数改变了
Critic.base   → 参数改变了
RM.base       → 不变
Reference.base → 不变

2. Head(最后一层)决定了输出类型

ini 复制代码
LM Head (Actor/Reference):
输入:[batch, seq, 768]
输出:[batch, seq, vocab_size]
用途:生成概率分布

Value Head (Critic):
输入:[batch, 768](只用最后一个token)
输出:[batch, 1]
用途:预测价值

Reward Head (RM):
输入:[batch, 768](只用最后一个token)
输出:[batch, 1]
用途:打分

3. Critic和RM的区别

markdown 复制代码
结构:几乎一样(Base + Linear(768, 1))

训练方式不同:
Critic:
  - 在PPO阶段训练
  - 预测目标:累积奖励
  - 损失函数:MSE(预测值, 实际累积奖励)

RM:
  - 在阶段2训练
  - 训练目标:人类偏好
  - 损失函数:成对比较(A vs B,人类喜欢哪个)

用途不同:
Critic:
  - 辅助Actor训练(计算Advantage)
  - 内部工具

RM:
  - 提供奖励信号
  - 代表人类偏好

3.4 为什么要共享Base?

原因1:节省计算

diff 复制代码
如果每个组件用不同的模型:
- Actor: GPT-3 (175B参数)
- Critic: GPT-3 (175B参数)
- RM: GPT-3 (175B参数)
- Reference: GPT-3 (175B参数)
总共:700B参数 ❌

共享Base(只是Head不同):
- 实际上每个组件还是独立的
- 但初始化时共享权重
- Head很小(768 → 1 或 50257)

原因2:知识共享

diff 复制代码
SFT模型已经学到了:
- 语言理解
- 世界知识
- 基本推理

Actor/Critic/RM都需要这些能力
→ 从同一个SFT初始化
→ 站在巨人肩膀上

🔄 Part 4: 训练流程 ------ 一次完整迭代

4.1 Rollout阶段(收集数据)

python 复制代码
# 伪代码展示一次迭代

# ========== Step 1: 准备Prompt ==========
prompts = ["什么是黑洞?", "如何学习Python?", ...]  # 256个

# ========== Step 2: Actor生成 ==========
responses = []
actor_log_probs = []

for prompt in prompts:
    # Actor生成回复
    response = actor.generate(prompt, max_length=128)
    log_prob = actor.get_log_prob(prompt, response)

    responses.append(response)
    actor_log_probs.append(log_prob)

# 结果:
# prompt: "什么是黑洞?"
# response: "黑洞是时空中引力极强的区域..."
# log_prob: [-0.5, -0.3, -0.2, ...](每个token一个)

# ========== Step 3: RM打分 ==========
rm_rewards = []

with torch.no_grad():  # RM固定
    for prompt, response in zip(prompts, responses):
        reward = reward_model(prompt, response)
        rm_rewards.append(reward)

# 结果:
# rm_reward = 8.5

# ========== Step 4: Reference计算KL ==========
ref_log_probs = []

with torch.no_grad():  # Reference固定
    for prompt, response in zip(prompts, responses):
        ref_log_prob = reference_model(prompt, response)
        ref_log_probs.append(ref_log_prob)

# 计算KL惩罚
kl_penalties = []
beta = 0.1  # KL惩罚系数

for actor_lp, ref_lp in zip(actor_log_probs, ref_log_probs):
    kl = (actor_lp - ref_lp).sum()  # KL散度
    kl_penalty = beta * kl
    kl_penalties.append(kl_penalty)

# 结果:
# kl_penalty = -0.02(负数,是惩罚)

# ========== Step 5: 总奖励 ==========
total_rewards = []

for rm_rew, kl_pen in zip(rm_rewards, kl_penalties):
    total_reward = rm_rew - kl_pen
    total_rewards.append(total_reward)

# 结果:
# total_reward = 8.5 - 0.02 = 8.48

# ========== Step 6: Critic预测 ==========
values = []

with torch.no_grad():
    for prompt, response in zip(prompts, responses):
        value = critic(prompt, response)
        values.append(value)

# 结果:
# value = 8.0(预测能得8分)

# ========== Step 7: 计算Advantage ==========
advantages = []

for total_rew, value in zip(total_rewards, values):
    advantage = total_rew - value
    advantages.append(advantage)

# 结果:
# advantage = 8.48 - 8.0 = 0.48(好于预期!)

# ========== Step 8: 保存数据 ==========
rollout_data = {
    'prompts': prompts,
    'responses': responses,
    'actor_log_probs': actor_log_probs,  # 旧策略
    'advantages': advantages,
    'total_rewards': total_rewards
}

4.2 训练阶段(更新参数)

python 复制代码
# ========== Step 9: 创建DataLoader ==========
dataloader = DataLoader(rollout_data, batch_size=32, shuffle=True)

# ========== Step 10: PPO训练(重点!)==========
num_ppo_epochs = 4  # 重复使用数据

for epoch in range(num_ppo_epochs):
    for batch in dataloader:
        prompts = batch['prompts']
        responses = batch['responses']
        old_log_probs = batch['actor_log_probs']  # 旧策略
        advantages = batch['advantages']
        returns = batch['total_rewards']

        # ━━━━━ 更新Actor(PPO的核心)━━━━━
        # 1. 新策略前向传播
        new_log_probs = actor(prompts, responses)

        # 2. 计算ratio
        ratio = torch.exp(new_log_probs - old_log_probs)
        # ratio = P_new / P_old

        # 3. PPO Clip(核心创新!)
        epsilon = 0.2
        ratio_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)

        # 4. PPO损失
        surr1 = ratio * advantages
        surr2 = ratio_clipped * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        # 5. Actor反向传播
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        # ━━━━━ 更新Critic ━━━━━
        # 1. Critic预测
        predicted_values = critic(prompts, responses)

        # 2. MSE损失
        critic_loss = F.mse_loss(predicted_values, returns)

        # 3. Critic反向传播
        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()

print(f"训练完成!Avg Reward = {total_rewards.mean():.2f}")

4.3 可视化时间线

markdown 复制代码
一次完整迭代(假设256个prompts):

时间    阶段               说明
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
0s      开始

0-10s   Rollout阶段       收集数据(无梯度)
        ├─ Actor生成      256个responses
        ├─ RM打分         256个rewards
        ├─ Reference      计算KL
        ├─ Critic预测     256个values
        └─ 计算Advantage

10-40s  训练阶段          更新参数(有梯度)
        ├─ Epoch 1
        │  ├─ Batch 1-8   每次32个样本
        │  └─ 更新Actor和Critic
        ├─ Epoch 2        重复使用数据
        ├─ Epoch 3
        └─ Epoch 4

40s     本轮完成

        下一轮            Actor已更新,生成不同responses

🎯 Part 5: PPO的核心创新 ------ Clip机制

5.1 问题:为什么需要Clip?

没有Clip的训练(传统Policy Gradient):

erlang 复制代码
问题:什么是黑洞?

第1轮:
Actor生成"黑洞"的概率 = 10%
得分 = 8分

第2轮(梯度更新):
系统发现:"黑洞"这个开头得分高
梯度更新后,概率直接跳到 80%(增加8倍!)

第3轮:
Actor对任何问题都先说"黑洞"
"如何学习编程?" → "黑洞..."(乱答)
训练崩溃 ❌

为什么会崩溃?

diff 复制代码
原因1:梯度爆炸
- ratio = 80% / 10% = 8倍
- 梯度 ∝ ratio × advantage
- 梯度太大 → 参数突变

原因2:过度拟合某个pattern
- 发现一个高分trick
- 疯狂强化这个trick
- 忘记其他能力

原因3:策略突变
- 概率分布剧变
- 模型行为完全改变
- 之前学的全忘了

5.2 PPO的解决方案:Clip

核心思想:限制概率变化幅度

python 复制代码
# 计算ratio(新旧概率的比值)
ratio = new_prob / old_prob

# Clip:限制ratio在[0.8, 1.2]
ratio_clipped = torch.clamp(ratio, 1 - 0.2, 1 + 0.2)

# 取min(保守更新)
loss = -torch.min(ratio * advantage, ratio_clipped * advantage)

具体例子:

erlang 复制代码
第1轮:
"黑洞"的概率 = 10%

第2轮(想更新):
梯度计算后,想变成 80%
ratio = 80% / 10% = 8倍

PPO Clip:
ratio > 1.2 → 强制限制为1.2
实际新概率 = 10% × 1.2 = 12%(只涨20%)

第3轮:
"黑洞"的概率 = 12%
想变成 96%(12% × 8)
PPO限制:只能变成14.4%(12% × 1.2)

第10轮:
慢慢涨到 40%
模型稳定 ✅

5.3 Clip的工作机制

数学表达(用大白话):

python 复制代码
# 损失函数的两项

# 项1:未裁剪(原始目标)
surr1 = ratio * advantage

# 项2:裁剪(保守目标)
surr2 = clamp(ratio, 1-ε, 1+ε) * advantage

# 取min(哪个更保守用哪个)
loss = -min(surr1, surr2)

为什么取min?

ini 复制代码
Advantage > 0(好的动作):
─────────────────────────────
目标:增加概率

如果ratio > 1.2(增加太多):
surr1 = 1.5 × 0.5 = 0.75(鼓励增加)
surr2 = 1.2 × 0.5 = 0.6(限制增加)
min(surr1, surr2) = 0.6 ← 选保守的

结果:增加概率,但不超过20%


Advantage < 0(坏的动作):
─────────────────────────────
目标:降低概率

如果ratio < 0.8(降低太多):
surr1 = 0.6 × (-0.5) = -0.3(鼓励降低)
surr2 = 0.8 × (-0.5) = -0.4(限制降低)
min(surr1, surr2) = -0.4 ← 选保守的

结果:降低概率,但不低于20%

5.4 可视化:概率变化轨迹

erlang 复制代码
生成"黑洞"的概率变化(10轮训练):

概率
 50%│                            ╱──●  稳定
    │                        ╱───
 40%│                    ╱───
    │                ╱───
 30%│            ╱───
    │        ╱───
 20%│    ╱───
    │╱───
 10%│●                        ← 有PPO Clip
    │
  0%├──┬──┬──┬──┬──┬──┬──┬──┬──┬──
    1  2  3  4  5  6  7  8  9 10  轮次


对比(没有Clip):
概率
100%│
    │
 80%│  ●                       ← 直接跳这
    │  ×(然后崩溃)
    │
 50%│
    │
 10%│●
    │
  0%├──┬──────────────────────
    1  2  3  4  5  6  7  8  9 10  轮次

5.5 PPO vs 其他算法

算法 方法 优点 缺点
PG 无限制 简单 不稳定,易崩溃
TRPO KL约束(二阶优化) 稳定 复杂,计算慢
A3C 多进程异步 并行快 需要多核,还是可能不稳定
PPO 简单Clip 稳定、简单、快 需要调参

PPO的创新:

  • TRPO用复杂数学(Hessian矩阵)解决问题
  • PPO说:我用一个简单的clip就够了
  • 结果:效果一样好,但实现简单10倍

5.6 为什么Clip有效?

直觉1:物理类比

css 复制代码
训练AI = 开车

没有Clip:
油门踩多少算多少
→ 不小心踩太多 → 冲出赛道 ❌

有Clip(限速器):
最多只能加速20%
→ 踩太多也只加速20% → 安全 ✅

直觉2:信任域

css 复制代码
旧策略:我知道这样做是对的
新策略:我想尝试新方法

没有Clip:
可以跳很远 → 可能跳到悬崖
有Clip:
只能走一小步 → 在安全范围内探索

这就是"信任域"的概念:
只在你信任的范围内更新

💻 Part 6: 代码实现

6.1 核心代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# ========== 模型定义 ==========

class Actor(nn.Module):
    """策略网络:生成回答"""
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model  # GPT-2/LLaMA等
        self.lm_head = nn.Linear(768, 50257)  # vocab_size

    def forward(self, input_ids):
        hidden = self.base(input_ids).last_hidden_state
        logits = self.lm_head(hidden)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs


class Critic(nn.Module):
    """价值网络:预测累积奖励"""
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.value_head = nn.Linear(768, 1)

    def forward(self, input_ids):
        hidden = self.base(input_ids).last_hidden_state
        last_hidden = hidden[:, -1, :]  # 取最后token
        value = self.value_head(last_hidden)
        return value.squeeze(-1)


class RewardModel(nn.Module):
    """奖励模型:给回答打分"""
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.reward_head = nn.Linear(768, 1)

    def forward(self, input_ids):
        hidden = self.base(input_ids).last_hidden_state
        last_hidden = hidden[:, -1, :]
        reward = self.reward_head(last_hidden)
        return reward.squeeze(-1)


# ========== PPO核心函数 ==========

def ppo_loss(old_log_probs, new_log_probs, advantages, epsilon=0.2):
    """
    PPO的核心损失函数

    Args:
        old_log_probs: 旧策略的log概率
        new_log_probs: 新策略的log概率
        advantages: 优势函数
        epsilon: 裁剪阈值(通常0.2)
    """
    # 1. 计算ratio(概率比值)
    ratio = torch.exp(new_log_probs - old_log_probs)
    # ratio = P_new / P_old

    # 2. 未裁剪项
    surr1 = ratio * advantages

    # 3. 裁剪项(核心创新!)
    ratio_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    surr2 = ratio_clipped * advantages

    # 4. 取min(保守更新)
    policy_loss = -torch.min(surr1, surr2).mean()

    return policy_loss


def compute_advantages(rewards, values, gamma=0.99):
    """
    计算Advantage = 实际收益 - 预期收益

    Args:
        rewards: 每个step的奖励
        values: Critic预测的价值
        gamma: 折扣因子
    """
    advantages = []
    returns = []
    R = 0

    # 从后往前计算
    for t in reversed(range(len(rewards))):
        R = rewards[t] + gamma * R
        advantage = R - values[t]
        advantages.insert(0, advantage)
        returns.insert(0, R)

    return torch.tensor(advantages), torch.tensor(returns)


# ========== 训练主循环 ==========

def train_ppo(actor, critic, reward_model, reference_model, prompts):
    """
    完整的PPO训练流程
    """
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=1e-5)
    critic_optimizer = torch.optim.Adam(critic.parameters(), lr=1e-4)

    for iteration in range(1000):
        # ━━━━━ Rollout阶段 ━━━━━
        rollout_data = []

        for prompt in prompts:
            # 1. Actor生成
            response = actor.generate(prompt)
            old_log_probs = actor.get_log_probs(prompt, response)

            # 2. RM打分
            with torch.no_grad():
                rm_reward = reward_model(prompt, response)

            # 3. Reference计算KL
            with torch.no_grad():
                ref_log_probs = reference_model(prompt, response)
                kl_penalty = 0.1 * (old_log_probs - ref_log_probs).sum()

            # 4. 总奖励
            total_reward = rm_reward - kl_penalty

            # 5. Critic预测
            with torch.no_grad():
                value = critic(prompt, response)

            # 6. 保存数据
            rollout_data.append({
                'prompt': prompt,
                'response': response,
                'old_log_probs': old_log_probs,
                'reward': total_reward,
                'value': value
            })

        # 计算Advantage
        rewards = [d['reward'] for d in rollout_data]
        values = [d['value'] for d in rollout_data]
        advantages, returns = compute_advantages(rewards, values)

        # ━━━━━ 训练阶段 ━━━━━
        for epoch in range(4):  # PPO epochs
            for i, data in enumerate(rollout_data):
                # 更新Actor
                new_log_probs = actor(data['prompt'], data['response'])
                actor_loss = ppo_loss(
                    data['old_log_probs'],
                    new_log_probs,
                    advantages[i]
                )

                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()

                # 更新Critic
                predicted_value = critic(data['prompt'], data['response'])
                critic_loss = F.mse_loss(predicted_value, returns[i])

                critic_optimizer.zero_grad()
                critic_loss.backward()
                critic_optimizer.step()

        print(f"Iter {iteration}: Avg Reward = {sum(rewards)/len(rewards):.2f}")

6.2 关键点总结

python 复制代码
# PPO的核心就这3行:
ratio = torch.exp(new_log_probs - old_log_probs)  # 计算比值
ratio_clipped = torch.clamp(ratio, 1-ε, 1+ε)      # 裁剪
loss = -torch.min(ratio * adv, ratio_clipped * adv)  # 取min

🤔 Part 7: 常见问题

Q1: PPO、RLHF、Actor-Critic的关系?

markdown 复制代码
RLHF(整体框架)
  └─ 阶段3用强化学习
      ├─ Actor-Critic(标准RL框架,1990s就有)
      │   ├─ Actor:策略网络
      │   └─ Critic:价值网络
      ├─ RLHF特有组件
      │   ├─ RM:替代人类打分
      │   └─ Reference:防止偏离
      └─ PPO(训练算法,2017年)
          └─ Clip机制 ← 核心创新

总结:
- Actor-Critic是框架
- RM和Reference是RLHF的设计
- PPO是训练Actor-Critic的算法
- Clip是PPO的核心创新

Q2: Critic和RM有什么区别?

diff 复制代码
模型结构:
- 几乎一样(Base + Linear(768, 1))
- 都输出一个标量

训练时机:
- Critic:PPO阶段训练
- RM:PPO之前(阶段2)就训练好了

训练数据:
- Critic:用PPO训练时的reward作为监督信号
- RM:用人类偏好数据(A vs B,哪个更好)

训练目标:
- Critic:预测累积奖励(V(s) = E[Σ r_t])
- RM:学习人类偏好(P(A > B) = σ(R(A) - R(B)))

使用方式:
- Critic:计算Advantage,辅助Actor训练
- RM:提供奖励信号,代表人类偏好

Q3: Reference为什么要固定?

arduino 复制代码
如果Reference也训练:
Actor改进 → Reference跟着改进
→ KL散度始终很小 → 没有惩罚作用
→ Actor可以随意改变 ❌

Reference固定:
Actor改进 → Reference不变(记住起点)
→ KL散度变大 → 有惩罚
→ Actor被迫在起点附近探索 ✅

类比:
减肥时的"初始体重"
- 不能随着减肥而"调整初始体重"
- 要记住真正的起点

Q4: epsilon=0.2是怎么来的?

diff 复制代码
经验值:
- Schulman论文(2017)测试了多个值
- 0.1、0.2、0.3都可以
- 0.2效果最稳定

含义:
- 每次最多改变20%
- 不是绝对的,可以调整

调整建议:
- 训练不稳定 → 减小ε(如0.1)
- 训练太慢 → 增大ε(如0.3)
- 大部分情况0.2就很好

Q5: 为什么不直接降低学习率?

diff 复制代码
降低学习率:
- 控制参数更新速度
- 但控制不了"输出概率"的变化
- 例如:学习率很小,但某个参数变化可能导致
  某个token的概率从10%跳到80%

PPO的Clip:
- 直接控制"输出概率"的变化
- ratio = P_new / P_old被限制在0.8-1.2
- 不管参数怎么变,输出概率变化有上限

更精确的控制 = Clip > 学习率

🎓 总结

对齐 vs RLHF vs PPO

diff 复制代码
┌─────────────────────────────────────────┐
│ 对齐(Alignment)                       │
│ 目标:让AI符合人类价值观                │
│                                          │
│   ↓ 实现方法之一                        │
│                                          │
│ RLHF(从人类反馈中学习)                │
│ 方法:用人类偏好训练模型                │
│   ├─ 阶段1: SFT                         │
│   ├─ 阶段2: 训练RM                      │
│   └─ 阶段3: RL微调 ← 用PPO             │
│                                          │
│   ↓ 训练算法                            │
│                                          │
│ PPO(Proximal Policy Optimization)     │
│ 核心:Clip机制限制更新幅度              │
└─────────────────────────────────────────┘

总结:
- 对齐是目标(make AI good)
- RLHF是方法(learn from humans)
- PPO是算法(train with clip)

核心要点

0. 概念层次

diff 复制代码
对齐(目标) > RLHF(方法) > PPO(算法) > Clip(技术)

例子:
- ChatGPT的目标:对齐人类价值观
- 使用方法:RLHF
- 训练算法:PPO
- 核心技术:Clip机制

1. RLHF三阶段

复制代码
阶段1(SFT):学会基本对话
阶段2(RM):训练裁判
阶段3(PPO):根据裁判反馈优化

2. 四个组件

复制代码
Actor:要训练的主模型
Critic:辅助训练(计算Advantage)
RM:提供奖励(固定)
Reference:记住起点(固定)

3. 模型结构

diff 复制代码
共享Base模型(Transformer)
只有Head不同:
- Actor/Reference: LM Head (vocab_size)
- Critic: Value Head (1)
- RM: Reward Head (1)

4. PPO的核心创新

python 复制代码
# 就这一行!
ratio_clipped = torch.clamp(ratio, 1-ε, 1+ε)

限制概率变化在±20%以内
→ 训练稳定
→ 不会崩溃

快速记忆

一句话记住RLHF:

先教AI说话(SFT),再训练裁判(RM),最后让AI根据裁判反馈慢慢进步(PPO)

一句话记住PPO:

让AI慢慢变好,每次最多改20%(Clip机制)

一行代码记住PPO:

python 复制代码
ratio_clipped = torch.clamp(ratio, 0.8, 1.2)

实际应用

使用PPO的著名项目:

  1. ChatGPT(OpenAI)
  2. Claude(Anthropic)
  3. Llama2-Chat(Meta)
  4. AlphaStar(DeepMind,星际争霸AI)

为什么都用PPO?

  • 简单:容易实现
  • 稳定:Clip机制防止崩溃
  • 有效:效果好于其他算法
  • 高效:不需要二阶优化
相关推荐
沸点小助手1 小时前
「AI 能力提升场」沸点获奖名单公示|本周互动话题上新🎊
aigc·openai·vibecoding
Meepo_haha1 小时前
创建Spring Initializr项目
java·后端·spring
Memory_荒年1 小时前
SpringBoot事务源码深度游:从注解到数据库的“奇幻漂流”
java·后端·spring
编码忘我1 小时前
为什么要用SpringBoot
java·后端
Memory_荒年2 小时前
SpringBoot事务:从“一键开关”到“踩坑大全”的生存指南
java·后端·spring
PFinal社区_南丞2 小时前
一文讲透 .trae 文件夹 - Trae IDE 配置指南和最佳实践
后端
段小二2 小时前
Spring AI Agent 完整实战:Function Calling + RAG + Memory + SafeGuard 构建机票助手
后端
编码忘我2 小时前
Spring源码又看了一遍
后端
希望永不加班2 小时前
SpringBoot 主启动类解释:@SpringBootApplication 到底做了什么
java·spring boot·后端·spring