【RL】op_compute_log_probs 计算过程

cpp 复制代码
    def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
        """
        loss func接口定义:
            data: DataProto, 由train_step透传
            output_tensor: torch.Tensor, model.forward()的输出Tensor
        """

        response_mask = data.batch["response_mask"][:, 1:].long()
        ref_log_probs = data.batch["ref_log_probs"]
        old_log_probs = data.batch["old_log_probs"]
        advantages = data.batch["advantages"]

        log_probs = self.strategy.op_compute_log_probs(
            logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"]
        )

        ratio = (log_probs - old_log_probs).exp()

        pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
        pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip  
        surr1 = ratio * advantages
        surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages
        pg_loss = -torch.min(surr1, surr2)
        if self.pipeline_config.dual_clip_loss:
            dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages)
            pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss)

        pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode)

        kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask,
                                    kl_penalty="k3")
        kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode)

        approxkl = compute_approx_kl(
            log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse"
        )
        policykl = compute_approx_kl(
            log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl"
        )
        clipped_low = (ratio < 1 - pg_clip_low).float()
        clipped_high = (ratio > 1 + pg_clip_high).float()
        clipped = (clipped_low + clipped_high).float()

        if self.pipeline_config.use_kl_loss:
            total_loss = pg_loss + kl_loss * self.pipeline_config.kl_loss_coef
        else:
            total_loss = pg_loss
        if self.pipeline_config.entropy_loss_coef > 0:
            entropy = self.strategy.op_compute_entropy(logits=output_tensor, attention_mask=data.batch["response_mask"])
            entropy_loss = agg_loss(
                loss_mat=entropy,
                loss_mask=response_mask,
                loss_agg_mode=self.pipeline_config.loss_agg_mode,
            )
            total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef

        pg_metrics = {
            "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(),
            "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(),
            "actor/ppo_ratio_clipfrac": clipped.mean().detach().item(),
            "actor/ratio_mean": masked_mean(ratio, response_mask, dim=-1).mean().detach().item(),
            "actor/ratio_max": torch.max(ratio * response_mask).detach().item(),
            "actor/ratio_min": torch.min(ratio * response_mask + (1 - response_mask) * 1e10).detach().item(),
            "actor/clipfrac": agg_loss(loss_mat=torch.lt(surr2, surr1).float(), loss_mask=response_mask,
                                       loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),
            "actor/pg_loss": pg_loss.detach().item(),
            "actor/kl_loss": kl_loss.detach().item(),
            "actor/total_loss": total_loss.detach().item(),
            "actor/approxkl": agg_loss(loss_mat=approxkl, loss_mask=response_mask,
                                       loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),
            "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask,
                                       loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),
        }


        return total_loss, pg_metrics

你提出了一个非常好的问题,这个问题涉及到模型输入、输出和损失计算之间如何对齐的关键细节。

data.batch["response_mask"][:, 1:].long() 这么做的原因是为了 response_mask 与模型的输出 logits 以及需要计算损失的目标 log_probs 在序列长度维度上对齐

让我们一步步来分解这个原因。

1. 语言模型的输入和输出

假设我们有一个输入序列 input_ids

[BOS, "你好", "吗", "?", EOS]

  • BOS: Begin of Sentence token (起始符)
  • EOS: End of Sentence token (结束符)

当这个序列被送入一个自回归语言模型(比如 GPT)时,模型的工作方式是 预测下一个 token

  • 输入 [BOS],模型预测 "你好"
  • 输入 [BOS, "你好"],模型预测 "吗"
  • 输入 [BOS, "你好", "吗"],模型预测 "?"
  • 输入 [BOS, "你好", "吗", "?"],模型预测 EOS

因此,如果输入序列的长度是 L,那么模型输出的 logits(预测每个位置的下一个 token 的概率分布)的序列长度通常也是 L

但是,logits 在位置 i 的输出,是基于 input_ids 在位置 0i 的输入得到的,它是用来预测 input_ids 在位置 i+1 的 token 的。

