为什么LLM中KL散度需要近似计算

在LLM中KL散度一般为带上下文条件 x x x的条件概率形式:
KL ( q , p ) = ∑ x q ( o ∣ x ) log ⁡ q ( o ∣ x ) p ( o ∣ x ) = E o ∼ q [ log ⁡ q ( o ∣ x ) p ( o ∣ x ) ] \text{KL}(q, p) = \sum_{x} q(o \mid x)\log \frac{q(o \mid x)}{p(o \mid x)} = E_{o\sim q}\left[\log\frac{q(o \mid x)}{p(o \mid x)}\right] KL(q,p)=x∑q(o∣x)logp(o∣x)q(o∣x)=Eo∼q[logp(o∣x)q(o∣x)]
o o o表示生成的某个token

对于PPO、GRPO等颗粒度为token的算法来说, 对每个token严格计算在上下文条件 s s s下分布的KL散度都需要遍历词表(动作空间)进行求和或积分. 以qwen2.5为例, 词表大小为15665, 一个长度为1000的序列, 严格计算KL散度需要遍历15665000次, 而每次训练又包含batch_size个序列, 计算消耗极大.

为解决计算消耗问题, 用以下方式近似估计KL散度
KL ( q , p ) = q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 \text{KL}(q, p) = \frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1 KL(q,p)=p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1
x i x_i xi表示上下文条件 s s s下采样得到的token, 仅需获取token x i x_i xi的生成概率即可估算KL散度, 不再需要遍历词表. 这种估计方法是无偏且稳定的. (一些数学上的证明和推导可以参考博客1博客2)

因此在工程中常用蒙特卡洛方法近似估计给定prompt下生成序列概率分布的KL散度:
KL ^ ( q , p ) = 1 N ∑ i = 1 N [ q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 ] \widehat{\text{KL}}(q, p) = \frac{1}{N}\sum_{i=1}^N\left[\frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1\right] KL (q,p)=N1i=1∑N[p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1]
N N N表示生成序列的长度,

观察GRPO论文中给出的KL散度的形式

花括号外面的 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} ∣oi∣1∑t=1∣oi∣就是在进行蒙特卡洛采样.

GSPO

GSPO用的是序列级重要性采样和reward, 若计算KL散度同样也需要序列级. 对于序列来说动作空间维度为vocab_size^seq_len, 几乎无法遍历严格计算. 也可以参考上述思路, 用单个序列的生成概率来估算训练模型和参考模型之间的KL散度.
KL ( π θ , π r e f ) = ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − log ⁡ ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − 1 \text{KL}(\pi_\theta, \pi_{ref}) = \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - \log\left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - 1 KL(πθ,πref)=(πθ(yi∣x)πref(yi∣x))∣yi∣1−log(πθ(yi∣x)πref(yi∣x))∣yi∣1−1
y i y_i yi表示给定上下文提示词 x x x下生成的序列, 序列生成概率或序列生成概率的比值为:

( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ⁡ ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log ⁡ π r e f ( y i , t ∣ x , y i , < t ) π θ ( y i , t ∣ x , y i , < t ) ) \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}}=\exp \left(\frac{1}{|y_i|}\sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_{ref}\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) (πθ(yi∣x)πref(yi∣x))∣yi∣1=exp ∣yi∣1t=1∑∣yi∣logπθ(yi,t∣x,yi,<t)πref(yi,t∣x,yi,<t)

其中 1 y i \frac{1}{y_i} yi1表示对序列生成概率进行长度归一化, 防止不同序列的长短造成序列生成概率量级不一致.

相关推荐
新智元11 小时前
AI 科学家登场!12 小时抵人类科学家半年工作量,已有 7 项大成果
人工智能·openai
新智元11 小时前
PyTorch 之父闪电离职,AI 半壁江山集体致敬!
人工智能·openai
NON-JUDGMENTAL11 小时前
指令微调(Instruction Tuning)
人工智能·深度学习·机器学习
Funny_AI_LAB11 小时前
深度解析Andrej Karpathy访谈:关于AI智能体、AGI、强化学习与大模型的十年远见
人工智能·计算机视觉·ai·agi
互联科技报11 小时前
AI赋能企业办公:文多多AiPPT以技术创新破解行业痛点
人工智能
番石榴AI11 小时前
视频转ppt/pdf V2.0版(新增转为可编辑PPT功能)
人工智能·pdf·powerpoint
c0d1ng12 小时前
自建督学习——BERT(第二十二周周报)
人工智能·学习·bert
云和数据.ChenGuang12 小时前
SyntaxError: Non-UTF-8 code starting
人工智能·python·numpy
FreeCode12 小时前
LangChain1.0智能体开发:结构化输出
人工智能·langchain·agent
私域实战笔记12 小时前
企业微信SCRM怎么选?工具适配与落地实操指南
人工智能·数据挖掘·企业微信·scrm·企业微信scrm