算法面试高频点:一文搞懂"重要性采样"及 Token-Level 损失为何难收敛?
在研究大模型强化学习对齐(如 PPO、GRPO)时,我们经常会遇到一个核心概念:重要性采样(Importance Sampling)。
面试官经常会问:"为什么要引入重要性采样?"或者更深入一点:"PPO/GRPO 在 Token 级别计算重要性权重,会带来什么问题?"
今天,我们就用一个最简单的"造句"例子,把这两个硬核问题彻底讲透。
一、 什么是"重要性采样"?(通俗举例篇)
假设我们要训练一个语言模型生成一句话:"猫喜欢鱼" 。这句话包含 3 个 Token: t 1 = t_1= t1="猫"、 t 2 = t_2= t2="喜欢"、 t 3 = t_3= t3="鱼"。
在强化学习中,我们通常面临两个策略(模型):
- 目标策略 P (Target Policy): 这是我们最终想要训练出的聪明模型。
- 行为策略 Q (Behavior Policy): 这是我们在训练过程中用来生成采样数据(也就是过去收集到的经验数据)的旧模型。
假设这两个模型生成各个词的条件概率如下:
| 动作 (Token) | 目标策略 P 的概率 (聪明) | 行为策略 Q 的概率 (笨) |
|---|---|---|
| t1 = "猫" (句首) | 0.8 | 0.5 |
| t2 = "喜欢" (接在"猫"后) | 0.7 | 0.2 |
| t3 = "鱼" (接在"猫喜欢"后) | 0.9 | 0.3 |
核心计算:序列级的重要性权重
重要性采样的核心思想是:用旧模型 Q 采样的样本,来等效评估新模型 P 的表现。
怎么"等效"呢?就是计算一个权重(Ratio)。对于整句"猫喜欢鱼",序列级的权重是条件概率的连乘比值:
序列权重 = P ( 全句 ) Q ( 全句 ) = P ( t 1 ) ⋅ P ( t 2 ∣ t 1 ) ⋅ P ( t 3 ∣ t 1 , t 2 ) Q ( t 1 ) ⋅ Q ( t 2 ∣ t 1 ) ⋅ Q ( t 3 ∣ t 1 , t 2 ) \text{序列权重} = \frac{P(\text{全句})}{Q(\text{全句})} = \frac{P(t_1) \cdot P(t_2|t_1) \cdot P(t_3|t_1, t_2)}{Q(t_1) \cdot Q(t_2|t_1) \cdot Q(t_3|t_1, t_2)} 序列权重=Q(全句)P(全句)=Q(t1)⋅Q(t2∣t1)⋅Q(t3∣t1,t2)P(t1)⋅P(t2∣t1)⋅P(t3∣t1,t2)
代入数值计算:
序列权重 = 0.8 × 0.7 × 0.9 0.5 × 0.2 × 0.3 = 0.504 0.03 = 16.8 \text{序列权重} = \frac{0.8 \times 0.7 \times 0.9}{0.5 \times 0.2 \times 0.3} = \frac{0.504}{0.03} = 16.8 序列权重=0.5×0.2×0.30.8×0.7×0.9=0.030.504=16.8
通俗解释:
旧模型 Q 想要生成"猫喜欢鱼"这句话的概率,比新模型 P 低了整整 16.8 倍。因此,当我们用旧模型 Q 生成的数据来更新新模型 P 时,必须给这个样本放大 16.8 倍的权重 ,才能起到正确的纠偏作用。这就是重要性采样的魔力。
二、 官方定义与数学公式
在数学上,重要性采样是在无法直接从目标分布 P P P 采样时,通过从一个更容易获取的行为分布 Q Q Q 中采样,再用"重要性权重"矫正偏差,最终估算出目标分布的期望。
公式表达为:
E P f ( x ) = 1 n ∑ i = 1 n ( P ( x i ) Q ( x i ) ) ⋅ f ( x i ) \mathbb{E}Pf(x) = \frac{1}{n}\sum{i=1}^{n}\left(\frac{P(x_i)}{Q(x_i)}\right) \cdot f(x_i) EPf(x)=n1i=1∑n(Q(xi)P(xi))⋅f(xi)
在 PPO/GRPO 中的应用:
为了节约成本,我们不可能每次更新参数都让模型重新生成一遍数据。我们会把旧模型生成的数据存在库里(Replay Buffer),更新时直接拿来用。
为了弥补新旧模型产生的偏差,定义了每个 Token 的重要性比例 r t r_t rt:
r t = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} rt=πθold(at∣st)πθ(at∣st)
然后在损失函数中,使用 Clip 函数限制更新幅度,防止新模型偏离旧模型太远导致崩溃:
L ( θ ) = E t min ( r t A t , CLIP ( r t , 1 − ϵ , 1 + ϵ ) A t ) L(\theta) = \mathbb{E}_t\\min(r_t A_t, \\text{CLIP}(r_t, 1-\\epsilon, 1+\\epsilon)A_t) L(θ)=Etmin(rtAt,CLIP(rt,1−ϵ,1+ϵ)At)
三、 面试深挖:为什么 Token-Level 的损失难以收敛?
目前主流的 PPO 和 GRPO 算法,都是在 Token-Level(逐个词) 去计算重要性权重的。这会带来一个致命的缺陷。
回到刚才的例子,如果我们逐 Token 去算权重:
- t 1 t_1 t1 的权重 = 0.8 0.5 = 1.6 \frac{0.8}{0.5} = 1.6 0.50.8=1.6
- t 2 t_2 t2 的权重 = 0.7 0.2 = 3.5 \frac{0.7}{0.2} = 3.5 0.20.7=3.5
- t 3 t_3 t3 的权重 = 0.9 0.3 = 3 \frac{0.9}{0.3} = 3 0.30.9=3
痛点 1:无法反映真实的整体偏差
我们刚才算过,整个句子(序列)的真实偏差是 16.8 倍 。但单个 Token 算出来的权重(1.6、3.5、3)都远远小于 16.8。
如果用单个 Token 的权重去更新,会严重低估需要放大的程度,导致分布矫正不足。模型总是感觉"差了那么点意思",怎么都无法完美收敛到目标状态。
痛点 2:高方差导致的"抓错重点"
逐 Token 的权重会把每个词局部的波动直接放大带入训练。
在例子中, t 2 t_2 t2 的权重(3.5)最高,模型可能会错误地以为"哦!'喜欢'这个词是最关键的,我要花大力气去关注它!"
但实际上,整句话表现不好是三个词共同拉胯导致的。这种局部的视野会导致单次采样的方差被急剧放大,模型会过度纠结于某个特定的 Token,而忽略了整个句子上下文的合理性。
四、 面试总结 (Cheat Sheet)
如果在面试中遇到相关问题,你可以用以下逻辑清晰作答:
- 什么是重要性采样? 为了复用历史策略(旧模型)产生的数据来评估和更新当前策略(新模型),我们需要计算新旧策略在同一动作上的概率比( r t r_t rt)作为修正权重。它是降低 RLHF 采样成本的核心机制。
- Token-Level 的缺陷是什么? PPO/GRPO 逐词计算权重存在两大问题:一是矫正不足 ,单个 Token 的概率比远小于整个序列的联合概率比,无法真实反映新旧策略在长文本上的整体偏差;二是高方差,容易让模型"抓错重点",过度关注局部波动的 Token 而忽视句子整体语义,导致大模型在长文本强化学习时极难稳定收敛。这也是为什么阿里近期提出了基于句子级别的 GSPO 算法来尝试破局。