RL 系统 Infra 笔记:区分不同模型

强化学习系统(RLHF/PPO)Infra 学习笔记,从 Infra 视角梳理各模块职责、数据流与训练循环,持续更新。


1. 🧭 全局感知:RL 系统在干什么?

大模型强化学习(以 PPO 为代表)的目标是:让模型学会生成"更好"的回答

整个系统可以用一句话概括:

模型自己生成回答 → 评分器打分 → 算出每步好坏 → 更新模型 → 循环

用"老师批改作文"来类比:

现实场景 RL 系统对应
学生写作文 Actor 做 Rollout,生成 response
老师打分 Reward Model 给 response 评分
"这句话比平均水平好多少" Advantages 计算
学生对照批改意见改进写法 Actor 根据 Loss 更新参数
原始答题模板,防止学生跑偏 Reference Model 提供 KL 约束

2. 🏗️ 系统里有几个模型?

整个系统只有 2~3 个模型实体,但在流程中有多个"叫法",容易混淆:

复制代码
┌─────────────────────────────────────────────────────────┐
│  模型实体 1:Actor 模型(✅ 会被训练更新)                  │
│     ├── Rollout 阶段用它生成文本   → 叫 "Rollout"         │
│     └── 训练阶段用它重算概率       → 叫 "Actor_fwd"       │
│         ⚠️ 这两个叫法本质是同一个模型,只是使用方式不同!    │
├─────────────────────────────────────────────────────────┤
│  模型实体 2:Reference 模型(❌ 冻结,永远不更新)           │
│     └── Actor 的初始参数副本,整个训练过程参数不变           │
├─────────────────────────────────────────────────────────┤
│  模型实体 3:Reward 模型(❌ 通常冻结)                     │
│     └── 专门用来打分的独立模型,与 Actor 架构可以不同        │
└─────────────────────────────────────────────────────────┘

💡 Advantages 不是模型,它是一个算法计算步骤(公式),输入是 Reward 分数和各种 log_prob,输出是每个 token 的优势值。


3. 🔬 各模块职责详解

3.1 🎬 Rollout(采样/生成阶段)

本质:Actor 模型在"生成模式"下运行,对一批 prompt 自回归地生成完整 response。

关键点

  • 生成每个 token 时,顺手记录该 token 的生成概率 (这就是 old_log_prob),不需要额外计算
  • 这是整个流程最耗时的步骤(串行逐 token 生成)
  • Rollout 结束后,这批数据固定不变,供后续多个 mini batch 复用

输出示例

python 复制代码
{
  "prompt":        "苹果是什么颜色?",
  "response":      "苹果是红色的",
  "input_ids":     [3928, 374, 1226, 1284],
  "old_log_probs": [-0.10, -0.05, -0.80, -0.10],  # 每个 token 的 log 概率,生成时顺手存的
  "attention_mask":[1, 1, 1, 1]
}

3.2 🏛️ Reference Model(参考模型)

本质 :训练开始前,把 Actor 的初始参数复制一份并冻结,整个训练过程中参数永远不变。

为什么需要它:防止 Actor 训练跑偏太远。通过计算 KL 散度施加惩罚:

复制代码
KL = old_log_prob - ref_log_prob
r_final = reward_score - β × KL

输出示例

python 复制代码
{
  "ref_log_probs": [-0.12, -0.06, -0.90, -0.11]
  # 格式与 old_log_probs 相同,但来自冻结的初始模型,整个训练过程值不变
}

3.3 🏆 Reward Model(奖励模型)

本质:独立的打分模型,通常在 RL 训练开始前已训练好并冻结。

输出示例

python 复制代码
{
  "reward_score": 1.5   # 一个标量,越高越好
}

3.4 📊 Advantages(优势值)

本质不是模型,是计算步骤。衡量某个 token 比"预期"好多少,是 Actor 更新的直接驱动力。

计算原料reward_score + ref_log_probs + old_log_probs → 经 GAE 等算法 → advantages

输出示例

