为什么LLM中KL散度需要近似计算

在LLM中KL散度一般为带上下文条件 x x x的条件概率形式:
KL ( q , p ) = ∑ x q ( o ∣ x ) log ⁡ q ( o ∣ x ) p ( o ∣ x ) = E o ∼ q [ log ⁡ q ( o ∣ x ) p ( o ∣ x ) ] \text{KL}(q, p) = \sum_{x} q(o \mid x)\log \frac{q(o \mid x)}{p(o \mid x)} = E_{o\sim q}\left[\log\frac{q(o \mid x)}{p(o \mid x)}\right] KL(q,p)=x∑q(o∣x)logp(o∣x)q(o∣x)=Eo∼q[logp(o∣x)q(o∣x)]
o o o表示生成的某个token

对于PPO、GRPO等颗粒度为token的算法来说, 对每个token严格计算在上下文条件 s s s下分布的KL散度都需要遍历词表(动作空间)进行求和或积分. 以qwen2.5为例, 词表大小为15665, 一个长度为1000的序列, 严格计算KL散度需要遍历15665000次, 而每次训练又包含batch_size个序列, 计算消耗极大.

为解决计算消耗问题, 用以下方式近似估计KL散度
KL ( q , p ) = q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 \text{KL}(q, p) = \frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1 KL(q,p)=p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1
x i x_i xi表示上下文条件 s s s下采样得到的token, 仅需获取token x i x_i xi的生成概率即可估算KL散度, 不再需要遍历词表. 这种估计方法是无偏且稳定的. (一些数学上的证明和推导可以参考博客1博客2)

因此在工程中常用蒙特卡洛方法近似估计给定prompt下生成序列概率分布的KL散度:
KL ^ ( q , p ) = 1 N ∑ i = 1 N [ q ( o i ∣ x ) p ( o i ∣ x ) − log ⁡ q ( o i ∣ x ) p ( o i ∣ x ) − 1 ] \widehat{\text{KL}}(q, p) = \frac{1}{N}\sum_{i=1}^N\left[\frac{q(o_i \mid x)}{p(o_i \mid x)} - \log\frac{q(o_i \mid x)}{p(o_i \mid x)} - 1\right] KL (q,p)=N1i=1∑N[p(oi∣x)q(oi∣x)−logp(oi∣x)q(oi∣x)−1]
N N N表示生成序列的长度,

观察GRPO论文中给出的KL散度的形式

花括号外面的 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} ∣oi∣1∑t=1∣oi∣就是在进行蒙特卡洛采样.

GSPO

GSPO用的是序列级重要性采样和reward, 若计算KL散度同样也需要序列级. 对于序列来说动作空间维度为vocab_size^seq_len, 几乎无法遍历严格计算. 也可以参考上述思路, 用单个序列的生成概率来估算训练模型和参考模型之间的KL散度.
KL ( π θ , π r e f ) = ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − log ⁡ ( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ − 1 \text{KL}(\pi_\theta, \pi_{ref}) = \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - \log\left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}} - 1 KL(πθ,πref)=(πθ(yi∣x)πref(yi∣x))∣yi∣1−log(πθ(yi∣x)πref(yi∣x))∣yi∣1−1
y i y_i yi表示给定上下文提示词 x x x下生成的序列, 序列生成概率或序列生成概率的比值为:

( π r e f ( y i ∣ x ) π θ ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ⁡ ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log ⁡ π r e f ( y i , t ∣ x , y i , < t ) π θ ( y i , t ∣ x , y i , < t ) ) \left(\frac{\pi_{ref}\left(y_i \mid x\right)}{\pi_{\theta}\left(y_i \mid x\right)}\right)^{\frac{1}{|y_i|}}=\exp \left(\frac{1}{|y_i|}\sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_{ref}\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) (πθ(yi∣x)πref(yi∣x))∣yi∣1=exp ∣yi∣1t=1∑∣yi∣logπθ(yi,t∣x,yi,<t)πref(yi,t∣x,yi,<t)

其中 1 y i \frac{1}{y_i} yi1表示对序列生成概率进行长度归一化, 防止不同序列的长短造成序列生成概率量级不一致.

相关推荐
沃达德软件3 小时前
智慧警务图像融合大数据
大数据·图像处理·人工智能·目标检测·计算机视觉·目标跟踪
QxQ么么4 小时前
移远通信(桂林)26校招-助理AI算法工程师-面试纪录
人工智能·python·算法·面试
愤怒的可乐4 小时前
从零构建大模型智能体:统一消息格式,快速接入大语言模型
人工智能·语言模型·自然语言处理
每天一个java小知识6 小时前
AI Agent
人工智能
猫头虎6 小时前
如何解决 pip install 编译报错 fatal error: hdf5.h: No such file or directory(h5py)问题
人工智能·python·pycharm·开源·beautifulsoup·ai编程·pip
龙赤子6 小时前
人工智能AI的大框架
人工智能
比奥利奥还傲.6 小时前
本地+AI+大模型自由用!Cherry+Studio打破局域网限制
人工智能
雪碧聊技术6 小时前
深度学习、机器学习、人工智能三者的关系
人工智能·深度学习·机器学习
β添砖java6 小时前
机器学习初级
人工智能·机器学习
陈奕昆6 小时前
n8n实战营Day3:电商订单全流程自动化·需求分析与流程拆解
大数据·开发语言·人工智能·自动化·需求分析·n8n