基于人类反馈的强化学习框架-RLHF&PPO

RLHF以及基于RLHF发展的PPO强化学习调优,是目前LLM超强能力的基础。

这里尝试从伪代码角度,尝试深入解析RLHF损失函数与训练数据的关系。

所用示例参考和修改自网络资料。

1 RLHF

RLHF(Reinforcement Learning from Human Feedback)的训练过程通常分为三个主要阶段:

1.1 预训练语言模型(LM)

使用标准的语言模型,比如由多层Transformer构成的一个大模型,在一个大规模的文本语料库上预训练,得到模型通常称为基座模型(base model)。

1.2 奖励模型(RM)

收集人类反馈并训练奖励模型,具体为

1)收集人类对模型生成文本的反馈,通常让基座模型生成多个输出,然后让人类对这些输出进行排序,比如,对模型输出进行打分,选择哪个更好。

2)使用这些排序数据来训练一个奖励模型,该模型能够对任意生成的文本给出一个分数,反映其质量(如是否符合人类偏好)。

1.3 策略模型(PM)

使用强化学习(RL)微调基于语言模型的策略模型,使用训练好的奖励模型作为奖励信号,通过强化学习通常是PPO来微调策略模型,使得模型生成更符合人类偏好的文本。

2 训练数据&过程解析

这里通过训练数据结合训练过程,分析RLHF的三个主要阶段。

2.1 SFT阶段

首先是SFT训练数据,用于训练LM模型,需要高质量、多样化的数据。

示例如下。

提示 (Prompt)

用户:我持续低烧和干咳已经三天了,有可能感染了新冠病毒吗?我应该怎么办?

理想回答 (Chosen Response)

【医学知识】 持续低烧和干咳是新冠病毒感染(COVID-19)的常见症状之一,但也可能与其他呼吸道感染(如普通流感)有关。

【行动建议】

  1. 立即检测:建议您尽快进行新冠病毒抗原快速检测或核酸检测,这是确认是否感染的最直接方法。

  2. 自我隔离:在等待检测结果和症状缓解期间,请尽量留在家中,佩戴口罩,与其他同住者保持距离,以降低潜在传播风险。

  3. 健康监测:密切监测体温和症状变化。如果出现呼吸困难、胸痛、意识模糊或高烧不退等严重症状,请立即就医。

  4. 常规护理:多休息、多喝水,可以使用一些非处方药来缓解发热和咳嗽症状。

【免责声明】以上信息仅供参考,不能替代专业医疗建议。请根据自身情况及时咨询医生。

这条SFT数据为模型提供了一个高质量的回答范例,涵盖了知识解释、结构化建议和必要的免责声明,为后续的偏好排序树立了明确的标准。

2.2 奖励训练

其次是训练RM奖励模型的样本数据,需要明确的偏好判断,避免模糊对。示例如下所示。

提示 (Prompt):与SFT阶段完全相同

用户:我持续低烧和干咳已经三天了,有可能感染了新冠病毒吗?我应该怎么办?

回答A (Response A)

低烧和干咳可能是新冠症状,建议你去做个检测看看。多喝水,多休息。

回答B (Response B):为简化分析,这里直接采用SFT阶段的理想回答。

人工标注的偏好与评分

评估维度 回答A 得分 (1-5) 回答B 得分 (1-5) 偏好说明
帮助性 2 5 回答B提供了具体、可操作的步骤(检测、隔离、监测),而A的建议过于笼统。
诚实性 4 5 两者均正确,但B的信息更全面,指出了其他可能性。
无害性 4 5 B包含了关键的免责声明,更严谨。
条理性 2 5 B使用标题和列表,结构清晰易读;A是简单句堆砌。
总体偏好 拒绝 (Rejected) 优选 (Chosen) 回答B在所有维度上都显著优于A,因为它更详细、更专业、更具建设性。

对于同一个prompt,可以生成多个回答,比如回答A、回答B、回答C、回答D等。

然后采用上述形式,两两对比标注,比如A<B, A>C, A>D, B>C, B>D, C>D,如此形成6个样本。

需要注意的时,在训练时需要将上述表主的6个样本放在一个batch,目的是维持训练的稳定性。

通过这种多维度、带解释的标注,奖励模型不仅能学会"B比A好",更能精细化地理解"好在哪里"

例如,回答B相比A,在"帮助性"和"条理性"上差距巨大。

成千上万条这样的标注数据,就构成了奖励模型的"价值观"。

2.3 PPO阶段

PPO阶段的数据,多维度奖励可以提供更丰富的训练信号。

在此阶段,初始策略模型(经过SFT的模型)根据RM的反馈进行迭代更新。

1)初始策略模型生成回答

提示"我接触了新冠确诊者,但还没症状,需要担心吗?"

初始模型生成(第1轮)"你存在被感染的风险,建议进行核酸检测并观察症状。"

2)奖励模型进行评估

奖励模型收到这个回答后,会对照其学习的"好回答"标准(来自RM训练数据)进行打分。假设它给出一个中等分数(例如+0.5)。这个分数不高,因为回答虽然正确,但不够详细和主动。

