为什么LLM中KL散度需要近似计算

在LLM中KL散度一般为带上下文条件 x x x的条件概率形式:
KL ( q , p ) = ∑ x q ( o ∣ x ) log ⁡ q ( o ∣ x ) p ( o ∣ x ) = E o ∼ q [ log ⁡ q ( o ∣ x ) p ( o ∣ x ) ] \text{KL}(q, p) = \sum_{x} q(o \mid x)\log \frac{q(o \mid x)}{p(o \mid x)} = E_{o\sim q}\left[\log\frac{q(o \mid x)}{p(o \mid x)}\right] KL(q,p)=x∑q(o∣x)logp(o∣x)q(o∣x)=Eo∼q[logp(o∣x)q(o∣x)]
o o o表示生成的某个token

对于PPO、GRPO等颗粒度为token的算法来说, 对每个token严格计算在上下文条件 s s s下分布的KL散度都需要遍历词表(动作空间)进行求和或积分. 以qwen2.5为例, 词表大小为15665, 一个长度为1000的序列, 严格计算KL散度需要遍历15665000次, 而每次训练又包含batch_size个序列, 计算消耗极大.

为解决计算消耗问题, 用以下方式近似估计KL散度
KL ( q , p ) = q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 \text{KL}(q, p) = \frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1 KL(q,p)=p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1
x i x_i xi表示上下文条件 s s s下采样得到的token, 仅需获取token x i x_i xi的生成概率即可估算KL散度, 不再需要遍历词表. 这种估计方法是无偏且稳定的. (一些数学上的证明和推导可以参考博客1博客2)

因此在工程中常用蒙特卡洛方法近似估计给定prompt下生成序列概率分布的KL散度:
KL ^ ( q , p ) = 1 N ∑ i = 1 N [ q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 ] \widehat{\text{KL}}(q, p) = \frac{1}{N}\sum_{i=1}^N\left[\frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1\right] KL (q,p)=N1i=1∑N[p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1]
N N N表示生成序列的长度,

观察GRPO论文中给出的KL散度的形式

花括号外面的 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} ∣oi∣1∑t=1∣oi∣就是在进行蒙特卡洛采样.

GSPO

GSPO用的是序列级重要性采样和reward, 若计算KL散度同样也需要序列级. 对于序列来说动作空间维度为vocab_size^seq_len, 几乎无法遍历严格计算. 也可以参考上述思路, 用单个序列的生成概率来估算训练模型和参考模型之间的KL散度.
KL ( π θ , π r e f ) = ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − log ⁡ ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − 1 \text{KL}(\pi_\theta, \pi_{ref}) = \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - \log\left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - 1 KL(πθ,πref)=(πθ(yi∣x)πref(yi∣x))∣yi∣1−log(πθ(yi∣x)πref(yi∣x))∣yi∣1−1
y i y_i yi表示给定上下文提示词 x x x下生成的序列, 序列生成概率或序列生成概率的比值为:

( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ⁡ ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log ⁡ π r e f ( y i , t ∣ x , y i , < t ) π θ ( y i , t ∣ x , y i , < t ) ) \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}}=\exp \left(\frac{1}{|y_i|}\sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_{ref}\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) (πθ(yi∣x)πref(yi∣x))∣yi∣1=exp ∣yi∣1t=1∑∣yi∣logπθ(yi,t∣x,yi,<t)πref(yi,t∣x,yi,<t)

其中 1 y i \frac{1}{y_i} yi1表示对序列生成概率进行长度归一化, 防止不同序列的长短造成序列生成概率量级不一致.

相关推荐
AI_小站5 小时前
6个GitHub爆火的免费大模型教程,助你快速进阶AI编程
人工智能·langchain·github·知识图谱·agent·llama·rag
xindoo5 小时前
GitHub Trending霸榜!深度解析AI Coding辅助神器 Superpowers
人工智能·github
时间之里5 小时前
【深度学习】:RF-DETR与yolo对比
人工智能·深度学习·yolo
北京阿法龙科技有限公司5 小时前
数智化升级:AR 智能眼镜驱动工业运维效能革新
人工智能
风落无尘5 小时前
《智能重生:从垃圾堆到AI工程师》——第二章 概率与生存
大数据·人工智能
j_xxx404_5 小时前
Linux:静态链接与动态链接深度解析
linux·运维·服务器·c++·人工智能
收获不止数据库5 小时前
达梦9发布会归来:AI 时代,我们需要一款什么样的数据库?
数据库·人工智能·ai·语言模型·数据分析
hhb_6185 小时前
AI全栈编程生存指南
人工智能
AI-Frontiers5 小时前
transformer进阶之路:#2 工作原理详解
人工智能·深度学习·transformer
科研前沿5 小时前
2026 数字孪生前沿科技:全景迭代报告 —— 镜像视界生成式孪生(Generative DT)技术白皮书
大数据·人工智能·科技·算法·音视频·空间计算