看下图示:

复制代码
input_ids:   [ token_0, token_1, token_2, token_3 ]  (长度 L=4)
             /        /        /        /
logits:      [ logit_0, logit_1, logit_2, logit_3 ]  (长度 L=4)
               |        |        |        |
预测目标:      (token_1) (token_2) (token_3) (token_4)

2. 计算损失时的对齐问题

我们要计算的损失,是模型预测的 log_probs 和真实 token 之间的交叉熵。

  • logit_0 用来预测 token_1
  • logit_1 用来预测 token_2
  • logit_2 用来预测 token_3

我们注意到,logit 的最后一个位置 logit_3 是用来预测 token_4 的,但我们的输入序列里没有 token_4。因此,logit_3 在计算损失时通常是被忽略的。

反过来看,token_0 (通常是 BOSPAD) 从来不作为被预测的目标,所以与它对应的 logit 也不需要。

所以,为了计算损失,我们需要:

  • logits : 取 [logit_0, logit_1, logit_2],即 logits[:, :-1, :]
  • 目标 input_ids : 取 [token_1, token_2, token_3],即 input_ids[:, 1:]

两者现在长度都是 L-1,完美对齐。

3. response_mask 的对齐

现在回到你的问题 response_mask

response_mask 的原始形状与 input_ids 是一样的,长度为 L。它标记了哪些 token 属于需要计算损失的 "response" 部分。

例如,对于输入 [prompt_1, prompt_2, resp_1, resp_2, resp_3, EOS]

原始 response_mask 可能是 [0, 0, 1, 1, 1, 1]

当我们计算损失时,我们用的是 logits[:, :-1]input_ids[:, 1:]。这意味着我们实际上是在处理一个长度为 L-1 的序列。

为了让 response_mask 能够正确地屏蔽掉这个长度为 L-1 的序列中不需要计算损失的部分,它自身也需要被截断,以匹配新的序列长度。

我们截断的方式必须和目标 input_ids 的截断方式一样,即 从第二个元素开始取,去掉第一个元素

  • 原始 response_mask (长度 L): [mask_0, mask_1, mask_2, ..., mask_{L-1}]
  • 截断后的 response_mask (长度 L-1): [mask_1, mask_2, ..., mask_{L-1}]
    • 这个操作就是 response_mask[:, 1:]

这样,截断后的 mask_1 就会对应 logit_0 和 目标 token_1mask_2 对应 logit_1 和 目标 token_2,以此类推,对齐关系就完全正确了。

总结

data.batch["response_mask"][:, 1:].long() 这行代码的目的是:

  1. [:, 1:] : 为了与用于计算损失的 logits[:, :-1] 和目标 input_ids[:, 1:] 在序列长度上保持对齐。它通过移除 mask 的第一个元素来实现这一点。
  2. .long() : 将 mask 的数据类型转换为长整型 (torch.int64)。这通常是为了兼容某些 PyTorch 操作,比如用 mask 进行索引或者作为某些损失函数的权重时,可能要求是整型而不是布尔型或浮点型。

这是一个在实现语言模型训练时非常常见且至关重要的细节,确保了计算的正确性。

好的,我们来详细分析你提供的这个特定版本的 op_compute_log_probs 函数,并用一个具体的例子来走一遍它的执行过程。

这个版本的实现比我之前给出的通用版本更简洁,它依赖一个名为 log_probs_from_logits 的辅助函数(HuggingFace TRL 库中常见)。我们假设 log_probs_from_logits 的功能就是我之前描述的第 2 步(Log-Softmax)和第 3 步(Gather)的结合。

函数代码分析

python 复制代码
# 假设 log_probs_from_logits 的实现如下:
def log_probs_from_logits(logits, labels):
    log_probs = F.log_softmax(logits, dim=-1)
    return torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