3)策略模型迭代与优化

强化学习优化(PPO)分析这个分数,并调整模型参数,目标是生成能获得更高奖励分数的回答。

经过多轮迭代后,面对同样的提示,优化后的模型可能生成:

【风险评估】 您目前处于暴露后观察期。由于新冠病毒存在潜伏期(通常2-14天),即使暂无症状,也存在感染风险。

【核心建议】

  1. 暴露后管理:建议从接触日起,自我健康监测14天。

  2. 检测时机:不必立即检测,因为病毒载量可能不足。建议在接触后第5天左右进行核酸检测,此时检出率较高。

  3. 预防性措施:在监测期间,请尽量减少与他人密切接触,外出佩戴口罩,这既是保护他人,也是自我保护。

【补充信息】 请留意是否出现发热、咳嗽、乏力等任何轻微症状。部分感染者症状轻微或无症状,但仍具传染性。

以上建议基于当前公共卫生指南,个人情况请咨询疾控部门或医生。

4)最终评估

奖励模型对这个更详尽、更具前瞻性 的回答可能会给出更高的分数(例如+2.8)

同时,参考模型 (通常clone自SFT模型)会计算当前回答与初始SFT模型"典型风格"的偏离程度(KL散度惩罚),以防止优化过度、产生胡言乱语。最终收益是 奖励分数 - 偏离惩罚

3 训练损失和完整流程

3.1 SFT损失函数

首先是SFT,使用高质量的问答数据对预训练模型进行微调。

数据是高质量的问答对,例如(prompt, response)

1)训练目标

目标是使模型学会生成高质量的答案。

损失函数,即标准的语言模型损失,即交叉熵损失,但只计算答案部分(不包括prompt部分)损失。

假设我们有一个批次的数据,每个数据点由input_idslabels组成,其中input_ids是prompt和response的拼接,labels是与input_ids等长的序列,但是将prompt部分的标签设置为-100(在计算损失时忽略),只计算response部分的损失。

2)数学表示

SFT损失函数的数学解析如下所示

其中

是mask(1表示有效token,0表示padding)。

prompt部分mask设置为0,response部分mask设置为1,表示只计算response部分的损失。

是batch size,是序列长度。

SFT损失示例伪码如下所示。

复制代码
import torch
import torch.nn.functional as F

class DetailedSFTLoss:
    def __init__(self, config):
        self.config = config
        
    def compute_loss_with_mask(self, logits, labels, attention_mask=None):
        """
        详细的SFT损失计算,包含padding mask处理
        """
        # 维度: logits: [B, T, V], labels: [B, T]
        batch_size, seq_len, vocab_size = logits.shape
        
        # 1. 计算每个位置的log probability
        log_probs = F.log_softmax(logits, dim=-1)  # [B, T, V]
        
        # 2. 提取目标token的log prob
        # labels需要调整为[B, T, 1]用于gather
        target_log_probs = torch.gather(
            log_probs, 
            dim=-1, 
            index=labels.unsqueeze(-1)
        ).squeeze(-1)  # [B, T]
        
        # 3. 创建loss mask
        if attention_mask is None:
            attention_mask = torch.ones_like(labels)
        
        # 对于语言建模,我们通常预测下一个token
        # 所以需要将labels向右shift一位
        shift_mask = attention_mask[:, 1:]  # 忽略第一个token
        shift_labels = labels[:, 1:]  # 预测下一个token
        shift_log_probs = target_log_probs[:, :-1]  # 对应的log probs
        
        # 4. 计算加权损失
        # 只计算非padding位置的损失
        loss_per_token = -shift_log_probs * shift_mask
        total_loss = loss_per_token.sum()
        num_valid_tokens = shift_mask.sum()
        
        # 5. 避免除以零
        if num_valid_tokens > 0:
            average_loss = total_loss / num_valid_tokens
        else:
            average_loss = torch.tensor(0.0)
            
        # 6. 计算额外统计信息
        with torch.no_grad():
            # 计算perplexity
            exp_loss = torch.exp(average_loss)
            
            # 计算准确率
            predictions = torch.argmax(logits[:, :-1, :], dim=-1)
            correct = (predictions == shift_labels) * shift_mask
            accuracy = correct.sum() / num_valid_tokens if num_valid_tokens > 0 else 0
            
        return {
            "loss": average_loss,
            "perplexity": exp_loss,
            "accuracy": accuracy,
            "num_tokens": num_valid_tokens
        }

3.2 奖励模型损失

1)训练目标

其次是RM,训练一个奖励模型评估生成内容的质量,通常使用人类对生成内容的偏好数据。

数据是由SFT模型生成多个回答,然后由人类标注员对这些回答进行排序。

如第2节所示,对于同一个prompt,生成两个回答,一个被标注为更好(chosen),一个被标注为更差(rejected),目标是训练一个奖励模型,对于给定的prompt和response,输出一个标量奖励值,使得chosen回答的奖励高于rejected回答。