python 复制代码
{
  "advantages": [+0.1, +0.1, +1.2, +0.1]
  #               "苹果" "是"  "红色"  "的"
  # 正数 → 这个 token 选得好,强化它;负数 → 选得差,抑制它
  # "红色"贡献最大,是回答质量的关键
}

3.5 🔁 Actor_fwd(Actor 前向重算)

本质 :用当前最新版本的 Actor ,对 Rollout 已生成的 response 重新做前向传播,得到 new_log_prob

关键:不生成新文本,只重算概率。同一批 Rollout 数据训练多个 mini batch,每次 Actor 参数更新后都需要重算。

输出示例

python 复制代码
{
  "new_log_probs": [-0.09, -0.04, -0.60, -0.09]
  # 格式与 old_log_probs 完全相同,但值不同(Actor 参数已更新)
}

4. 🔗 完整单步数据流

下图展示一次完整的 PPO 迭代中,数据如何在各模块间流动:
存入缓冲区
下一轮
📥 Prompt

'苹果是什么颜色?'
🎬 Rollout

【Actor 模型,生成模式】

自回归生成文本

✅ 产出:

response = '苹果是红色的'

old_log_prob = [-0.10,-0.05,-0.80,-0.10]
💾 经验缓冲区

prompt + response + old_log_prob
🏆 Reward Model

【独立模型,冻结】

看 prompt+response 打分

✅ 产出:

reward_score = 1.5

一个标量
🏛️ Reference Model

【Actor初始副本,冻结】

对 response 做前向

✅ 产出:

ref_log_prob = [-0.12,-0.06,-0.90,-0.11]

每个token一个值
📊 Advantages

【不是模型!是公式】

= f(reward_score, ref_log_prob, old_log_prob)

✅ 产出:

advantages = [+0.1,+0.1,+1.2,+0.1]

每个token一个值
🔁 Actor_fwd

【Actor 模型,前向模式】

对已有response重算概率

✅ 产出:

new_log_prob = [-0.09,-0.04,-0.60,-0.09]

每个token一个值
📉 PPO Loss

ratio = exp(new - old)

Loss = ratio × Advantages
✅ 反向传播,更新 Actor 参数

🔵 蓝色 = 同一个 Actor 模型的两种用法

🟠 橙色 = Reference(冻结副本)

🔴 红色 = Reward Model(独立模型)

🟢 绿色 = Advantages(公式计算,非模型)

🟣 紫色 = Loss 计算与参数更新


5. ⏱️ 三种 log_prob 的更新频率

这是理解 PPO 训练节奏最关键的一点,也是最容易混淆的地方。

5.1 三者的本质区别

来自哪个模型 更新时机 直觉含义
ref_log_prob Reference(冻结) ❌ 永远不变 "出生时的我"的概率
old_log_prob Rollout 时的 Actor 每个 Global Batch 刷新一次 "这轮开始时的我"的概率
new_log_prob 当前最新 Actor 每个 Mini Batch 都在变 "此刻的我"的概率

5.2 更新频率时间轴

复制代码
训练过程时间轴 ──────────────────────────────────────────────────▶

ref_log_prob:
████████████████████████████████████████████████  永远不变

