大模型面试必备17-重要性采样


算法面试高频点:一文搞懂"重要性采样"及 Token-Level 损失为何难收敛?

在研究大模型强化学习对齐(如 PPO、GRPO)时,我们经常会遇到一个核心概念:重要性采样(Importance Sampling)

面试官经常会问:"为什么要引入重要性采样?"或者更深入一点:"PPO/GRPO 在 Token 级别计算重要性权重,会带来什么问题?"

今天,我们就用一个最简单的"造句"例子,把这两个硬核问题彻底讲透。


一、 什么是"重要性采样"?(通俗举例篇)

假设我们要训练一个语言模型生成一句话:"猫喜欢鱼" 。这句话包含 3 个 Token: t 1 = t_1= t1="猫"、 t 2 = t_2= t2="喜欢"、 t 3 = t_3= t3="鱼"。

在强化学习中,我们通常面临两个策略(模型):

  1. 目标策略 P (Target Policy): 这是我们最终想要训练出的聪明模型。
  2. 行为策略 Q (Behavior Policy): 这是我们在训练过程中用来生成采样数据(也就是过去收集到的经验数据)的旧模型。

假设这两个模型生成各个词的条件概率如下:

动作 (Token) 目标策略 P 的概率 (聪明) 行为策略 Q 的概率 (笨)
t1 = "猫" (句首) 0.8 0.5
t2 = "喜欢" (接在"猫"后) 0.7 0.2
t3 = "鱼" (接在"猫喜欢"后) 0.9 0.3

核心计算:序列级的重要性权重

重要性采样的核心思想是:用旧模型 Q 采样的样本,来等效评估新模型 P 的表现

怎么"等效"呢?就是计算一个权重(Ratio)。对于整句"猫喜欢鱼",序列级的权重是条件概率的连乘比值:

序列权重 = P ( 全句 ) Q ( 全句 ) = P ( t 1 ) ⋅ P ( t 2 ∣ t 1 ) ⋅ P ( t 3 ∣ t 1 , t 2 ) Q ( t 1 ) ⋅ Q ( t 2 ∣ t 1 ) ⋅ Q ( t 3 ∣ t 1 , t 2 ) \text{序列权重} = \frac{P(\text{全句})}{Q(\text{全句})} = \frac{P(t_1) \cdot P(t_2|t_1) \cdot P(t_3|t_1, t_2)}{Q(t_1) \cdot Q(t_2|t_1) \cdot Q(t_3|t_1, t_2)} 序列权重=Q(全句)P(全句)=Q(t1)⋅Q(t2∣t1)⋅Q(t3∣t1,t2)P(t1)⋅P(t2∣t1)⋅P(t3∣t1,t2)

代入数值计算:

序列权重 = 0.8 × 0.7 × 0.9 0.5 × 0.2 × 0.3 = 0.504 0.03 = 16.8 \text{序列权重} = \frac{0.8 \times 0.7 \times 0.9}{0.5 \times 0.2 \times 0.3} = \frac{0.504}{0.03} = 16.8 序列权重=0.5×0.2×0.30.8×0.7×0.9=0.030.504=16.8

通俗解释:

旧模型 Q 想要生成"猫喜欢鱼"这句话的概率,比新模型 P 低了整整 16.8 倍。因此,当我们用旧模型 Q 生成的数据来更新新模型 P 时,必须给这个样本放大 16.8 倍的权重才能起到正确的纠偏作用。这就是重要性采样的魔力


二、 官方定义与数学公式

在数学上,重要性采样是在无法直接从目标分布 P P P 采样时,通过从一个更容易获取的行为分布 Q Q Q 中采样,再用"重要性权重"矫正偏差,最终估算出目标分布的期望。

公式表达为:

E P f ( x ) = 1 n ∑ i = 1 n ( P ( x i ) Q ( x i ) ) ⋅ f ( x i ) \mathbb{E}Pf(x) = \frac{1}{n}\sum{i=1}^{n}\left(\frac{P(x_i)}{Q(x_i)}\right) \cdot f(x_i) EPf(x)=n1i=1∑n(Q(xi)P(xi))⋅f(xi)

在 PPO/GRPO 中的应用:

为了节约成本,我们不可能每次更新参数都让模型重新生成一遍数据。我们会把旧模型生成的数据存在库里(Replay Buffer),更新时直接拿来用。

为了弥补新旧模型产生的偏差,定义了每个 Token 的重要性比例 r t r_t rt:

r t = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} rt=πθold(at∣st)πθ(at∣st)

然后在损失函数中,使用 Clip 函数限制更新幅度,防止新模型偏离旧模型太远导致崩溃:

L ( θ ) = E t min ⁡ ( r t A t , CLIP ( r t , 1 − ϵ , 1 + ϵ ) A t ) L(\theta) = \mathbb{E}_t\\min(r_t A_t, \\text{CLIP}(r_t, 1-\\epsilon, 1+\\epsilon)A_t) L(θ)=Etmin(rtAt,CLIP(rt,1−ϵ,1+ϵ)At)


三、 面试深挖:为什么 Token-Level 的损失难以收敛?

目前主流的 PPO 和 GRPO 算法,都是在 Token-Level(逐个词) 去计算重要性权重的。这会带来一个致命的缺陷。

回到刚才的例子,如果我们逐 Token 去算权重:

  • t 1 t_1 t1 的权重 = 0.8 0.5 = 1.6 \frac{0.8}{0.5} = 1.6 0.50.8=1.6
  • t 2 t_2 t2 的权重 = 0.7 0.2 = 3.5 \frac{0.7}{0.2} = 3.5 0.20.7=3.5
  • t 3 t_3 t3 的权重 = 0.9 0.3 = 3 \frac{0.9}{0.3} = 3 0.30.9=3

痛点 1:无法反映真实的整体偏差

我们刚才算过,整个句子(序列)的真实偏差是 16.8 倍 。但单个 Token 算出来的权重(1.6、3.5、3)都远远小于 16.8。

如果用单个 Token 的权重去更新,会严重低估需要放大的程度,导致分布矫正不足。模型总是感觉"差了那么点意思",怎么都无法完美收敛到目标状态。

痛点 2:高方差导致的"抓错重点"

逐 Token 的权重会把每个词局部的波动直接放大带入训练。

在例子中, t 2 t_2 t2 的权重(3.5)最高,模型可能会错误地以为"哦!'喜欢'这个词是最关键的,我要花大力气去关注它!"

但实际上,整句话表现不好是三个词共同拉胯导致的。这种局部的视野会导致单次采样的方差被急剧放大,模型会过度纠结于某个特定的 Token,而忽略了整个句子上下文的合理性


四、 面试总结 (Cheat Sheet)

如果在面试中遇到相关问题,你可以用以下逻辑清晰作答:

  1. 什么是重要性采样? 为了复用历史策略(旧模型)产生的数据来评估和更新当前策略(新模型),我们需要计算新旧策略在同一动作上的概率比( r t r_t rt)作为修正权重。它是降低 RLHF 采样成本的核心机制。
  2. Token-Level 的缺陷是什么? PPO/GRPO 逐词计算权重存在两大问题:一是矫正不足 ,单个 Token 的概率比远小于整个序列的联合概率比,无法真实反映新旧策略在长文本上的整体偏差;二是高方差,容易让模型"抓错重点",过度关注局部波动的 Token 而忽视句子整体语义,导致大模型在长文本强化学习时极难稳定收敛。这也是为什么阿里近期提出了基于句子级别的 GSPO 算法来尝试破局。