2)数学表示

RM的损失函数,通常使用Pairwise Ranking Loss,例如Bradley-Terry模型。

奖励函数的数学解析如下所示

成对损失(Pairwise Loss)

其中是温度参数,控制区分度。

成对损失主要用于优化奖励模型对<prompt, response>对的打分。

列表损失(Listwise Loss)

其中是根据人类偏好排序的排列。

列表损失采用类似softmax的方法,通过分数,间接学习人类对不同<prompt, response>的排序。

RM损失的示例代码如下所示。

复制代码
class DetailedRewardLoss:
    def __init__(self, config):
        self.config = config
        # 温度参数,控制偏好对的区分度
        self.temperature = config.temperature
        # 是否使用margin
        self.use_margin = config.use_margin
        self.margin = config.margin
        
    def compute_pairwise_loss(self, chosen_rewards, rejected_rewards, 
                            chosen_mask=None, rejected_mask=None,
                            pair_weights=None):
        """
        计算成对偏好损失,支持不同长度序列
        
        参数:
            chosen_rewards: [B, T_c] 或 [B,]
            rejected_rewards: [B, T_r] 或 [B,]
        """
        batch_size = chosen_rewards.shape[0]
        
        # 1. 处理序列奖励(如果奖励是逐token的)
        if chosen_rewards.dim() > 1:
            # 对序列奖励进行聚合(例如取最后一个token或平均)
            if chosen_mask is not None:
                chosen_final_rewards = self.aggregate_sequence_rewards(
                    chosen_rewards, chosen_mask
                )
                rejected_final_rewards = self.aggregate_sequence_rewards(
                    rejected_rewards, rejected_mask
                )
            else:
                chosen_final_rewards = chosen_rewards[:, -1]  # 取EOS token
                rejected_final_rewards = rejected_rewards[:, -1]
        else:
            chosen_final_rewards = chosen_rewards
            rejected_final_rewards = rejected_rewards
        
        # 2. 计算logits(对数胜率)
        # 应用温度参数
        logits = (chosen_final_rewards - rejected_final_rewards) / self.temperature
        
        # 3. 计算基础损失
        # Bradley-Terry模型: P(chosen > rejected) = σ(r_chosen - r_rejected)
        if self.use_margin:
            # 带margin的损失,确保chosen比rejected至少高margin
            loss = F.relu(self.margin - logits).mean()
            losses = F.relu(self.margin - logits)
        else:
            # 标准Bradley-Terry损失
            # loss = -log(σ(r_chosen - r_rejected))
            loss = -F.logsigmoid(logits).mean()
            losses = -F.logsigmoid(logits)
        
        # 4. 应用样本权重(如果有)
        if pair_weights is not None:
            loss = (losses * pair_weights).sum() / pair_weights.sum()
        
        # 5. 计算额外统计信息
        with torch.no_grad():
            # 预测准确率(chosen奖励是否大于rejected)
            predictions = (chosen_final_rewards > rejected_final_rewards).float()
            accuracy = predictions.mean()
            
            # 平均奖励差
            reward_diff = (chosen_final_rewards - rejected_final_rewards).mean()
            
            # 冲突样本比例(奖励差太小)
            conflict_ratio = (torch.abs(chosen_final_rewards - rejected_final_rewards) < 0.1).float().mean()
            
        return {
            "loss": loss,
            "accuracy": accuracy,
            "reward_diff": reward_diff,
            "conflict_ratio": conflict_ratio,
            "chosen_mean": chosen_final_rewards.mean(),
            "rejected_mean": rejected_final_rewards.mean()
        }
    
    def aggregate_sequence_rewards(self, rewards, mask):
        """聚合序列奖励"""
        if self.config.reward_aggregation == "last":
            # 取最后一个非padding token的奖励
            seq_lengths = mask.sum(dim=1).long() - 1  # 最后一个有效token的索引
            batch_indices = torch.arange(rewards.shape[0])
            aggregated = rewards[batch_indices, seq_lengths]
        elif self.config.reward_aggregation == "mean":
            # 取非padding token的平均奖励
            aggregated = (rewards * mask).sum(dim=1) / mask.sum(dim=1)
        elif self.config.reward_aggregation == "sum":
            # 取非padding token的奖励和
            aggregated = (rewards * mask).sum(dim=1)
        else:
            raise ValueError(f"Unknown aggregation: {self.config.reward_aggregation}")
        
        return aggregated
    
    def compute_listwise_loss(self, rewards_list, rankings, temperature=1.0):
        """
        列表损失(Listwise loss),用于多个响应排序
        
        参数:
            rewards_list: [B, K] K个响应的奖励
            rankings: [B, K] 每个响应的排名(1表示最好)
        """
        batch_size, num_responses = rewards_list.shape
        
        # 1. 将排名转换为概率分布(Plackett-Luce模型)
        # P(ranking) = ∏_{i=1}^{K} exp(r_{π(i)}) / ∑_{j=i}^{K} exp(r_{π(j)})
        
        # 根据排名排序奖励
        sorted_indices = torch.argsort(rankings, dim=1)
        sorted_rewards = torch.gather(rewards_list, 1, sorted_indices)
        
        # 2. 计算log概率
        losses = []
        for i in range(num_responses):
            # 计算第i个位置的条件概率
            remaining_rewards = sorted_rewards[:, i:] / temperature
            log_denominator = torch.logsumexp(remaining_rewards, dim=1)
            log_prob = sorted_rewards[:, i] / temperature - log_denominator
            losses.append(-log_prob)
        
        loss = torch.stack(losses, dim=1).mean()
        
        return loss

