DPO 直接偏好优化
论文地址: https://arxiv.org/pdf/2305.18290
2.7.1 为什么需要 DPO?(RLHF 的缺点)
传统的 RLHF(基于人类反馈的强化学习)通常包含三个步骤:SFT(有监督微调)、Reward Model(奖励模型训练)、PPO(强化学习优化)。
该方案的优点自然不需多说,但缺点也非常明显:训练流程繁琐,需要同时加载 4 个模型(Actor, Critic, Reward, Reference),导致占用显存极大,计算量也十分惊人。因此,寻找 RLHF 的平替方案成为了业界的热点。
DPO (Direct Preference Optimization, 直接偏好优化) 就是一种非常高效的 RLHF 替代算法。它巧妙地绕过了构建奖励模型和进行强化学习这两个繁琐的过程,直接通过偏好数据进行微调。由于无需加载 Reward 和 Critic 这两个模型,DPO 效果简单粗暴,在使模型输出更符合人类偏好的同时,极大地缩短了训练时间和难度。
2.7.2 DPO 训练的数据格式
DPO 需要的数据与 RLHF 一致,都是经过人工排序后的 QA 语料对。不同的是,DPO 训练数据的核心是"输入 - 优选回答 - 次选回答"的三元组结构。
其核心目的是让模型通过对比好回答 (chosen) 和差回答 (rejected),直接学习人类的偏好倾向。以下是具体的格式说明和示例:
json
{
"input": "写一段关于"春天"的抒情文字,要求50字左右。",
"chosen": "春风拂过,柳枝抽芽,樱花飘落如粉雪。田野里草色渐青,燕子衔泥筑巢,空气里满是湿润的花香,这是春天最温柔的呼吸。",
"rejected": "春天来了,天气变暖了,花儿开了,草儿绿了,大家可以出去玩了。"
}
2.7.3 DPO 的损失函数 (Loss Function)
这种数据格式的核心作用是让 DPO 的损失函数(通过对比 chosen 和 rejected 的概率差异)有效优化模型,使其更倾向于生成 chosen 级别的回答。其损失函数设计如下:
L D P O ( π θ ; π r e f ) = − E ( x , y c , y r ) ∼ D [ log σ ( β log π θ ( y c ∣ x ) π r e f ( y c ∣ x ) − β log π θ ( y r ∣ x ) π r e f ( y r ∣ x ) ) ] \mathcal{L}{DPO}(\pi{\theta};\pi_{ref})=-\mathbb{E}{(x,y{c},y_{r})\sim\mathcal{D}}\left[\log \sigma\left(\beta \log\frac{\pi_{\theta}(y_{c}|x)}{\pi_{ref}(y_{c}|x)}-\beta \log\frac{\pi_{\theta}(y_{r}|x)}{\pi_{ref}(y_{r}|x)}\right)\right] LDPO(πθ;πref)=−E(x,yc,yr)∼D[logσ(βlogπref(yc∣x)πθ(yc∣x)−βlogπref(yr∣x)πθ(yr∣x))]
公式参数解析:
- y c y_{c} yc 是偏好数据对中好的回答 (chosen) , y r y_{r} yr 则是坏的回答 (rejected)。
- π θ ( y ∣ x ) \pi_{\theta}(y|x) πθ(y∣x) 是给定输入 x x x 时,当前策略 (Actor) 生成答案的概率。
- π r e f ( y ∣ x ) \pi_{ref}(y|x) πref(y∣x) 是给定输入 x x x 时,原始策略 (Reference) 生成答案的概率。
原理解析:
− log σ -\log\sigma −logσ 函数里面的部分越大时,整体的 loss 就越小。所以对于 DPO 的 loss,我们只需要将 − log σ -\log\sigma −logσ 里面的部分最大化即可。提取出来简化后得到:
β ( [ log π θ ( y c ∣ x ) − log π θ ( y r ∣ x ) ] − [ log π r e f ( y c ∣ x ) − log π r e f ( y r ∣ x ) ] ) \beta \left( \left[\log \pi_{\theta}(y_{c}|x) - \log \pi_{\theta}(y_{r}|x)\right] - \left[\log \pi_{ref}(y_{c}|x) - \log \pi_{ref}(y_{r}|x)\right] \right) β([logπθ(yc∣x)−logπθ(yr∣x)]−[logπref(yc∣x)−logπref(yr∣x)])
由此可以看出,DPO 期望最大化的其实就是策略模型相对于参考模型,在 chosen 数据和 rejected 数据上的概率差值。通过这种对比拉扯,达到使模型的回答更偏向于人类排序靠前回答的目标。
直接策略优化 (DPO) 算法巧妙地将 reward model 和强化学习两个步骤合并,使得训练更加快速高效。在它的训练过程中,Reference 参数固定,只对目标语言模型 (Actor) 进行参数更新,调试更加简单。
2.7.4 DPO 的缺点 (过拟合风险)
虽然 DPO 的推导结果看似非常完美,但实际使用过程中与 PPO 优化算法仍有差距。主要原因是 DPO 的训练目标可能会导致过拟合。
在优化过程中,如果 π θ ( y r ∣ x ) → 0 \pi_{\theta}(y_{r}|x) \rightarrow 0 πθ(yr∣x)→0(即模型极度降低生成坏回答的概率),会导致:
− β log π θ ( y r ∣ x ) π r e f ( y r ∣ x ) → + ∞ -\beta \log\frac{\pi_{\theta}(y_{r}|x)}{\pi_{ref}(y_{r}|x)} \rightarrow +\infty −βlogπref(yr∣x)πθ(yr∣x)→+∞
这意味着,损失能一下子降得很低。直观的理解是:模型即便随便胡说八道,只要能把 rejected 的概率压到足够低,就能使得 Loss 下降。 这给模型留下了"钻空子"的空间。
相比之下,PPO 的损失函数考虑了结果整体的分值(霸总逻辑:除非你能拿到高分,否则必须给我守规矩保持结果合理分布),因此在对齐的稳健性上 PPO 通常更胜一筹。
2.7.5 DPO Loss 核心代码实现
以下是 DPO 损失函数的 PyTorch 代码实现(包含标准 DPO 与 IPO 变体):
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class DPOLoss(nn.Module):
"""
DPO (Direct Preference Optimization) 损失函数
无需显式中间奖励模型,直接通过偏好数据(chosen vs rejected)优化策略。
"""
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
"""
参数:
beta: 温度参数,控制策略更新的强度
label_smoothing: 标签平滑系数,用于缓解过拟合,默认0.0(不使用平滑)
ipo: 是否使用IPO (Implicit Preference Optimization)变体,默认False
"""
super().__init__()
self.beta = beta
self.label_smoothing = label_smoothing
self.ipo = ipo
def forward(
self,
policy_chosen_logps: torch.Tensor, # 策略模型对选中响应的对数概率
policy_rejected_logps: torch.Tensor, # 策略模型对被拒绝响应的对数概率
reference_chosen_logps: torch.Tensor, # 参考模型对选中响应的对数概率
reference_rejected_logps: torch.Tensor # 参考模型对被拒绝响应的对数概率
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# 计算策略模型下选中响应与被拒绝响应的对数概率比
pi_logratios = policy_chosen_logps - policy_rejected_logps
# 计算参考模型下选中响应与被拒绝响应的对数概率比
ref_logratios = reference_chosen_logps - reference_rejected_logps
# 计算logits: 策略比与参考比的差值,衡量策略相对参考模型的改进
logits = pi_logratios - ref_logratios
if self.ipo:
# 使用IPO损失(平方损失)
# 鼓励logits接近 1/(2*beta),实现更稳定的优化
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
# 标准DPO损失,结合标签平滑,平衡选中和拒绝样本的损失权重
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
# 计算批次平均损失
loss = losses.mean()
# 计算选中和拒绝响应的奖励估计(基于策略与参考模型的差异,detach()确保不影响梯度计算)
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewards
2.7.6 DPO 与 PPO 的核心区别汇总
| 特性 | PPO(近端策略优化) | DPO(直接偏好优化) |
|---|---|---|
| 形象比喻 | 先造尺子,再用尺子量着学走路。 | 直接看别人走的对错,调整步伐。 |
| 优化方式 | 借助策略比值、优势函数及 clip 裁剪操作限制策略更新幅度,追求长期回报。兼顾新旧分布合理与绝对分值高。 | 利用打标数据,最大化生成 chosen 的概率、减少生成 rejected 的概率。 |
| 资源消耗 | 需要 Reward + RL 两次训练,RL 阶段要加载 4 个模型并大量在线采样,耗时且耗显存。 | 无需明确奖励函数,将 Reward 和 RL 步骤合并,训练仅需加载 2 个模型。 |
| 效果上限 | Reward 模型的质量直接决定了最终大模型的上限。 | 依赖静态数据训练,更接近监督学习,偏好数据的质量决定了模型的最终质量。 |
