RLHF以及基于RLHF发展的PPO强化学习调优,是目前LLM超强能力的基础。
这里尝试从伪代码角度,尝试深入解析RLHF损失函数与训练数据的关系。
所用示例参考和修改自网络资料。
1 RLHF
RLHF(Reinforcement Learning from Human Feedback)的训练过程通常分为三个主要阶段:
1.1 预训练语言模型(LM)
使用标准的语言模型,比如由多层Transformer构成的一个大模型,在一个大规模的文本语料库上预训练,得到模型通常称为基座模型(base model)。
1.2 奖励模型(RM)
收集人类反馈并训练奖励模型,具体为
1)收集人类对模型生成文本的反馈,通常让基座模型生成多个输出,然后让人类对这些输出进行排序,比如,对模型输出进行打分,选择哪个更好。
2)使用这些排序数据来训练一个奖励模型,该模型能够对任意生成的文本给出一个分数,反映其质量(如是否符合人类偏好)。
1.3 策略模型(PM)
使用强化学习(RL)微调基于语言模型的策略模型,使用训练好的奖励模型作为奖励信号,通过强化学习通常是PPO来微调策略模型,使得模型生成更符合人类偏好的文本。
2 训练数据&过程解析
这里通过训练数据结合训练过程,分析RLHF的三个主要阶段。
2.1 SFT阶段
首先是SFT训练数据,用于训练LM模型,需要高质量、多样化的数据。
示例如下。
提示 (Prompt):
用户:我持续低烧和干咳已经三天了,有可能感染了新冠病毒吗?我应该怎么办?
理想回答 (Chosen Response):
【医学知识】 持续低烧和干咳是新冠病毒感染(COVID-19)的常见症状之一,但也可能与其他呼吸道感染(如普通流感)有关。
【行动建议】
立即检测:建议您尽快进行新冠病毒抗原快速检测或核酸检测,这是确认是否感染的最直接方法。
自我隔离:在等待检测结果和症状缓解期间,请尽量留在家中,佩戴口罩,与其他同住者保持距离,以降低潜在传播风险。
健康监测:密切监测体温和症状变化。如果出现呼吸困难、胸痛、意识模糊或高烧不退等严重症状,请立即就医。
常规护理:多休息、多喝水,可以使用一些非处方药来缓解发热和咳嗽症状。
【免责声明】以上信息仅供参考,不能替代专业医疗建议。请根据自身情况及时咨询医生。
这条SFT数据为模型提供了一个高质量的回答范例,涵盖了知识解释、结构化建议和必要的免责声明,为后续的偏好排序树立了明确的标准。
2.2 奖励训练
其次是训练RM奖励模型的样本数据,需要明确的偏好判断,避免模糊对。示例如下所示。
提示 (Prompt):与SFT阶段完全相同
用户:我持续低烧和干咳已经三天了,有可能感染了新冠病毒吗?我应该怎么办?
回答A (Response A):
低烧和干咳可能是新冠症状,建议你去做个检测看看。多喝水,多休息。
回答B (Response B):为简化分析,这里直接采用SFT阶段的理想回答。
人工标注的偏好与评分:
| 评估维度 | 回答A 得分 (1-5) | 回答B 得分 (1-5) | 偏好说明 |
|---|---|---|---|
| 帮助性 | 2 | 5 | 回答B提供了具体、可操作的步骤(检测、隔离、监测),而A的建议过于笼统。 |
| 诚实性 | 4 | 5 | 两者均正确,但B的信息更全面,指出了其他可能性。 |
| 无害性 | 4 | 5 | B包含了关键的免责声明,更严谨。 |
| 条理性 | 2 | 5 | B使用标题和列表,结构清晰易读;A是简单句堆砌。 |
| 总体偏好 | 拒绝 (Rejected) | 优选 (Chosen) | 回答B在所有维度上都显著优于A,因为它更详细、更专业、更具建设性。 |
对于同一个prompt,可以生成多个回答,比如回答A、回答B、回答C、回答D等。
然后采用上述形式,两两对比标注,比如A<B, A>C, A>D, B>C, B>D, C>D,如此形成6个样本。
需要注意的时,在训练时需要将上述表主的6个样本放在一个batch,目的是维持训练的稳定性。
通过这种多维度、带解释的标注,奖励模型不仅能学会"B比A好",更能精细化地理解"好在哪里"
例如,回答B相比A,在"帮助性"和"条理性"上差距巨大。
成千上万条这样的标注数据,就构成了奖励模型的"价值观"。
2.3 PPO阶段
PPO阶段的数据,多维度奖励可以提供更丰富的训练信号。
在此阶段,初始策略模型(经过SFT的模型)根据RM的反馈进行迭代更新。
1)初始策略模型生成回答:
提示 :"我接触了新冠确诊者,但还没症状,需要担心吗?"
初始模型生成(第1轮) :"你存在被感染的风险,建议进行核酸检测并观察症状。"
2)奖励模型进行评估:
奖励模型收到这个回答后,会对照其学习的"好回答"标准(来自RM训练数据)进行打分。假设它给出一个中等分数(例如+0.5)。这个分数不高,因为回答虽然正确,但不够详细和主动。
3)策略模型迭代与优化:
强化学习优化(PPO)分析这个分数,并调整模型参数,目标是生成能获得更高奖励分数的回答。
经过多轮迭代后,面对同样的提示,优化后的模型可能生成:
【风险评估】 您目前处于暴露后观察期。由于新冠病毒存在潜伏期(通常2-14天),即使暂无症状,也存在感染风险。
【核心建议】
暴露后管理:建议从接触日起,自我健康监测14天。
检测时机:不必立即检测,因为病毒载量可能不足。建议在接触后第5天左右进行核酸检测,此时检出率较高。
预防性措施:在监测期间,请尽量减少与他人密切接触,外出佩戴口罩,这既是保护他人,也是自我保护。
【补充信息】 请留意是否出现发热、咳嗽、乏力等任何轻微症状。部分感染者症状轻微或无症状,但仍具传染性。
以上建议基于当前公共卫生指南,个人情况请咨询疾控部门或医生。
4)最终评估:
奖励模型对这个更详尽、更具前瞻性 的回答可能会给出更高的分数(例如+2.8)。
同时,参考模型 (通常clone自SFT模型)会计算当前回答与初始SFT模型"典型风格"的偏离程度(KL散度惩罚),以防止优化过度、产生胡言乱语。最终收益是 奖励分数 - 偏离惩罚。
3 训练损失和完整流程
3.1 SFT损失函数
首先是SFT,使用高质量的问答数据对预训练模型进行微调。
数据是高质量的问答对,例如(prompt, response)。
1)训练目标
目标是使模型学会生成高质量的答案。
损失函数,即标准的语言模型损失,即交叉熵损失,但只计算答案部分(不包括prompt部分)损失。
假设我们有一个批次的数据,每个数据点由input_ids和labels组成,其中input_ids是prompt和response的拼接,labels是与input_ids等长的序列,但是将prompt部分的标签设置为-100(在计算损失时忽略),只计算response部分的损失。
2)数学表示
SFT损失函数的数学解析如下所示
其中
是mask(1表示有效token,0表示padding)。
prompt部分mask设置为0,response部分mask设置为1,表示只计算response部分的损失。
是batch size,
是序列长度。
SFT损失示例伪码如下所示。
import torch
import torch.nn.functional as F
class DetailedSFTLoss:
def __init__(self, config):
self.config = config
def compute_loss_with_mask(self, logits, labels, attention_mask=None):
"""
详细的SFT损失计算,包含padding mask处理
"""
# 维度: logits: [B, T, V], labels: [B, T]
batch_size, seq_len, vocab_size = logits.shape
# 1. 计算每个位置的log probability
log_probs = F.log_softmax(logits, dim=-1) # [B, T, V]
# 2. 提取目标token的log prob
# labels需要调整为[B, T, 1]用于gather
target_log_probs = torch.gather(
log_probs,
dim=-1,
index=labels.unsqueeze(-1)
).squeeze(-1) # [B, T]
# 3. 创建loss mask
if attention_mask is None:
attention_mask = torch.ones_like(labels)
# 对于语言建模,我们通常预测下一个token
# 所以需要将labels向右shift一位
shift_mask = attention_mask[:, 1:] # 忽略第一个token
shift_labels = labels[:, 1:] # 预测下一个token
shift_log_probs = target_log_probs[:, :-1] # 对应的log probs
# 4. 计算加权损失
# 只计算非padding位置的损失
loss_per_token = -shift_log_probs * shift_mask
total_loss = loss_per_token.sum()
num_valid_tokens = shift_mask.sum()
# 5. 避免除以零
if num_valid_tokens > 0:
average_loss = total_loss / num_valid_tokens
else:
average_loss = torch.tensor(0.0)
# 6. 计算额外统计信息
with torch.no_grad():
# 计算perplexity
exp_loss = torch.exp(average_loss)
# 计算准确率
predictions = torch.argmax(logits[:, :-1, :], dim=-1)
correct = (predictions == shift_labels) * shift_mask
accuracy = correct.sum() / num_valid_tokens if num_valid_tokens > 0 else 0
return {
"loss": average_loss,
"perplexity": exp_loss,
"accuracy": accuracy,
"num_tokens": num_valid_tokens
}
3.2 奖励模型损失
1)训练目标
其次是RM,训练一个奖励模型评估生成内容的质量,通常使用人类对生成内容的偏好数据。
数据是由SFT模型生成多个回答,然后由人类标注员对这些回答进行排序。
如第2节所示,对于同一个prompt,生成两个回答,一个被标注为更好(chosen),一个被标注为更差(rejected),目标是训练一个奖励模型,对于给定的prompt和response,输出一个标量奖励值,使得chosen回答的奖励高于rejected回答。
2)数学表示
RM的损失函数,通常使用Pairwise Ranking Loss,例如Bradley-Terry模型。
奖励函数的数学解析如下所示
成对损失(Pairwise Loss)
其中是温度参数,控制区分度。
成对损失主要用于优化奖励模型对<prompt, response>对的打分。
列表损失(Listwise Loss)
其中是根据人类偏好排序的排列。
列表损失采用类似softmax的方法,通过分数,间接学习人类对不同<prompt, response>的排序。
RM损失的示例代码如下所示。
class DetailedRewardLoss:
def __init__(self, config):
self.config = config
# 温度参数,控制偏好对的区分度
self.temperature = config.temperature
# 是否使用margin
self.use_margin = config.use_margin
self.margin = config.margin
def compute_pairwise_loss(self, chosen_rewards, rejected_rewards,
chosen_mask=None, rejected_mask=None,
pair_weights=None):
"""
计算成对偏好损失,支持不同长度序列
参数:
chosen_rewards: [B, T_c] 或 [B,]
rejected_rewards: [B, T_r] 或 [B,]
"""
batch_size = chosen_rewards.shape[0]
# 1. 处理序列奖励(如果奖励是逐token的)
if chosen_rewards.dim() > 1:
# 对序列奖励进行聚合(例如取最后一个token或平均)
if chosen_mask is not None:
chosen_final_rewards = self.aggregate_sequence_rewards(
chosen_rewards, chosen_mask
)
rejected_final_rewards = self.aggregate_sequence_rewards(
rejected_rewards, rejected_mask
)
else:
chosen_final_rewards = chosen_rewards[:, -1] # 取EOS token
rejected_final_rewards = rejected_rewards[:, -1]
else:
chosen_final_rewards = chosen_rewards
rejected_final_rewards = rejected_rewards
# 2. 计算logits(对数胜率)
# 应用温度参数
logits = (chosen_final_rewards - rejected_final_rewards) / self.temperature
# 3. 计算基础损失
# Bradley-Terry模型: P(chosen > rejected) = σ(r_chosen - r_rejected)
if self.use_margin:
# 带margin的损失,确保chosen比rejected至少高margin
loss = F.relu(self.margin - logits).mean()
losses = F.relu(self.margin - logits)
else:
# 标准Bradley-Terry损失
# loss = -log(σ(r_chosen - r_rejected))
loss = -F.logsigmoid(logits).mean()
losses = -F.logsigmoid(logits)
# 4. 应用样本权重(如果有)
if pair_weights is not None:
loss = (losses * pair_weights).sum() / pair_weights.sum()
# 5. 计算额外统计信息
with torch.no_grad():
# 预测准确率(chosen奖励是否大于rejected)
predictions = (chosen_final_rewards > rejected_final_rewards).float()
accuracy = predictions.mean()
# 平均奖励差
reward_diff = (chosen_final_rewards - rejected_final_rewards).mean()
# 冲突样本比例(奖励差太小)
conflict_ratio = (torch.abs(chosen_final_rewards - rejected_final_rewards) < 0.1).float().mean()
return {
"loss": loss,
"accuracy": accuracy,
"reward_diff": reward_diff,
"conflict_ratio": conflict_ratio,
"chosen_mean": chosen_final_rewards.mean(),
"rejected_mean": rejected_final_rewards.mean()
}
def aggregate_sequence_rewards(self, rewards, mask):
"""聚合序列奖励"""
if self.config.reward_aggregation == "last":
# 取最后一个非padding token的奖励
seq_lengths = mask.sum(dim=1).long() - 1 # 最后一个有效token的索引
batch_indices = torch.arange(rewards.shape[0])
aggregated = rewards[batch_indices, seq_lengths]
elif self.config.reward_aggregation == "mean":
# 取非padding token的平均奖励
aggregated = (rewards * mask).sum(dim=1) / mask.sum(dim=1)
elif self.config.reward_aggregation == "sum":
# 取非padding token的奖励和
aggregated = (rewards * mask).sum(dim=1)
else:
raise ValueError(f"Unknown aggregation: {self.config.reward_aggregation}")
return aggregated
def compute_listwise_loss(self, rewards_list, rankings, temperature=1.0):
"""
列表损失(Listwise loss),用于多个响应排序
参数:
rewards_list: [B, K] K个响应的奖励
rankings: [B, K] 每个响应的排名(1表示最好)
"""
batch_size, num_responses = rewards_list.shape
# 1. 将排名转换为概率分布(Plackett-Luce模型)
# P(ranking) = ∏_{i=1}^{K} exp(r_{π(i)}) / ∑_{j=i}^{K} exp(r_{π(j)})
# 根据排名排序奖励
sorted_indices = torch.argsort(rankings, dim=1)
sorted_rewards = torch.gather(rewards_list, 1, sorted_indices)
# 2. 计算log概率
losses = []
for i in range(num_responses):
# 计算第i个位置的条件概率
remaining_rewards = sorted_rewards[:, i:] / temperature
log_denominator = torch.logsumexp(remaining_rewards, dim=1)
log_prob = sorted_rewards[:, i] / temperature - log_denominator
losses.append(-log_prob)
loss = torch.stack(losses, dim=1).mean()
return loss
3.3 PPO强化损失
强化学习微调,使用奖励模型作为奖励信号,通过强化学习(PPO)进一步微调SFT模型。
1)训练目标
PPO使用当前的策略模型,通常初始化自SFT模型,与环境即用户输入prompt交互,生成回答。
然后使用奖励模型计算奖励,目标是优化策略模型以获取更高奖励。
PPO损失函数,包括三个部分:
策略损失(Policy Loss):使用PPO-Clip算法,限制更新幅度。
价值损失(Value Loss):如果使用价值函数,则需要对价值函数模型进行训练。
熵损失(Entropy Loss):鼓励探索,防止策略过早收敛。
其中策略损失是主要部分,同时加入KL散度惩罚来防止策略模型偏离参考模型(SFT模型)太远。
2)数学表示
PPO-Clip Actor损失:
其中。
这里从稳定性角度出发,使用clip和normalization等技术,限制更新幅度。
KL散度惩罚:
加入KL散度惩罚来防止策略模型偏离参考模型(SFT模型)太远。
价值损失
即通用的MSE损失。
示例代码如下所示
class DetailedPPOLoss:
def __init__(self, config):
self.config = config
def compute_actor_loss(self, logprobs, old_logprobs, advantages, kl_penalty=True):
"""
详细的Actor损失计算
参数:
logprobs: 当前策略的log概率 [B, T] 或 [B,]
old_logprobs: 旧策略的log概率 [B, T] 或 [B,]
advantages: 优势估计 [B, T] 或 [B,]
"""
# 1. 计算概率比
ratio = torch.exp(logprobs - old_logprobs) # [B, T] 或 [B,]
# 2. 计算两种损失
pg_loss1 = ratio * advantages
pg_loss2 = torch.clamp(ratio,
1.0 - self.config.clip_epsilon,
1.0 + self.config.clip_epsilon) * advantages
# 3. PPO-Clip损失
actor_loss = -torch.min(pg_loss1, pg_loss2)
# 4. 如果按token计算,需要取平均
if actor_loss.dim() > 1:
actor_loss = actor_loss.mean(dim=-1)
# 5. 添加KL散度惩罚
if kl_penalty:
kl_div = old_logprobs - logprobs # KL(p_old || p_new)
if kl_div.dim() > 1:
kl_div = kl_div.mean(dim=-1)
# 自适应KL惩罚
if self.config.adaptive_kl:
kl_coef = self.config.kl_coef
kl_mean = kl_div.mean().item()
if kl_mean > self.config.target_kl * 1.5:
kl_coef *= 1.2
elif kl_mean < self.config.target_kl / 1.5:
kl_coef /= 1.2
else:
kl_coef = self.config.kl_coef
actor_loss = actor_loss + kl_coef * kl_div
# 6. 添加熵奖励(鼓励探索)
if self.config.entropy_coef > 0:
# 计算熵: H(p) = -∑ p log p
# 这里我们有logits时可以直接计算,只有logprobs时需要估算
# 简化:假设当前策略是确定的,熵近似为-logprobs的方差
entropy_bonus = -logprobs.var(dim=-1) if logprobs.dim() > 1 else torch.tensor(0.0)
actor_loss = actor_loss - self.config.entropy_coef * entropy_bonus
return actor_loss.mean()
def compute_value_loss(self, values, old_values, returns, clip_value_loss=True):
"""
价值损失计算
参数:
values: 当前价值估计 [B, T] 或 [B,]
old_values: 旧价值估计 [B, T] 或 [B,]
returns: 实际回报 [B, T] 或 [B,]
"""
# 1. 价值预测误差
value_error = returns - values
# 2. 可选:Clipped value loss
if clip_value_loss and old_values is not None:
value_clipped = old_values + torch.clamp(
values - old_values,
-self.config.clip_epsilon,
self.config.clip_epsilon
)
value_loss1 = (value_error ** 2)
value_loss2 = ((returns - value_clipped) ** 2)
value_loss = 0.5 * torch.max(value_loss1, value_loss2)
else:
value_loss = 0.5 * (value_error ** 2)
# 3. 按需取平均
if value_loss.dim() > 1:
value_loss = value_loss.mean(dim=-1)
return value_loss.mean()
def compute_advantages(self, rewards, values, dones=None, gamma=0.99, gae_lambda=0.95):
"""
广义优势估计 (GAE)
参数:
rewards: 奖励序列 [B, T]
values: 价值估计 [B, T+1] (包含终止状态价值)
dones: 终止标志 [B, T]
"""
batch_size, seq_len = rewards.shape
# 初始化优势值
advantages = torch.zeros_like(rewards)
last_gae = torch.zeros(batch_size)
# 反向计算GAE
for t in reversed(range(seq_len)):
if t == seq_len - 1:
next_value = values[:, t+1] if values.shape[1] > seq_len else 0.0
else:
next_value = values[:, t+1]
# 计算TD误差
delta = rewards[:, t] + gamma * next_value * (1.0 if dones is None else (1.0 - dones[:, t])) - values[:, t]
# 更新GAE
last_gae = delta + gamma * gae_lambda * (1.0 if dones is None else (1.0 - dones[:, t])) * last_gae
advantages[:, t] = last_gae
return advantages
3.4 完整训练流程
以下是RLHF的完整训练流程伪码示例,先SFT,再RM,最后PPO。
从平衡性角度出发,PPO需要对不同损失项之间的权重需要仔细调整。
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import numpy as np
class RLHFDataset(Dataset):
"""RLHF数据集类,处理不同类型的数据"""
def __init__(self, data, stage="sft"):
self.data = data
self.stage = stage
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
if self.stage == "sft":
return {
"input_ids": item["input_ids"],
"attention_mask": item["attention_mask"],
"labels": item.get("labels", item["input_ids"].clone())
}
elif self.stage == "rm":
return {
"prompt_ids": item["prompt_ids"],
"chosen_ids": item["chosen_ids"],
"rejected_ids": item["rejected_ids"],
"prompt_mask": item["prompt_mask"],
"chosen_mask": item["chosen_mask"],
"rejected_mask": item["rejected_mask"]
}
elif self.stage == "ppo":
return {
"query": item["query"],
"query_ids": item["query_ids"],
"response": item.get("response", ""),
"response_ids": item.get("response_ids", None),
"advantages": item.get("advantages", None),
"returns": item.get("returns", None)
}
class DataCollator:
"""数据整理器,处理padding和批处理"""
def __init__(self, tokenizer, stage="sft"):
self.tokenizer = tokenizer
self.stage = stage
def __call__(self, batch):
if self.stage == "sft":
return self.collate_sft(batch)
elif self.stage == "rm":
return self.collate_rm(batch)
elif self.stage == "ppo":
return self.collate_ppo(batch)
def collate_sft(self, batch):
max_len = max(len(item["input_ids"]) for item in batch)
input_ids = []
attention_masks = []
labels = []
for item in batch:
pad_len = max_len - len(item["input_ids"])
input_ids.append(torch.cat([item["input_ids"],
torch.full((pad_len,), self.tokenizer.pad_token_id)]))
attention_masks.append(torch.cat([item["attention_mask"],
torch.zeros(pad_len)]))
labels.append(torch.cat([item["labels"],
torch.full((pad_len,), -100)]))
return {
"input_ids": torch.stack(input_ids),
"attention_mask": torch.stack(attention_masks),
"labels": torch.stack(labels)
}
def collate_rm(self, batch):
# 处理prompt
prompt_max_len = max(len(item["prompt_ids"]) for item in batch)
prompt_ids = []
prompt_masks = []
for item in batch:
pad_len = prompt_max_len - len(item["prompt_ids"])
prompt_ids.append(torch.cat([item["prompt_ids"],
torch.full((pad_len,), self.tokenizer.pad_token_id)]))
prompt_masks.append(torch.cat([item["prompt_mask"],
torch.zeros(pad_len)]))
# 处理chosen和rejected响应
def pad_sequence(sequences, max_len=None):
if max_len is None:
max_len = max(len(seq) for seq in sequences)
padded = []
masks = []
for seq in sequences:
pad_len = max_len - len(seq)
padded.append(torch.cat([seq,
torch.full((pad_len,), self.tokenizer.pad_token_id)]))
masks.append(torch.cat([torch.ones(len(seq)),
torch.zeros(pad_len)]))
return torch.stack(padded), torch.stack(masks)
chosen_ids = [item["chosen_ids"] for item in batch]
rejected_ids = [item["rejected_ids"] for item in batch]
chosen_max_len = max(len(ids) for ids in chosen_ids)
rejected_max_len = max(len(ids) for ids in rejected_ids)
response_max_len = max(chosen_max_len, rejected_max_len)
chosen_ids_padded, chosen_masks = pad_sequence(chosen_ids, response_max_len)
rejected_ids_padded, rejected_masks = pad_sequence(rejected_ids, response_max_len)
return {
"prompt_ids": torch.stack(prompt_ids),
"prompt_mask": torch.stack(prompt_masks),
"chosen_ids": chosen_ids_padded,
"chosen_mask": chosen_masks,
"rejected_ids": rejected_ids_padded,
"rejected_mask": rejected_masks
}
class RLHFTrainer:
"""完整的RLHF训练器"""
def __init__(self, config, models, tokenizer):
self.config = config
self.models = models # {'sft': ..., 'rm': ..., 'ppo': ...}
self.tokenizer = tokenizer
# 初始化损失函数
self.sft_loss_fn = DetailedSFTLoss(config.sft_config)
self.rm_loss_fn = DetailedRewardLoss(config.rm_config)
self.ppo_loss_fn = DetailedPPOLoss(config.ppo_config)
# 训练状态跟踪
self.metrics = defaultdict(list)
def train_sft_epoch(self, dataloader, optimizer, model):
"""SFT训练epoch"""
model.train()
total_loss = 0
total_tokens = 0
for batch_idx, batch in enumerate(dataloader):
optimizer.zero_grad()
# 前向传播
outputs = model(
input_ids=batch["input_ids"].to(self.config.device),
attention_mask=batch["attention_mask"].to(self.config.device),
labels=batch["labels"].to(self.config.device)
)
# 计算详细损失
loss_dict = self.sft_loss_fn.compute_loss_with_mask(
logits=outputs.logits,
labels=batch["labels"].to(self.config.device),
attention_mask=batch["attention_mask"].to(self.config.device)
)
loss = loss_dict["loss"]
# 反向传播
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.max_grad_norm)
optimizer.step()
# 记录指标
total_loss += loss.item() * loss_dict["num_tokens"].item()
total_tokens += loss_dict["num_tokens"].item()
if batch_idx % self.config.log_interval == 0:
print(f"SFT Batch {batch_idx}: Loss={loss.item():.4f}, "
f"PPL={loss_dict['perplexity'].item():.2f}, "
f"Acc={loss_dict['accuracy'].item():.4f}")
avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
return avg_loss
def train_rm_epoch(self, dataloader, optimizer, reward_model):
"""奖励模型训练epoch"""
reward_model.train()
total_loss = 0
total_samples = 0
for batch_idx, batch in enumerate(dataloader):
optimizer.zero_grad()
# 计算chosen奖励
chosen_inputs = {
"input_ids": torch.cat([batch["prompt_ids"], batch["chosen_ids"]], dim=1).to(self.config.device),
"attention_mask": torch.cat([batch["prompt_mask"], batch["chosen_mask"]], dim=1).to(self.config.device)
}
chosen_rewards = reward_model(**chosen_inputs)
# 计算rejected奖励
rejected_inputs = {
"input_ids": torch.cat([batch["prompt_ids"], batch["rejected_ids"]], dim=1).to(self.config.device),
"attention_mask": torch.cat([batch["prompt_mask"], batch["rejected_mask"]], dim=1).to(self.config.device)
}
rejected_rewards = reward_model(**rejected_inputs)
# 计算损失
loss_dict = self.rm_loss_fn.compute_pairwise_loss(
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
chosen_mask=batch["chosen_mask"].to(self.config.device),
rejected_mask=batch["rejected_mask"].to(self.config.device)
)
loss = loss_dict["loss"]
# 反向传播
loss.backward()
torch.nn.utils.clip_grad_norm_(reward_model.parameters(), self.config.max_grad_norm)
optimizer.step()
# 记录指标
total_loss += loss.item()
total_samples += 1
if batch_idx % self.config.log_interval == 0:
print(f"RM Batch {batch_idx}: Loss={loss.item():.4f}, "
f"Acc={loss_dict['accuracy'].item():.4f}, "
f"Chosen-Reward={loss_dict['chosen_mean'].item():.4f}, "
f"Rejected-Reward={loss_dict['rejected_mean'].item():.4f}")
avg_loss = total_loss / total_samples if total_samples > 0 else 0
return avg_loss
def collect_ppo_samples(self, model, prompts, reward_model, num_samples=1):
"""收集PPO训练样本"""
model.eval()
all_samples = []
with torch.no_grad():
for prompt in prompts:
# 编码prompt
prompt_enc = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.config.max_prompt_length
).to(self.config.device)
# 生成多个响应
for _ in range(num_samples):
# 采样生成
output_sequences = model.generate(
input_ids=prompt_enc["input_ids"],
attention_mask=prompt_enc["attention_mask"],
max_length=self.config.max_length,
do_sample=True,
temperature=self.config.temperature,
top_p=self.config.top_p,
pad_token_id=self.tokenizer.pad_token_id
)
# 解码响应
response = self.tokenizer.decode(
output_sequences[0, prompt_enc["input_ids"].shape[1]:],
skip_special_tokens=True
)
# 计算奖励
full_input = torch.cat([
prompt_enc["input_ids"],
output_sequences[:, prompt_enc["input_ids"].shape[1]:]
], dim=1)
full_mask = torch.cat([
prompt_enc["attention_mask"],
torch.ones_like(output_sequences[:, prompt_enc["input_ids"].shape[1]:])
], dim=1)
reward = reward_model(
input_ids=full_input,
attention_mask=full_mask
)
# 计算log probs
outputs = model(
input_ids=full_input,
attention_mask=full_mask,
return_dict=True
)
logits = outputs.logits
log_probs = F.log_softmax(logits, dim=-1)
# 收集生成token的log probs
response_ids = output_sequences[:, prompt_enc["input_ids"].shape[1]:]
response_log_probs = torch.gather(
log_probs[:, prompt_enc["input_ids"].shape[1]-1:-1, :],
dim=-1,
index=response_ids.unsqueeze(-1)
).squeeze(-1)
all_samples.append({
"prompt": prompt,
"response": response,
"response_ids": response_ids[0],
"reward": reward.item(),
"log_probs": response_log_probs[0],
"attention_mask": full_mask[0]
})
return all_samples
def train_ppo_epoch(self, samples, policy_model, value_model, optimizer, reward_model):
"""PPO训练epoch"""
policy_model.train()
value_model.train() if value_model is not None else None
# 准备数据
batch_size = self.config.ppo_batch_size
num_batches = (len(samples) + batch_size - 1) // batch_size
total_actor_loss = 0
total_value_loss = 0
total_kl = 0
for batch_idx in range(num_batches):
batch_samples = samples[batch_idx*batch_size:(batch_idx+1)*batch_size]
if not batch_samples:
continue
# 准备batch数据
batch_data = self.prepare_ppo_batch(batch_samples)
# 前向传播计算当前策略的log probs
outputs = policy_model(
input_ids=batch_data["full_input_ids"],
attention_mask=batch_data["full_attention_mask"],
return_dict=True
)
current_logits = outputs.logits
current_log_probs = F.log_softmax(current_logits, dim=-1)
# 提取响应部分的log probs
response_log_probs = torch.gather(
current_log_probs[:, batch_data["prompt_lengths"][0]-1:-1, :],
dim=-1,
index=batch_data["response_ids"].unsqueeze(-1)
).squeeze(-1)
# 计算actor损失
actor_loss = self.ppo_loss_fn.compute_actor_loss(
logprobs=response_log_probs,
old_logprobs=batch_data["old_log_probs"],
advantages=batch_data["advantages"],
kl_penalty=True
)
# 计算value损失(如果有价值模型)
if value_model is not None:
values = value_model(batch_data["full_input_ids"])
value_loss = self.ppo_loss_fn.compute_value_loss(
values=values,
old_values=batch_data["old_values"],
returns=batch_data["returns"]
)
loss = actor_loss + self.config.value_coef * value_loss
total_value_loss += value_loss.item()
else:
loss = actor_loss
# 反向传播
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), self.config.max_grad_norm)
optimizer.step()
# 记录指标
total_actor_loss += actor_loss.item()
total_kl += (batch_data["old_log_probs"] - response_log_probs).mean().item()
if batch_idx % self.config.log_interval == 0:
print(f"PPO Batch {batch_idx}: Actor Loss={actor_loss.item():.4f}, "
f"Value Loss={value_loss.item() if value_model else 0:.4f}, "
f"KL={total_kl/(batch_idx+1):.4f}")
metrics = {
"actor_loss": total_actor_loss / num_batches,
"value_loss": total_value_loss / num_batches if value_model else 0,
"kl": total_kl / num_batches
}
return metrics
def prepare_ppo_batch(self, samples):
"""准备PPO训练batch"""
# 这里实现批处理逻辑
# 实际实现需要考虑不同长度的序列padding
pass
4 高级技巧
实际训练中可能不会简单复用RLHF,而是会采用各种高级技巧。
比如多目标奖励平衡、课程学习策略、离线偏好优化。
4.1 多目标奖励平衡
以下是多目标奖励平衡的伪码示例。
class MultiObjectiveReward:
"""多目标奖励平衡"""
def __init__(self, objectives, weights=None):
self.objectives = objectives # 字典: {name: reward_model}
self.weights = weights or {name: 1.0 for name in objectives}
def __call__(self, prompt, response):
total_reward = 0
individual_rewards = {}
for name, model in self.objectives.items():
reward = model(prompt, response)
weighted_reward = reward * self.weights[name]
total_reward += weighted_reward
individual_rewards[name] = reward.item()
return total_reward, individual_rewards
4.2 课程学习策略
以下是课程学习策略的的伪码示例。
class CurriculumLearning:
"""课程学习,逐步增加难度"""
def __init__(self, curriculum_config):
self.stages = curriculum_config["stages"]
self.current_stage = 0
def get_current_data(self, full_dataset):
"""根据当前阶段获取数据子集"""
stage_config = self.stages[self.current_stage]
# 根据难度过滤数据
filtered_data = []
for item in full_dataset:
if self.meets_criteria(item, stage_config):
filtered_data.append(item)
return filtered_data
def advance_stage(self, metrics):
"""根据性能指标决定是否进入下一阶段"""
if self.current_stage >= len(self.stages) - 1:
return False
stage_config = self.stages[self.current_stage]
advancement_criteria = stage_config.get("advancement_criteria", {})
# 检查是否满足进阶条件
for metric_name, threshold in advancement_criteria.items():
if metric_name in metrics and metrics[metric_name] >= threshold:
continue
else:
return False
self.current_stage += 1
return True
4.3 离线偏好优化
以下是离线偏好优化,比如DPO、KTO优化的伪码示例。
class OfflinePreferenceOptimization:
"""离线偏好优化(如DPO、KTO)"""
def dpo_loss(self, policy_logps, ref_logps, beta=0.1):
"""
DPO损失函数
policy_logps: 当前策略的log概率 [B]
ref_logps: 参考策略的log概率 [B]
"""
log_ratios = policy_logps - ref_logps
ratios = torch.exp(log_ratios)
# DPO损失
losses = -F.logsigmoid(beta * log_ratios)
return losses.mean()
def kto_loss(self, rewards, kl_penalty, alpha=0.1):
"""
KTO损失函数
rewards: 奖励 [B]
kl_penalty: KL散度惩罚 [B]
"""
# 简单实现
return (rewards - alpha * kl_penalty).mean()
reference
--
PPO优势函数的学习和解读
https://blog.csdn.net/liliang199/article/details/148875214
PPO在强化学习中的应用
https://blog.csdn.net/liliang199/article/details/148840758
PPO偏好策略函数的学习解读
https://blog.csdn.net/liliang199/article/details/148811167
从蒙特卡洛的角度探索和示例重要性采样
https://blog.csdn.net/liliang199/article/details/154842020
DPO直接偏好函数的学习解读
https://blog.csdn.net/liliang199/article/details/148797585
KTO偏好效用函数的学习解读
https://blog.csdn.net/liliang199/article/details/148772543
RL偏好数据集