3.3 PPO强化损失

强化学习微调,使用奖励模型作为奖励信号,通过强化学习(PPO)进一步微调SFT模型。

1)训练目标

PPO使用当前的策略模型,通常初始化自SFT模型,与环境即用户输入prompt交互,生成回答。

然后使用奖励模型计算奖励,目标是优化策略模型以获取更高奖励。

PPO损失函数,包括三个部分:

策略损失(Policy Loss):使用PPO-Clip算法,限制更新幅度。

价值损失(Value Loss):如果使用价值函数,则需要对价值函数模型进行训练。

熵损失(Entropy Loss):鼓励探索,防止策略过早收敛。

其中策略损失是主要部分,同时加入KL散度惩罚来防止策略模型偏离参考模型(SFT模型)太远。

2)数学表示

PPO-Clip Actor损失:

其中

这里从稳定性角度出发,使用clip和normalization等技术,限制更新幅度。

KL散度惩罚:

加入KL散度惩罚来防止策略模型偏离参考模型(SFT模型)太远。

价值损失

即通用的MSE损失。

示例代码如下所示

复制代码
class DetailedPPOLoss:
    def __init__(self, config):
        self.config = config
        
    def compute_actor_loss(self, logprobs, old_logprobs, advantages, kl_penalty=True):
        """
        详细的Actor损失计算
        
        参数:
            logprobs: 当前策略的log概率 [B, T] 或 [B,]
            old_logprobs: 旧策略的log概率 [B, T] 或 [B,]
            advantages: 优势估计 [B, T] 或 [B,]
        """
        # 1. 计算概率比
        ratio = torch.exp(logprobs - old_logprobs)  # [B, T] 或 [B,]
        
        # 2. 计算两种损失
        pg_loss1 = ratio * advantages
        pg_loss2 = torch.clamp(ratio, 
                             1.0 - self.config.clip_epsilon,
                             1.0 + self.config.clip_epsilon) * advantages
        
        # 3. PPO-Clip损失
        actor_loss = -torch.min(pg_loss1, pg_loss2)
        
        # 4. 如果按token计算,需要取平均
        if actor_loss.dim() > 1:
            actor_loss = actor_loss.mean(dim=-1)
        
        # 5. 添加KL散度惩罚
        if kl_penalty:
            kl_div = old_logprobs - logprobs  # KL(p_old || p_new)
            if kl_div.dim() > 1:
                kl_div = kl_div.mean(dim=-1)
            
            # 自适应KL惩罚
            if self.config.adaptive_kl:
                kl_coef = self.config.kl_coef
                kl_mean = kl_div.mean().item()
                
                if kl_mean > self.config.target_kl * 1.5:
                    kl_coef *= 1.2
                elif kl_mean < self.config.target_kl / 1.5:
                    kl_coef /= 1.2
            else:
                kl_coef = self.config.kl_coef
            
            actor_loss = actor_loss + kl_coef * kl_div
        
        # 6. 添加熵奖励(鼓励探索)
        if self.config.entropy_coef > 0:
            # 计算熵: H(p) = -∑ p log p
            # 这里我们有logits时可以直接计算,只有logprobs时需要估算
            # 简化:假设当前策略是确定的,熵近似为-logprobs的方差
            entropy_bonus = -logprobs.var(dim=-1) if logprobs.dim() > 1 else torch.tensor(0.0)
            actor_loss = actor_loss - self.config.entropy_coef * entropy_bonus
        
        return actor_loss.mean()
    
    def compute_value_loss(self, values, old_values, returns, clip_value_loss=True):
        """
        价值损失计算
        
        参数:
            values: 当前价值估计 [B, T] 或 [B,]
            old_values: 旧价值估计 [B, T] 或 [B,]
            returns: 实际回报 [B, T] 或 [B,]
        """
        # 1. 价值预测误差
        value_error = returns - values
        
        # 2. 可选:Clipped value loss
        if clip_value_loss and old_values is not None:
            value_clipped = old_values + torch.clamp(
                values - old_values,
                -self.config.clip_epsilon,
                self.config.clip_epsilon
            )
            value_loss1 = (value_error ** 2)
            value_loss2 = ((returns - value_clipped) ** 2)
            value_loss = 0.5 * torch.max(value_loss1, value_loss2)
        else:
            value_loss = 0.5 * (value_error ** 2)
        
        # 3. 按需取平均
        if value_loss.dim() > 1:
            value_loss = value_loss.mean(dim=-1)
        
        return value_loss.mean()
    
    def compute_advantages(self, rewards, values, dones=None, gamma=0.99, gae_lambda=0.95):
        """
        广义优势估计 (GAE)
        
        参数:
            rewards: 奖励序列 [B, T]
            values: 价值估计 [B, T+1] (包含终止状态价值)
            dones: 终止标志 [B, T]
        """
        batch_size, seq_len = rewards.shape
        
        # 初始化优势值
        advantages = torch.zeros_like(rewards)
        last_gae = torch.zeros(batch_size)
        
        # 反向计算GAE
        for t in reversed(range(seq_len)):
            if t == seq_len - 1:
                next_value = values[:, t+1] if values.shape[1] > seq_len else 0.0
            else:
                next_value = values[:, t+1]
            
            # 计算TD误差
            delta = rewards[:, t] + gamma * next_value * (1.0 if dones is None else (1.0 - dones[:, t])) - values[:, t]
            
            # 更新GAE
            last_gae = delta + gamma * gae_lambda * (1.0 if dones is None else (1.0 - dones[:, t])) * last_gae
            advantages[:, t] = last_gae
        
        return advantages