old_log_prob:
[══════ GB1 固定 ══════][══════ GB2 固定 ══════][══ GB3...
 每个 Global Batch 做一次 Rollout,存下新的 old_log_prob

new_log_prob:
[mb1][mb2][mb3][mb4][mb5][mb6][mb7][mb8][mb9]...
 每个 Mini Batch 训练后 Actor 参数更新,下次 Actor_fwd 就会算出不同的值

5.3 具体数值示例

以一个 Global Batch 内部为例,追踪 "红色" 这个 token 的概率变化:

复制代码
Global Batch N:

  Rollout(Actor_v0 生成):
    old_log_prob["红色"] = -0.80   ← 固定,本轮不再变

  ref_log_prob["红色"]  = -0.90   ← 永远不变

  Mini Batch 1(Actor_fwd 用 Actor_v0):
    new_log_prob["红色"] = -0.80   ratio = exp(-0.80 - (-0.80)) = 1.00
    → 刚开始,新旧一样,ratio = 1

  Mini Batch 2(Actor_fwd 用 Actor_v1):
    new_log_prob["红色"] = -0.70   ratio = exp(-0.70 - (-0.80)) = 1.11
    → Actor 开始增大"红色"的概率

  Mini Batch 3(Actor_fwd 用 Actor_v2):
    new_log_prob["红色"] = -0.62   ratio = exp(-0.62 - (-0.80)) = 1.20
    → 接近 clip 上界(通常是 1.2),PPO 开始截断

  Mini Batch 4(Actor_fwd 用 Actor_v3):
    new_log_prob["红色"] = -0.55   ratio 被 clip 到 1.2
    → 超出范围,梯度被截断,防止更新过猛

6. 🔄 两层训练循环结构

PPO 训练由两层嵌套循环组成,理解这个结构对 Infra 实现至关重要。
🔄 外层循环:Global Batch(每轮做一次 Rollout)
🔁 内层循环:Mini Batch(共 K 步)
复制并冻结
下一个 mini batch
ref_log_prob

用于 KL 惩罚
old_log_prob 固定
K 步后得到 Actor_v(N+1)
🚀 初始化

Actor_v0
🏛️ Reference

参数永远 = v0
🎬 Rollout

用当前 Actor_vN 生成一批数据

存下 old_log_prob(本轮固定不变)
🔁 Actor_fwd

用当前最新 Actor

重算 new_log_prob
📉 PPO Loss

ratio = exp(new-old)

clip(ratio, 1-ε, 1+ε) × Adv
✅ 更新 Actor 参数

Actor_vN → Actor_v(N+1/K)

6.1 两层循环的关键参数

参数 含义 典型值 影响
Global Batch Size 每次 Rollout 生成多少条数据 512~4096 条 越大数据越多样,但 Rollout 越慢
Mini Batch Size 每次梯度更新用多少条数据 32~256 条 影响显存占用和梯度噪声
K(内层步数) 同一批数据训练几轮 1~4 越大复用率越高,但 ratio 偏离越大
ε(clip 范围) ratio 的截断范围 0.2(即 [0.8, 1.2]) 防止单步更新过猛

7. 🧮 PPO Loss 计算全流程

把所有模块的输出汇聚到一起,看 Loss 是怎么算出来的:
old_log_prob

-0.10,-0.05,-0.80,-0.10

来自 Rollout 缓冲区
new_log_prob

-0.09,-0.04,-0.60,-0.09

来自 Actor_fwd
advantages

+0.1,+0.1,+1.2,+0.1

来自 Advantages 计算
ratio

= exp(new - old)

= [1.01, 1.01, 1.22, 1.01]
clip(ratio, 0.8, 1.2)

= [1.01, 1.01, 1.20, 1.01]

超出范围的截断
PPO Loss

= -mean( min(ratio, clipped) × Adv )

负号:因为要最大化奖励
ref_log_prob

-0.12,-0.06,-0.90,-0.11

来自 Reference Model
KL 散度

= old - ref

= [+0.02,+0.01,+0.10,+0.01]
Total Loss

= PPO Loss + β × KL

逐步解释

python 复制代码
# 1. 算新旧概率比(ratio)
ratio = exp(new_log_prob - old_log_prob)
# ratio > 1 → Actor 提高了这个 token 的概率
# ratio < 1 → Actor 降低了这个 token 的概率

# 2. clip:防止单步更新过猛
clipped_ratio = clip(ratio, 1 - ε, 1 + ε)   # ε 通常 = 0.2

# 3. PPO Loss(取两者中更保守的那个)
ppo_loss = -mean( min(ratio, clipped_ratio) × advantages )
# 负号:梯度下降最小化 Loss = 最大化期望奖励

# 4. 加上 KL 惩罚(防止跑偏)
kl = mean(old_log_prob - ref_log_prob)
total_loss = ppo_loss + β × kl

8. 📦 各模块输出格式速查

prompt = "苹果是什么颜色?" response = "苹果是红色的" 为例,汇总所有模块的输出:

复制代码
Token 序列:  ["苹果",  "是",   "红色",  "的"]
Token ID:    [3928,    374,    1226,    1284]
                ↓        ↓        ↓        ↓
old_log_prob: [-0.10,  -0.05,  -0.80,  -0.10]   ← Rollout 时 Actor 顺手存的
ref_log_prob: [-0.12,  -0.06,  -0.90,  -0.11]   ← Reference 模型算的,永远不变
new_log_prob: [-0.09,  -0.04,  -0.60,  -0.09]   ← Actor_fwd 用最新参数重算的

reward_score:  1.5                               ← Reward Model 输出,一个标量

advantages:   [+0.10,  +0.10,  +1.20,  +0.10]  ← 公式计算,每 token 一个值

ratio:        [ 1.01,   1.01,   1.22,   1.01]   ← exp(new - old),每 token 一个值

格式规律总结

输出类型 形状 说明
old_log_prob [seq_len] 每个 token 一个 float
ref_log_prob [seq_len] 每个 token 一个 float
new_log_prob [seq_len] 每个 token 一个 float
advantages [seq_len] 每个 token 一个 float
reward_score scalar 整条 response 一个 float
ratio [seq_len] 每个 token 一个 float

💡 除了 reward_score 是标量,其余所有中间产物都是与 token 序列等长的向量。


9. 🗂️ 模块关系总览

① 复制冻结
② 存入缓冲区
③ 打分
④ 算参考概率
ref_log_prob
reward_score
⑤ 重算概率

每 mini batch 都在变
advantages
new_log_prob
⑥ 反向传播
⑦ 下一轮 Rollout
🚀 训练开始

Actor_v0(SFT 模型)
🏛️ Reference Model

❌ 参数永不更新

整个训练只用来算 ref_log_prob
🎬 Rollout

= Actor 生成模式

每 Global Batch 执行一次

✅ 产出 old_log_prob
💾 经验缓冲区

prompt + response

  • old_log_prob
    🏆 Reward Model

❌ 参数不更新

✅ 产出 reward_score(标量)
📊 Advantages 计算

不是模型,是公式

✅ 产出每 token 的优势值
🔁 Actor_fwd

= Actor 前向模式

✅ 产出 new_log_prob
📉 PPO Loss

ratio × Adv + β×KL
✅ 更新 Actor 参数

Actor_vN → Actor_v(N+1)


10. 💡 常见疑问 Q&A

Q:old_log_prob 是 Rollout 时算的,还是 Reference 算的?

A:Rollout 时 Actor 顺手算的,跟 Reference 没有任何关系 。生成 token 本身就需要算概率(从词表采样),存下来就是 old_log_prob。Reference 算的是 ref_log_prob,是完全独立的另一个东西。

Q:Actor 和 Actor_fwd 是同一个模型吗?

A:是的,完全同一个模型。区别只是使用方式:Rollout 时是"生成模式"(自回归采样),Actor_fwd 时是"前向模式"(给定文本算概率)。就像同一个人,一次是"自由发挥写作文",另一次是"对着已写好的作文重新估算每个字的概率"。

Q:为什么需要 Actor_fwd 重新算概率,直接用 Rollout 时存的 old_log_prob 不行吗?

A:不行。PPO 的核心是用 ratio = new/old 来衡量"Actor 更新了多少"。如果不重算,new == oldratio 永远是 1,Loss 就没有意义了。重算是为了捕捉每次 mini batch 更新后 Actor 的变化。

Q:Advantages 和 new_log_prob 有关系吗?

A:计算 Advantages 时不需要 new_log_prob 。Advantages 的原料是 reward_score + ref_log_prob + old_log_prob,在 Rollout 结束后就可以算好,与 Actor_fwd 无关。但两者都是计算 PPO Loss 的原料,在 Loss 计算时汇聚到一起:Loss = ratio × Advantages,其中 ratio 才用到 new_log_prob

Q:Reference Model 和 Reward Model 有什么区别?

A:两者职责完全不同:

  • Reward Model(奖励模型) :对整条生成序列打一个标量分数 ,衡量"这个回答有多好"。它只在 Rollout 阶段的末尾调用一次,输出 reward_score,是外部对生成质量的评判。
  • Reference Model(参考模型) :是 Actor 的"冻结副本",逐 token 输出 log 概率,即 ref_log_prob。它的作用是充当 KL 惩罚的基准------防止 Actor 训练过程中偏离原始分布太远,避免模型"钻空子"刷高奖励但输出退化。

一句话总结:Reward Model 告诉你"结果好不好",Reference Model 告诉你"走得偏不偏"。


Q:KL 惩罚是怎么加进去的?

A:KL 惩罚以逐 token 的形式叠加到奖励信号里,而不是作为一个单独的 Loss 项(也有实现把它加到 Loss 里,但主流 RLHF 实现如 TRL、OpenRLHF 是加到 reward 里)。

具体公式:

rt={reward_score−β⋅KLtt=T(最后一个 token)−β⋅KLtt<Tr_t = \begin{cases} \text{reward\_score} - \beta \cdot \text{KL}_t & t = T \text{(最后一个 token)} \\ -\beta \cdot \text{KL}_t & t < T \end{cases}rt={reward_score−β⋅KLt−β⋅KLtt=T(最后一个 token)t<T

其中:

KLt=log⁡πθ(at∣st)−log⁡πref(at∣st)=new_log_probt−ref_log_probt\text{KL}t = \log \pi{\theta}(a_t | s_t) - \log \pi_{\text{ref}}(a_t | s_t) = \text{new\_log\_prob}_t - \text{ref\_log\_prob}_tKLt=logπθ(at∣st)−logπref(at∣st)=new_log_probt−ref_log_probt

β\betaβ 是控制 KL 惩罚强度的超参数。这样每一步都有一个"软约束",让 Actor 不敢偏离 Reference 太远。


Q:Critic 和 Reward Model 是同一个东西吗?

A:不是,这是 RLHF 里最容易混淆的概念之一:

Reward Model Critic(价值网络)
输入 完整序列(prompt + response) 当前 token 的隐状态
输出 一个标量:整条序列的得分 每个时间步的状态价值 V(st)V(s_t)V(st)
调用时机 Rollout 结束后,打一次分 每次训练时,逐步估计
是否更新 ❌ 冻结(已训练好) ✅ 随 PPO 一起更新
作用 提供最终奖励信号 计算 Advantage(基线减方差)

Critic 通常用 Reward Model 的权重初始化,但之后独立训练。它的目标是学会预测"从当前状态出发,未来能拿到多少累计奖励",从而为 Advantage 估计提供基线,降低梯度方差。


Q:Advantage 具体是怎么计算的?GAE 是什么?

A:最朴素的 Advantage 定义是:

At=Q(st,at)−V(st)A_t = Q(s_t, a_t) - V(s_t)At=Q(st,at)−V(st)

即"实际获得的回报"减去"基线预期"。但直接用 Monte Carlo 回报估计 QQQ 方差很大,所以实践中用 GAE(Generalized Advantage Estimation)

AtGAE=∑l=0∞(γλ)lδt+lA_t^{\text{GAE}} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}AtGAE=l=0∑∞(γλ)lδt+l

其中 TD 误差为:

δt=rt+γV(st+1)−V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)δt=rt+γV(st+1)−V(st)

  • γ\gammaγ:折扣因子,控制对未来奖励的重视程度(RLHF 中常设为 1)
  • λ\lambdaλ:GAE 平滑系数,在**偏差(bias)和方差(variance)**之间权衡
    • λ=0\lambda = 0λ=0:退化为单步 TD,低方差高偏差
    • λ=1\lambda = 1λ=1:退化为 Monte Carlo,无偏但高方差

实践中 λ\lambdaλ 常取 0.95 左右。


Q:PPO Loss 的完整形式是什么?

A:PPO 的完整训练目标包含三项:

LPPO=Lclip+c1LVF−c2Lentropy\mathcal{L}{\text{PPO}} = \mathcal{L}{\text{clip}} + c_1 \mathcal{L}{\text{VF}} - c_2 \mathcal{L}{\text{entropy}}LPPO=Lclip+c1LVF−c2Lentropy

① Policy Loss(核心):

Lclip=−Et[min⁡(rtAt, clip(rt,1−ϵ,1+ϵ)At)]\mathcal{L}_{\text{clip}} = -\mathbb{E}_t \left[ \min\left( r_t A_t,\ \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A_t \right) \right]Lclip=−Et[min(rtAt, clip(rt,1−ϵ,1+ϵ)At)]

其中 rt=πθ(at∣st)πθold(at∣st)=enew_log_prob−old_log_probr_t = \dfrac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)} = e^{\text{new\_log\_prob} - \text{old\_log\_prob}}rt=πθold(at∣st)πθ(at∣st)=enew_log_prob−old_log_prob

