为什么LLM中KL散度需要近似计算

在LLM中KL散度一般为带上下文条件 x x x的条件概率形式:
KL ( q , p ) = ∑ x q ( o ∣ x ) log ⁡ q ( o ∣ x ) p ( o ∣ x ) = E o ∼ q [ log ⁡ q ( o ∣ x ) p ( o ∣ x ) ] \text{KL}(q, p) = \sum_{x} q(o \mid x)\log \frac{q(o \mid x)}{p(o \mid x)} = E_{o\sim q}\left[\log\frac{q(o \mid x)}{p(o \mid x)}\right] KL(q,p)=x∑q(o∣x)logp(o∣x)q(o∣x)=Eo∼q[logp(o∣x)q(o∣x)]
o o o表示生成的某个token

对于PPO、GRPO等颗粒度为token的算法来说, 对每个token严格计算在上下文条件 s s s下分布的KL散度都需要遍历词表(动作空间)进行求和或积分. 以qwen2.5为例, 词表大小为15665, 一个长度为1000的序列, 严格计算KL散度需要遍历15665000次, 而每次训练又包含batch_size个序列, 计算消耗极大.

为解决计算消耗问题, 用以下方式近似估计KL散度
KL ( q , p ) = q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 \text{KL}(q, p) = \frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1 KL(q,p)=p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1
x i x_i xi表示上下文条件 s s s下采样得到的token, 仅需获取token x i x_i xi的生成概率即可估算KL散度, 不再需要遍历词表. 这种估计方法是无偏且稳定的. (一些数学上的证明和推导可以参考博客1博客2)

因此在工程中常用蒙特卡洛方法近似估计给定prompt下生成序列概率分布的KL散度:
KL ^ ( q , p ) = 1 N ∑ i = 1 N [ q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 ] \widehat{\text{KL}}(q, p) = \frac{1}{N}\sum_{i=1}^N\left[\frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1\right] KL (q,p)=N1i=1∑N[p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1]
N N N表示生成序列的长度,

观察GRPO论文中给出的KL散度的形式

花括号外面的 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} ∣oi∣1∑t=1∣oi∣就是在进行蒙特卡洛采样.

GSPO

GSPO用的是序列级重要性采样和reward, 若计算KL散度同样也需要序列级. 对于序列来说动作空间维度为vocab_size^seq_len, 几乎无法遍历严格计算. 也可以参考上述思路, 用单个序列的生成概率来估算训练模型和参考模型之间的KL散度.
KL ( π θ , π r e f ) = ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − log ⁡ ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − 1 \text{KL}(\pi_\theta, \pi_{ref}) = \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - \log\left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - 1 KL(πθ,πref)=(πθ(yi∣x)πref(yi∣x))∣yi∣1−log(πθ(yi∣x)πref(yi∣x))∣yi∣1−1
y i y_i yi表示给定上下文提示词 x x x下生成的序列, 序列生成概率或序列生成概率的比值为:

( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ⁡ ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log ⁡ π r e f ( y i , t ∣ x , y i , < t ) π θ ( y i , t ∣ x , y i , < t ) ) \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}}=\exp \left(\frac{1}{|y_i|}\sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_{ref}\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) (πθ(yi∣x)πref(yi∣x))∣yi∣1=exp ∣yi∣1t=1∑∣yi∣logπθ(yi,t∣x,yi,<t)πref(yi,t∣x,yi,<t)

其中 1 y i \frac{1}{y_i} yi1表示对序列生成概率进行长度归一化, 防止不同序列的长短造成序列生成概率量级不一致.

相关推荐
小润nature4 分钟前
Moltbot/OpenClaw Gateway 命令和交互
人工智能
tongxianchao5 分钟前
TOKEN MERGING YOUR VIT BUT FASTER
人工智能
自可乐9 分钟前
LangGraph从入门到精通:构建智能Agent的完整指南
人工智能·python·机器学习
下午写HelloWorld9 分钟前
差分隐私深度学习(DP-DL)简要理解
人工智能·深度学习
码农垦荒笔记10 分钟前
OpenClaw 实战 #02-1:新手一把过(原Clawdbot )保姆级安装教程-Mac版
人工智能·macos·openclaw
冀博13 分钟前
LangGraph实操-干中学
人工智能·ai
玉梅小洋20 分钟前
手机 App 云端存储云服务选型指南
人工智能·智能手机·手机·工具开发·手机app开发
deephub20 分钟前
让 AI 智能体学会自我进化:Agent Lightning 实战入门
人工智能·深度学习·大语言模型·agent
Loo国昌24 分钟前
【垂类模型数据工程】第四阶段:高性能 Embedding 实战:从双编码器架构到 InfoNCE 损失函数详解
人工智能·后端·深度学习·自然语言处理·架构·transformer·embedding
yunhuibin29 分钟前
VideoPipe环境搭建及编译ubuntu240403
前端·人工智能