3.4 完整训练流程

以下是RLHF的完整训练流程伪码示例,先SFT,再RM,最后PPO。

从平衡性角度出发,PPO需要对不同损失项之间的权重需要仔细调整。

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import numpy as np

class RLHFDataset(Dataset):
    """RLHF数据集类,处理不同类型的数据"""
    
    def __init__(self, data, stage="sft"):
        self.data = data
        self.stage = stage
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        if self.stage == "sft":
            return {
                "input_ids": item["input_ids"],
                "attention_mask": item["attention_mask"],
                "labels": item.get("labels", item["input_ids"].clone())
            }
        elif self.stage == "rm":
            return {
                "prompt_ids": item["prompt_ids"],
                "chosen_ids": item["chosen_ids"],
                "rejected_ids": item["rejected_ids"],
                "prompt_mask": item["prompt_mask"],
                "chosen_mask": item["chosen_mask"],
                "rejected_mask": item["rejected_mask"]
            }
        elif self.stage == "ppo":
            return {
                "query": item["query"],
                "query_ids": item["query_ids"],
                "response": item.get("response", ""),
                "response_ids": item.get("response_ids", None),
                "advantages": item.get("advantages", None),
                "returns": item.get("returns", None)
            }
    
class DataCollator:
    """数据整理器,处理padding和批处理"""
    
    def __init__(self, tokenizer, stage="sft"):
        self.tokenizer = tokenizer
        self.stage = stage
        
    def __call__(self, batch):
        if self.stage == "sft":
            return self.collate_sft(batch)
        elif self.stage == "rm":
            return self.collate_rm(batch)
        elif self.stage == "ppo":
            return self.collate_ppo(batch)
    
    def collate_sft(self, batch):
        max_len = max(len(item["input_ids"]) for item in batch)
        
        input_ids = []
        attention_masks = []
        labels = []
        
        for item in batch:
            pad_len = max_len - len(item["input_ids"])
            input_ids.append(torch.cat([item["input_ids"], 
                                      torch.full((pad_len,), self.tokenizer.pad_token_id)]))
            attention_masks.append(torch.cat([item["attention_mask"], 
                                            torch.zeros(pad_len)]))
            labels.append(torch.cat([item["labels"], 
                                   torch.full((pad_len,), -100)]))
        
        return {
            "input_ids": torch.stack(input_ids),
            "attention_mask": torch.stack(attention_masks),
            "labels": torch.stack(labels)
        }
    
    def collate_rm(self, batch):
        # 处理prompt
        prompt_max_len = max(len(item["prompt_ids"]) for item in batch)
        
        prompt_ids = []
        prompt_masks = []
        
        for item in batch:
            pad_len = prompt_max_len - len(item["prompt_ids"])
            prompt_ids.append(torch.cat([item["prompt_ids"],
                                       torch.full((pad_len,), self.tokenizer.pad_token_id)]))
            prompt_masks.append(torch.cat([item["prompt_mask"],
                                         torch.zeros(pad_len)]))
        
        # 处理chosen和rejected响应
        def pad_sequence(sequences, max_len=None):
            if max_len is None:
                max_len = max(len(seq) for seq in sequences)
            padded = []
            masks = []
            for seq in sequences:
                pad_len = max_len - len(seq)
                padded.append(torch.cat([seq, 
                                       torch.full((pad_len,), self.tokenizer.pad_token_id)]))
                masks.append(torch.cat([torch.ones(len(seq)), 
                                      torch.zeros(pad_len)]))
            return torch.stack(padded), torch.stack(masks)
        
        chosen_ids = [item["chosen_ids"] for item in batch]
        rejected_ids = [item["rejected_ids"] for item in batch]
        
        chosen_max_len = max(len(ids) for ids in chosen_ids)
        rejected_max_len = max(len(ids) for ids in rejected_ids)
        response_max_len = max(chosen_max_len, rejected_max_len)
        
        chosen_ids_padded, chosen_masks = pad_sequence(chosen_ids, response_max_len)
        rejected_ids_padded, rejected_masks = pad_sequence(rejected_ids, response_max_len)
        
        return {
            "prompt_ids": torch.stack(prompt_ids),
            "prompt_mask": torch.stack(prompt_masks),
            "chosen_ids": chosen_ids_padded,
            "chosen_mask": chosen_masks,
            "rejected_ids": rejected_ids_padded,
            "rejected_mask": rejected_masks
        }

