在LLM中KL散度一般为带上下文条件 x x x的条件概率形式:
KL ( q , p ) = ∑ x q ( o ∣ x ) log q ( o ∣ x ) p ( o ∣ x ) = E o ∼ q [ log q ( o ∣ x ) p ( o ∣ x ) ] \text{KL}(q, p) = \sum_{x} q(o \mid x)\log \frac{q(o \mid x)}{p(o \mid x)} = E_{o\sim q}\left[\log\frac{q(o \mid x)}{p(o \mid x)}\right] KL(q,p)=x∑q(o∣x)logp(o∣x)q(o∣x)=Eo∼q[logp(o∣x)q(o∣x)]
o o o表示生成的某个token
对于PPO、GRPO等颗粒度为token的算法来说, 对每个token严格计算在上下文条件 s s s下分布的KL散度都需要遍历词表(动作空间)进行求和或积分. 以qwen2.5为例, 词表大小为15665, 一个长度为1000的序列, 严格计算KL散度需要遍历15665000次, 而每次训练又包含batch_size个序列, 计算消耗极大.
为解决计算消耗问题, 用以下方式近似估计KL散度
KL ( q , p ) = q ( o i ∣ x ) p ( o i ∣ x ) − log q ( o i ∣ x ) p ( o i ∣ x ) − 1 \text{KL}(q, p) = \frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1 KL(q,p)=p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1
x i x_i xi表示上下文条件 s s s下采样得到的token, 仅需获取token x i x_i xi的生成概率即可估算KL散度, 不再需要遍历词表. 这种估计方法是无偏且稳定的. (一些数学上的证明和推导可以参考博客1和博客2)
因此在工程中常用蒙特卡洛方法近似估计给定prompt下生成序列概率分布的KL散度:
KL ^ ( q , p ) = 1 N ∑ i = 1 N [ q ( o i ∣ x ) p ( o i ∣ x ) − log q ( o i ∣ x ) p ( o i ∣ x ) − 1 ] \widehat{\text{KL}}(q, p) = \frac{1}{N}\sum_{i=1}^N\left[\frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1\right] KL (q,p)=N1i=1∑N[p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1]
N N N表示生成序列的长度,
观察GRPO论文中给出的KL散度的形式
花括号外面的 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} ∣oi∣1∑t=1∣oi∣就是在进行蒙特卡洛采样.
GSPO
GSPO用的是序列级重要性采样和reward, 若计算KL散度同样也需要序列级. 对于序列来说动作空间维度为vocab_size^seq_len, 几乎无法遍历严格计算. 也可以参考上述思路, 用单个序列的生成概率来估算训练模型和参考模型之间的KL散度.
KL ( π θ , π r e f ) = ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − log ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − 1 \text{KL}(\pi_\theta, \pi_{ref}) = \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - \log\left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - 1 KL(πθ,πref)=(πθ(yi∣x)πref(yi∣x))∣yi∣1−log(πθ(yi∣x)πref(yi∣x))∣yi∣1−1
y i y_i yi表示给定上下文提示词 x x x下生成的序列, 序列生成概率或序列生成概率的比值为:
( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log π r e f ( y i , t ∣ x , y i , < t ) π θ ( y i , t ∣ x , y i , < t ) ) \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}}=\exp \left(\frac{1}{|y_i|}\sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_{ref}\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) (πθ(yi∣x)πref(yi∣x))∣yi∣1=exp ∣yi∣1t=1∑∣yi∣logπθ(yi,t∣x,yi,<t)πref(yi,t∣x,yi,<t)
其中 1 y i \frac{1}{y_i} yi1表示对序列生成概率进行长度归一化, 防止不同序列的长短造成序列生成概率量级不一致.