class YourClass:
    def op_compute_log_probs(self, logits: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        """
        logits: llm logits, 形状 [batch_size, seq_len, vocab_size]
        input_ids [[p, p, r, r, r, 0, 0]], 形状 [batch_size, seq_len]
        attention_mask(response_mask) [[0, 0, 1, 1, 1, 0, 0]], 形状 [batch_size, seq_len]
        """
        # 1. 准备 Labels,并处理无效 Token
        labels: torch.Tensor = input_ids[:, 1:].clone()
        # 将 mask 为 0 的位置的 label 设置为 0
        labels[attention_mask[:, 1:] == 0] = 0  

        # 2. 计算 Log Probs
        # 传入错位对齐的 logits 和处理过的 labels
        log_probs = log_probs_from_logits(logits[:, :-1], labels)

        # 3. 应用 Mask
        # 将不属于 response 的位置的 log_probs 清零
        log_probs = log_probs * attention_mask[:, 1:]

        return log_probs

核心逻辑点

  1. labels[attention_mask[:, 1:] == 0] = 0 : 这是这个实现中最有趣和最关键的一步。它的目的是防止 log_probs_from_logits 访问到无效的 token ID
    • 在 PPO 训练中,input_ids 可能包含 PAD token(其 ID 通常是 0)。
    • 如果一个 PAD token 出现在 labels 中,torch.gather 会尝试去访问词汇表索引为 0 的位置。这本身没问题。
    • 但更重要的是,对于 prompt 部分和 padding 部分,我们根本不关心它们的 log_probs,因为它们不会计入最终的损失。将这些位置的 label 统一设置为 0,可以简化计算。虽然 gather 仍然会为这些位置计算一个值(即词汇表中 token 0 的对数概率),但这没关系,因为 第 3 步会把这些位置的 log_probs 全部清零。这是一个"先计算再丢弃"的策略。

举例说明执行过程

假设有以下微型配置:

  • batch_size = 1
  • seq_len = 7
  • vocab_size = 50000
  • PAD_TOKEN_ID = 0

输入:

  • input_ids : [[101, 102, 201, 202, 203, 0, 0]]
    • [p, p, r, r, r, pad, pad]
  • attention_mask (response_mask) : [[0, 0, 1, 1, 1, 0, 0]]
  • logits : 一个由模型生成的 [1, 7, 50000] 的张量。

执行步骤:

第 1 步: 准备 labels
  1. labels = input_ids[:, 1:].clone()

    • input_ids[:, 1:] 得到 [[102, 201, 202, 203, 0, 0]]
    • labels 的值现在是 [[102, 201, 202, 203, 0, 0]],形状 [1, 6]
  2. 计算 mask for labels

    • attention_mask[:, 1:] 得到 [[0, 1, 1, 1, 0, 0]]
  3. labels[attention_mask[:, 1:] == 0] = 0

    • attention_mask[:, 1:] == 0 会产生一个布尔掩码 [[True, False, False, False, True, True]]
    • 这个掩码会选中 labels 中需要被修改的位置:
      • labels 的第 0 个元素 (对应 prompt 部分)
      • labels 的第 4 个元素 (对应第一个 pad)
      • labels 的第 5 个元素 (对应第二个 pad)
    • labels 被原地修改,修改后的值为:[[0, 201, 202, 203, 0, 0]]
      • 注意:原来的 102 变成了 0

    至此,labels 准备完毕,值为 [[0, 201, 202, 203, 0, 0]]

第 2 步: 计算 log_probs
  1. 准备 logits

    • logits[:, :-1] 得到一个 [1, 6, 50000] 的张量。
  2. 调用 log_probs_from_logits(logits[:, :-1], labels)

    • log_probs_from_logits 内部会:
      a. 对 logits[:, :-1] 在最后一个维度上做 log_softmax
      b. 使用 labels [[0, 201, 202, 203, 0, 0]] 作为索引,通过 torch.gatherlog_softmax 的结果中提取值。

    • log_probs 的计算结果(形状为 [1, 6])会是:

      复制代码
      [[
          logP(token=0 | p),          // prompt 部分,计算了 pad token 的 log_prob
          logP(token=201 | p,p),      // response 部分,正确
          logP(token=202 | p,p,r),    // response 部分,正确
          logP(token=203 | p,p,r,r),  // response 部分,正确
          logP(token=0 | p,p,r,r,r),  // padding 部分,计算了 pad token 的 log_prob
          logP(token=0 | p,p,r,r,r,0) // padding 部分,计算了 pad token 的 log_prob
      ]]
    • 假设计算出的值为 [[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]]

第 3 步: 应用 Mask
  1. log_probs = log_probs * attention_mask[:, 1:]
    • log_probs: [[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]]

    • attention_mask[:, 1:]: [[0, 1, 1, 1, 0, 0]]

    • 两者进行逐元素相乘:

      复制代码
        [ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]
      * [    0,    1,    1,    1,    0,    0 ]
      ----------------------------------------
      = [    0, -1.1, -0.8, -1.5,    0,    0 ]

最终返回结果

函数返回的 log_probs 张量为: [[ 0, -1.1, -0.8, -1.5, 0, 0 ]]

这个结果非常完美:

  • 只有 response 部分 (r, r, r) 的 log_probs 被保留了下来。
  • prompt 部分和 padding 部分的 log_probs 都被清零了
  • 这个张量可以直接用于后续的损失计算(例如计算 ratio = (log_probs - old_log_probs).exp()),而不需要担心无效位置的干扰。

这个实现方式非常高效和简洁,巧妙地利用了 PyTorch 的广播和掩码操作。

好的,我们来详细拆解这两行代码,它们是PyTorch中一个非常常见且强大的模式,用于从概率分布中根据标签提取特定的值。

这两行代码的核心是 torch.gather 函数。

torch.gather 的工作原理

torch.gather 就像一个高级的索引工具。它的作用是:沿着指定的维度(dim),根据 index 张量中的值,从输入张量(input)中收集元素。

它的签名是 torch.gather(input, dim, index)

为了让它工作,index 张量需要满足一些条件,最重要的是:

  • index 的维度数量必须和 input 的维度数量相同。
  • 在所有 dim 的维度上,index 的大小必须和 input 的大小相同(或者为 1,可以广播)。

结合你的代码进行分解

我们一步步来看:

python 复制代码
# 假设我们有以下张量(以 batch_size=1, seq_len=3, vocab_size=5 为例)

# log_probs: [1, 3, 5] 的张量,代表了3个位置上,每个词的对数概率
log_probs = torch.tensor([[
    [-1.6, -2.1, -0.9, -3.0, -1.8],  # 位置0的 log_probs
    [-0.5, -1.1, -2.5, -1.3, -4.0],  # 位置1的 log_probs
    [-3.2, -1.9, -1.0, -2.2, -0.8]   # 位置2的 log_probs
]])

# labels: [1, 3] 的张量,代表了3个位置上,正确的 token ID
labels = torch.tensor([[2, 0, 4]])

第 1 步: labels.unsqueeze(-1)
  • 目的 : 增加一个维度,使 labels 的维度数量与 log_probs 相同,从而满足 torch.gather 的要求。
  • 输入 labels :
    • 形状: [1, 3]
    • 值: [[2, 0, 4]]
  • 操作 : unsqueeze(-1) 在最后一个维度(维度索引为-1)上增加一个大小为 1 的新维度。
  • 输出 index :
    • 形状: [1, 3, 1]

    • 值:

      复制代码
      [[[2],
        [0],
        [4]]]

现在,log_probs (3D) 和 index (3D) 的维度数量相同了。


第 2 步: log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
  • input : log_probs (形状 [1, 3, 5])
  • dim : -1 (或 2)。这意味着我们将在最后一个维度------**词汇表维度(vocab_size)**上进行收集。
  • index : labels.unsqueeze(-1) (形状 [1, 3, 1])

gather 的执行过程 (可以想象成一个 for 循环):

  1. gather 会遍历 index 张量的所有位置。
  2. 对于 index 中的每个元素 (batch_idx, seq_idx, 0),它会取出其中的值 v = index[batch_idx, seq_idx, 0]
  3. 然后,它会在 log_probs 张量的对应位置 (batch_idx, seq_idx, ...) 上,沿着 dim=-1 收集索引为 v 的元素。
  4. 它将收集到的值放在输出张量的 (batch_idx, seq_idx, 0) 位置。

让我们手动走一遍:

  • 处理 index[0, 0, 0] 位置:

    • index[0, 0, 0] 的值是 2
    • gatherlog_probs[0, 0, :] 位置,也就是 [-1.6, -2.1, -0.9, -3.0, -1.8]
    • 它从这个向量中取出索引为 2 的元素,即 -0.9
    • 输出张量的 [0, 0, 0] 位置被设置为 -0.9
  • 处理 index[0, 1, 0] 位置:

    • index[0, 1, 0] 的值是 0
    • gatherlog_probs[0, 1, :] 位置,也就是 [-0.5, -1.1, -2.5, -1.3, -4.0]
    • 它从这个向量中取出索引为 0 的元素,即 -0.5
    • 输出张量的 [0, 1, 0] 位置被设置为 -0.5
  • 处理 index[0, 2, 0] 位置:

    • index[0, 2, 0] 的值是 4
    • gatherlog_probs[0, 2, :] 位置,也就是 [-3.2, -1.9, -1.0, -2.2, -0.8]
    • 它从这个向量中取出索引为 4 的元素,即 -0.8
    • 输出张量的 [0, 2, 0] 位置被设置为 -0.8

gather 的输出 log_probs_labels:

  • 形状: [1, 3, 1] (与 index 的形状相同)

  • 值:

    复制代码
    [[[-0.9],
      [-0.5],
      [-0.8]]]

直观理解 : 对于序列中的每个位置,我们都从完整的词汇表概率分布中,只挑选出了正确标签(label)对应的那个对数概率


第 3 步: .squeeze(-1)
  • 目的: 移除多余的、大小为 1 的维度,让张量更易于处理。
  • 输入 log_probs_labels :
    • 形状: [1, 3, 1]
  • 操作 : squeeze(-1) 移除最后一个维度(因为它的大小是 1)。
  • 输出 :
    • 形状: [1, 3]
    • 值: [[-0.9, -0.5, -0.8]]

最终结果

函数最终返回了一个 [1, 3] 的张量 [[-0.9, -0.5, -0.8]]

这个张量的每个元素 output[i, j] 都代表了在批次 i 的序列位置 j,模型赋予正确 label 的对数概率。这正是我们计算交叉熵损失或 PPO 损失时所需要的核心数值。

相关推荐
GISer_Jing2 小时前
Taro打造电商项目实战
前端·javascript·人工智能·aigc·taro
liliangcsdn2 小时前
SD稳定扩散模型理论基础的探索
人工智能·机器学习
Hiweir ·2 小时前
ROS Noetic教程------VSCode创建ROS话题通讯--发布方的简单流程
vscode·python·ros noetic·ros noetic教程
deephub2 小时前
dLLM:复用自回归模型权重快速训练扩散语言模型
人工智能·python·语言模型·大语言模型
中國龍在廣州2 小时前
2025,具身智能正在惩罚“持有者”
人工智能·深度学习·算法·自然语言处理·chatgpt
阿部多瑞 ABU2 小时前
第五章:林心
人工智能·ai·ai写作
itwangyang5202 小时前
AIDD-人工智能药物设计-字节跳动 PXDesign:AI 设计蛋白,82% 命中率惊艳业界
人工智能·python
Biehmltym2 小时前
【AI】01开发环境:Conda_python包/环境管理,10分钟上手
开发语言·python·conda
ToB营销学堂2 小时前
百格活动《大型会议管理 & 执行指南》:大型会议如何从“事务交付”向“资产运营”转型?
人工智能·大会管理系统·会议活动执行·活动管理·大会执行·会议策划