class RLHFTrainer:
    """完整的RLHF训练器"""
    
    def __init__(self, config, models, tokenizer):
        self.config = config
        self.models = models  # {'sft': ..., 'rm': ..., 'ppo': ...}
        self.tokenizer = tokenizer
        
        # 初始化损失函数
        self.sft_loss_fn = DetailedSFTLoss(config.sft_config)
        self.rm_loss_fn = DetailedRewardLoss(config.rm_config)
        self.ppo_loss_fn = DetailedPPOLoss(config.ppo_config)
        
        # 训练状态跟踪
        self.metrics = defaultdict(list)
        
    def train_sft_epoch(self, dataloader, optimizer, model):
        """SFT训练epoch"""
        model.train()
        total_loss = 0
        total_tokens = 0
        
        for batch_idx, batch in enumerate(dataloader):
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(
                input_ids=batch["input_ids"].to(self.config.device),
                attention_mask=batch["attention_mask"].to(self.config.device),
                labels=batch["labels"].to(self.config.device)
            )
            
            # 计算详细损失
            loss_dict = self.sft_loss_fn.compute_loss_with_mask(
                logits=outputs.logits,
                labels=batch["labels"].to(self.config.device),
                attention_mask=batch["attention_mask"].to(self.config.device)
            )
            
            loss = loss_dict["loss"]
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.max_grad_norm)
            optimizer.step()
            
            # 记录指标
            total_loss += loss.item() * loss_dict["num_tokens"].item()
            total_tokens += loss_dict["num_tokens"].item()
            
            if batch_idx % self.config.log_interval == 0:
                print(f"SFT Batch {batch_idx}: Loss={loss.item():.4f}, "
                      f"PPL={loss_dict['perplexity'].item():.2f}, "
                      f"Acc={loss_dict['accuracy'].item():.4f}")
        
        avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
        return avg_loss
    
    def train_rm_epoch(self, dataloader, optimizer, reward_model):
        """奖励模型训练epoch"""
        reward_model.train()
        total_loss = 0
        total_samples = 0
        
        for batch_idx, batch in enumerate(dataloader):
            optimizer.zero_grad()
            
            # 计算chosen奖励
            chosen_inputs = {
                "input_ids": torch.cat([batch["prompt_ids"], batch["chosen_ids"]], dim=1).to(self.config.device),
                "attention_mask": torch.cat([batch["prompt_mask"], batch["chosen_mask"]], dim=1).to(self.config.device)
            }
            
            chosen_rewards = reward_model(**chosen_inputs)
            
            # 计算rejected奖励
            rejected_inputs = {
                "input_ids": torch.cat([batch["prompt_ids"], batch["rejected_ids"]], dim=1).to(self.config.device),
                "attention_mask": torch.cat([batch["prompt_mask"], batch["rejected_mask"]], dim=1).to(self.config.device)
            }
            
            rejected_rewards = reward_model(**rejected_inputs)
            
            # 计算损失
            loss_dict = self.rm_loss_fn.compute_pairwise_loss(
                chosen_rewards=chosen_rewards,
                rejected_rewards=rejected_rewards,
                chosen_mask=batch["chosen_mask"].to(self.config.device),
                rejected_mask=batch["rejected_mask"].to(self.config.device)
            )
            
            loss = loss_dict["loss"]
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(reward_model.parameters(), self.config.max_grad_norm)
            optimizer.step()
            
            # 记录指标
            total_loss += loss.item()
            total_samples += 1
            
            if batch_idx % self.config.log_interval == 0:
                print(f"RM Batch {batch_idx}: Loss={loss.item():.4f}, "
                      f"Acc={loss_dict['accuracy'].item():.4f}, "
                      f"Chosen-Reward={loss_dict['chosen_mean'].item():.4f}, "
                      f"Rejected-Reward={loss_dict['rejected_mean'].item():.4f}")
        
        avg_loss = total_loss / total_samples if total_samples > 0 else 0
        return avg_loss
    
    def collect_ppo_samples(self, model, prompts, reward_model, num_samples=1):
        """收集PPO训练样本"""
        model.eval()
        all_samples = []
        
        with torch.no_grad():
            for prompt in prompts:
                # 编码prompt
                prompt_enc = self.tokenizer(
                    prompt, 
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=self.config.max_prompt_length
                ).to(self.config.device)
                
                # 生成多个响应
                for _ in range(num_samples):
                    # 采样生成
                    output_sequences = model.generate(
                        input_ids=prompt_enc["input_ids"],
                        attention_mask=prompt_enc["attention_mask"],
                        max_length=self.config.max_length,
                        do_sample=True,
                        temperature=self.config.temperature,
                        top_p=self.config.top_p,
                        pad_token_id=self.tokenizer.pad_token_id
                    )
                    
                    # 解码响应
                    response = self.tokenizer.decode(
                        output_sequences[0, prompt_enc["input_ids"].shape[1]:],
                        skip_special_tokens=True
                    )
                    
                    # 计算奖励
                    full_input = torch.cat([
                        prompt_enc["input_ids"],
                        output_sequences[:, prompt_enc["input_ids"].shape[1]:]
                    ], dim=1)
                    
                    full_mask = torch.cat([
                        prompt_enc["attention_mask"],
                        torch.ones_like(output_sequences[:, prompt_enc["input_ids"].shape[1]:])
                    ], dim=1)
                    
                    reward = reward_model(
                        input_ids=full_input,
                        attention_mask=full_mask
                    )
                    
                    # 计算log probs
                    outputs = model(
                        input_ids=full_input,
                        attention_mask=full_mask,
                        return_dict=True
                    )
                    
                    logits = outputs.logits
                    log_probs = F.log_softmax(logits, dim=-1)
                    
                    # 收集生成token的log probs
                    response_ids = output_sequences[:, prompt_enc["input_ids"].shape[1]:]
                    response_log_probs = torch.gather(
                        log_probs[:, prompt_enc["input_ids"].shape[1]-1:-1, :],
                        dim=-1,
                        index=response_ids.unsqueeze(-1)
                    ).squeeze(-1)
                    
                    all_samples.append({
                        "prompt": prompt,
                        "response": response,
                        "response_ids": response_ids[0],
                        "reward": reward.item(),
                        "log_probs": response_log_probs[0],
                        "attention_mask": full_mask[0]
                    })
        
        return all_samples
    
    def train_ppo_epoch(self, samples, policy_model, value_model, optimizer, reward_model):
        """PPO训练epoch"""
        policy_model.train()
        value_model.train() if value_model is not None else None
        
        # 准备数据
        batch_size = self.config.ppo_batch_size
        num_batches = (len(samples) + batch_size - 1) // batch_size
        
        total_actor_loss = 0
        total_value_loss = 0
        total_kl = 0
        
        for batch_idx in range(num_batches):
            batch_samples = samples[batch_idx*batch_size:(batch_idx+1)*batch_size]
            
            if not batch_samples:
                continue
            
            # 准备batch数据
            batch_data = self.prepare_ppo_batch(batch_samples)
            
            # 前向传播计算当前策略的log probs
            outputs = policy_model(
                input_ids=batch_data["full_input_ids"],
                attention_mask=batch_data["full_attention_mask"],
                return_dict=True
            )
            
            current_logits = outputs.logits
            current_log_probs = F.log_softmax(current_logits, dim=-1)
            
            # 提取响应部分的log probs
            response_log_probs = torch.gather(
                current_log_probs[:, batch_data["prompt_lengths"][0]-1:-1, :],
                dim=-1,
                index=batch_data["response_ids"].unsqueeze(-1)
            ).squeeze(-1)
            
            # 计算actor损失
            actor_loss = self.ppo_loss_fn.compute_actor_loss(
                logprobs=response_log_probs,
                old_logprobs=batch_data["old_log_probs"],
                advantages=batch_data["advantages"],
                kl_penalty=True
            )
            
            # 计算value损失(如果有价值模型)
            if value_model is not None:
                values = value_model(batch_data["full_input_ids"])
                value_loss = self.ppo_loss_fn.compute_value_loss(
                    values=values,
                    old_values=batch_data["old_values"],
                    returns=batch_data["returns"]
                )
                loss = actor_loss + self.config.value_coef * value_loss
                total_value_loss += value_loss.item()
            else:
                loss = actor_loss
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), self.config.max_grad_norm)
            optimizer.step()
            
            # 记录指标
            total_actor_loss += actor_loss.item()
            total_kl += (batch_data["old_log_probs"] - response_log_probs).mean().item()
            
            if batch_idx % self.config.log_interval == 0:
                print(f"PPO Batch {batch_idx}: Actor Loss={actor_loss.item():.4f}, "
                      f"Value Loss={value_loss.item() if value_model else 0:.4f}, "
                      f"KL={total_kl/(batch_idx+1):.4f}")
        
        metrics = {
            "actor_loss": total_actor_loss / num_batches,
            "value_loss": total_value_loss / num_batches if value_model else 0,
            "kl": total_kl / num_batches
        }
        
        return metrics
    
    def prepare_ppo_batch(self, samples):
        """准备PPO训练batch"""
        # 这里实现批处理逻辑
        # 实际实现需要考虑不同长度的序列padding
        pass