Clip 机制限制每次更新的步长,防止策略突变。

② Value Loss:

LVF=Et[(Vθ(st)−Vttarget)2]\mathcal{L}_{\text{VF}} = \mathbb{E}t \left[ \left( V\theta(s_t) - V_t^{\text{target}} \right)^2 \right]LVF=Et[(Vθ(st)−Vttarget)2]

监督 Critic 更准确地预测状态价值。

③ Entropy Bonus(可选):

鼓励策略保持一定随机性,防止过早收敛到确定性策略。


Q:一次 PPO 训练循环的完整流程是什么?

A:标准流程可以分为两个大阶段:

阶段一:Rollout(数据收集,不更新参数)

  1. Actor 对 prompt 做自回归生成,得到 response
  2. Reference Model 计算 ref_log_prob
  3. Actor 同时记录 old_log_prob
  4. Reward Model 对完整序列打分,得到 reward_score
  5. Critic 估计每步 V(st)V(s_t)V(st)
  6. 用上述数据计算 KL 惩罚、合并奖励、计算 GAE Advantages

阶段二:Training(多轮 mini-batch 更新)

  1. Actor 前向传播,用当前参数重新计算 new_log_prob
  2. 计算 ratio = exp(new_log_prob - old_log_prob)
  3. 计算 Clip Policy Loss
  4. Critic 前向传播,计算 Value Loss
  5. 合并 Loss,反向传播,更新 Actor 和 Critic 参数
  6. 重复多个 epoch(PPO epoch,通常 1~4 次)

