强化学习/对齐(个人理解)

Bradley-Terry 奖励模型

含义:给定选中和拒绝响应的隐藏状态,将其投影为标量奖励并计算偏好损失。

python 复制代码
def reward_model_loss(chosen_hidden, rejected_hidden, reward_head):
    r_chosen = (chosen_hidden @ reward_head).squeeze(-1)     # (B,)
    r_rejected = (rejected_hidden @ reward_head).squeeze(-1) # (B,)
    margin = r_chosen - r_rejected
    # manual log-sigmoid: log(1/(1+exp(-x))) = -log(1+exp(-x))
    loss = -torch.log(1.0 / (1.0 + torch.exp(-margin))).mean()
    return loss
  1. loss = -torch.log(1.0 / (1.0 + torch.exp(-margin))).mean()成对损失通常使用 log-sigmoid 形式,等同于二元交叉熵损失

DPO损失

含义: 无需强化学习即可将语言模型与人类偏好对齐,使用配对的选中/拒绝对数概率。。

python 复制代码
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
             ref_chosen_logps, ref_rejected_logps, beta=0.1):
    chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
    rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
    diff = chosen_rewards - rejected_rewards
    return -torch.log(torch.sigmoid(diff)).mean()
  1. 参考模型是为了防止模型偏离初始语言能力,避免退化。
  2. policy_chosen_logps这些参数是整个对话的对数似然,取每个token 的 log probs 求和

GRPO损失

含义:每个提示组内归一化奖励以计算优势值,然后使用这些组相对优势优化策略。

python 复制代码
def grpo_loss(logps: Tensor, rewards: Tensor, group_ids: Tensor,
              eps: float = 1e-5) -> Tensor:
    """Group Relative Policy Optimization (GRPO) loss.

    logps: (B,) policy log-probs for each sampled response
    rewards: (B,) scalar rewards for each response
    group_ids: (B,) integers, same id = same prompt/group
    returns: scalar loss (Tensor)
    """
    # Compute per-group normalized advantages A_i
    unique_ids = group_ids.unique()
    advantages = torch.empty_like(rewards)
    for gid in unique_ids:
        mask = group_ids == gid
        r_g = rewards[mask]
        mean_g = r_g.mean()
        std_g = r_g.std(unbiased=False)
        advantages[mask] = (r_g - mean_g) / (std_g + eps)

    # Stop gradient through advantages
    advantages_detached = advantages.detach()

    # GRPO objective: -E[A_i * logpi_i]
    return -(advantages_detached * logps).mean()
  1. 在反向传播时,不通过优势值回传梯度。优势值被视为"常数"或"目标",只用于加权策略梯度。
  2. 无需 Critic 网络。传统 PPO 需要训练一个价值网络(critic)来估计优势,GRPO 用组内统计量替代,简化架构。
  3. 同一 prompt 的多个回答相互比较,消除 prompt 难度差异带来的偏差。

PPO 损失

含义:通过裁剪重要性采样比率来约束策略更新,防止强化学习中的破坏性大幅更新。

python 复制代码
def ppo_loss(new_logps: Tensor, old_logps: Tensor, advantages: Tensor,
             clip_ratio: float = 0.2) -> Tensor:
    """PPO clipped surrogate loss.

    new_logps: (B,) current policy log-probs
    old_logps: (B,) old policy log-probs (treated as constant)
    advantages: (B,) advantage estimates (treated as constant)
    returns: scalar loss (Tensor)
    """
    # Detach old_logps and advantages so gradients only flow through new_logps
    old_logps_detached = old_logps.detach()
    adv_detached = advantages.detach()

    # Importance sampling ratio r = pi_new / pi_old in log-space
    ratios = torch.exp(new_logps - old_logps_detached)

    # Unclipped and clipped objectives
    unclipped = ratios * adv_detached
    clipped = torch.clamp(ratios, 1.0 - clip_ratio, 1.0 + clip_ratio) * adv_detached

    # PPO objective: negative mean of the more conservative objective
    return -torch.min(unclipped, clipped).mean()
  1. 通过裁剪比率,防止单次更新步长过大,避免策略崩溃。
相关推荐
qq_330037992 小时前
如何清洗SQL输入数据_使用框架内置的ORM处理数据交互
jvm·数据库·python
一叶之秋14122 小时前
哈希密钥:解锁unordered容器的极速潜能
开发语言·c++·哈希算法
t***5442 小时前
如何在Dev-C++中设置Clang编译参数
开发语言·c++
csbysj20203 小时前
PHP If...Else 语句详解
开发语言
sinat_383437363 小时前
Laravel 8 中实现错误日志与调试日志分离的完整配置指南
jvm·数据库·python
清水白石00810 小时前
Python 编程实战全景:从基础语法到插件架构、异步性能与工程最佳实践
开发语言·python·架构
yaoxin52112311 小时前
390. Java IO API - WatchDir 示例
java·前端·python
武帝为此11 小时前
【数据清洗缺失值处理】
python·算法·数学建模
zhangchaoxies12 小时前
如何在 Go 中安全复制接口指针所指向的值
jvm·数据库·python