4 高级技巧

实际训练中可能不会简单复用RLHF,而是会采用各种高级技巧。

比如多目标奖励平衡、课程学习策略、离线偏好优化。

4.1 多目标奖励平衡

以下是多目标奖励平衡的伪码示例。

复制代码
class MultiObjectiveReward:
    """多目标奖励平衡"""
    
    def __init__(self, objectives, weights=None):
        self.objectives = objectives  # 字典: {name: reward_model}
        self.weights = weights or {name: 1.0 for name in objectives}
        
    def __call__(self, prompt, response):
        total_reward = 0
        individual_rewards = {}
        
        for name, model in self.objectives.items():
            reward = model(prompt, response)
            weighted_reward = reward * self.weights[name]
            total_reward += weighted_reward
            individual_rewards[name] = reward.item()
            
        return total_reward, individual_rewards

4.2 课程学习策略

以下是课程学习策略的的伪码示例。

复制代码
class CurriculumLearning:
    """课程学习,逐步增加难度"""
    
    def __init__(self, curriculum_config):
        self.stages = curriculum_config["stages"]
        self.current_stage = 0
        
    def get_current_data(self, full_dataset):
        """根据当前阶段获取数据子集"""
        stage_config = self.stages[self.current_stage]
        
        # 根据难度过滤数据
        filtered_data = []
        for item in full_dataset:
            if self.meets_criteria(item, stage_config):
                filtered_data.append(item)
        
        return filtered_data
    
    def advance_stage(self, metrics):
        """根据性能指标决定是否进入下一阶段"""
        if self.current_stage >= len(self.stages) - 1:
            return False
        
        stage_config = self.stages[self.current_stage]
        advancement_criteria = stage_config.get("advancement_criteria", {})
        
        # 检查是否满足进阶条件
        for metric_name, threshold in advancement_criteria.items():
            if metric_name in metrics and metrics[metric_name] >= threshold:
                continue
            else:
                return False
        
        self.current_stage += 1
        return True