之后回到阶段一,开始下一轮 Rollout。

相关推荐
Narrastory1 天前
Note:强化学习(三)
人工智能·深度学习·强化学习
Robot_Nav3 天前
RL-Driven MPPI:基于离线策略加速在线控制律计算的模型预测路径积分控制
rl·learning_based·mppi
盼小辉丶4 天前
PyTorch强化学习实战(1)——强化学习(Reinforcement Learning,RL)详解
人工智能·pytorch·深度学习·强化学习
可编程芯片开发5 天前
基于Qlearning强化学习的源荷扰动下交直流微电网负荷频率控制算法matlab仿真
matlab·强化学习·交直流微电网·qlearning·负荷频率控制
星马梦缘6 天前
强化学习实战-2——Keras-DoubleDQN解决Predator【图像输入】
人工智能·python·jupyter·cnn·keras·强化学习·dqn
阿杰学AI6 天前
AI核心知识121—大语言模型之 基于人类反馈的强化学习 (简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·强化学习·奖励模型·rm
Narrastory6 天前
Note:强化学习(二)
人工智能·深度学习·强化学习
星马梦缘6 天前
强化学习实战8.1——用PPO打赢星际争霸【环境配置与下位机代码】
人工智能·python·jupyter·强化学习·星际争霸·stablebaseline3·starcraft2
阿杰学AI6 天前
AI核心知识122—大语言模型之 直接偏好优化(简洁且通俗易懂版)
人工智能·算法·机器学习·ai·强化学习·dpo·直接优化偏好