4.3 离线偏好优化

以下是离线偏好优化,比如DPO、KTO优化的伪码示例。

复制代码
class OfflinePreferenceOptimization:
    """离线偏好优化(如DPO、KTO)"""
    
    def dpo_loss(self, policy_logps, ref_logps, beta=0.1):
        """
        DPO损失函数
        policy_logps: 当前策略的log概率 [B]
        ref_logps: 参考策略的log概率 [B]
        """
        log_ratios = policy_logps - ref_logps
        ratios = torch.exp(log_ratios)
        
        # DPO损失
        losses = -F.logsigmoid(beta * log_ratios)
        return losses.mean()
    
    def kto_loss(self, rewards, kl_penalty, alpha=0.1):
        """
        KTO损失函数
        rewards: 奖励 [B]
        kl_penalty: KL散度惩罚 [B]
        """
        # 简单实现
        return (rewards - alpha * kl_penalty).mean()

reference

--

PPO优势函数的学习和解读

https://blog.csdn.net/liliang199/article/details/148875214

PPO在强化学习中的应用

https://blog.csdn.net/liliang199/article/details/148840758

PPO偏好策略函数的学习解读

https://blog.csdn.net/liliang199/article/details/148811167

从蒙特卡洛的角度探索和示例重要性采样

https://blog.csdn.net/liliang199/article/details/154842020

DPO直接偏好函数的学习解读

https://blog.csdn.net/liliang199/article/details/148797585

KTO偏好效用函数的学习解读

https://blog.csdn.net/liliang199/article/details/148772543

RL偏好数据集

https://blog.csdn.net/liliang199/article/details/148761379

相关推荐
猫天意1 小时前
YOLOv11魔改高效涨点 | 注意力篇 | 坐标注意力CoordAttention:将位置信息硬核嵌入通道,精准捕获长程空间依赖,即插即用,涨点神器!!!
开发语言·人工智能·深度学习·神经网络·yolo·目标检测·低光照增强
irizhao2 小时前
《高质量数据集 分类指南》解读(TC609-5-2025-03)由全国数据标准化技术委员会发布
大数据·人工智能
观无2 小时前
VisionPro 视觉检测工具基础知识点
人工智能·计算机视觉·视觉检测
min1811234562 小时前
HR人力资源招聘配置流程图制作教程
大数据·网络·人工智能·架构·流程图·求职招聘
ai_xiaogui2 小时前
Stable Diffusion Web UI 绘世版 v4.6.1 整合包:一键极速部署,深度解决 AI 绘画环境配置与 CUDA 依赖难题
人工智能·stable diffusion·环境零配置·高性能内核优化·全功能插件集成·极速部署体验
Elastic 中国社区官方博客2 小时前
使用 Elasticsearch 管理 agentic 记忆
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
升职佳兴2 小时前
从 0 到 1:我做了一个提升 AI 对话效率的浏览器插件(架构+实现+发布)
人工智能·架构
linmoo19862 小时前
Langchain4j 系列之二十二 - Embedding Models
人工智能·langchain·embedding·嵌入模型·langchain4j
三不原则2 小时前
实战:基于 GitOps 实现 AI 应用的自动化部署与发布
运维·人工智能·自动化