【RL】中Token级策略梯度损失

https://www.xiaoqiedun.com/posts/2025-03-18-dapo/

您提的问题非常精准,直击了GRPO(Group-wise Reward Policy Optimization)算法与标准PPO在实现上的一个核心区别。您的描述是完全正确的:GRPO在样本(序列)级别计算目标,它首先计算每个生成序列的平均损失,然后对所有样本的损失取平均。

这个机制主要体现在代码的两个地方:

  1. 损失聚合函数 agg_loss: 这是最直接的体现。
  2. ActorWorker.loss_func 中对 agg_loss 的调用: 这里决定了使用哪种聚合模式。

下面我们来详细解析。


1. 核心机制:agg_loss 函数中的 "seq-mean-token-mean" 模式

roll/utils/functionals.py 文件中,agg_loss 函数是实现这一逻辑的关键。让我们聚焦于 loss_agg_mode="seq-mean-token-mean" 这个分支:

python 复制代码
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str,
             weights: Optional[torch.Tensor] = None, loss_scale: Optional[float] = None):
    """
    ...
    Args:
        loss_mat: `(torch.Tensor)`
            shape: (bs, response_length)  <-- 这是每个token的损失矩阵
    ...
    """
    # ... 其他模式的代码 ...

    elif loss_agg_mode == "seq-mean-token-mean":
        # -------------------------------------------------------------
        # 第一步:计算每个生成序列内的平均损失 (seq-mean)
        # -------------------------------------------------------------
        # masked_mean 在 dim=-1 (序列维度) 上操作
        # 它计算 loss_mat 中每个序列(行)的有效token损失的平均值
        seq_losses = masked_mean(loss_mat, loss_mask, dim=-1) # (bs,)

        # 此时 seq_losses 是一个形状为 (bs,) 的张量,每个元素代表一个样本(序列)的平均损失
        
        # -------------------------------------------------------------
        # 第二步:对不同样本的损失取平均值 (token-mean or batch-mean)
        # -------------------------------------------------------------
        valid_samples = torch.any(loss_mask > 0, dim=-1).float() # 找出哪些是有效样本
        if weights is None:
            weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
        
        # 将所有样本的平均损失(seq_losses)加权求和,再除以有效样本的数量
        # 这就是对不同样本的损失取平均值
        loss = (seq_losses * weights * valid_samples).sum() / (valid_samples.sum() + 1e-8) # 标量

    # ... 其他模式的代码 ...
    
    return loss * loss_scale if loss_scale else loss

代码解读:

  • 输入 : loss_mat 是一个形状为 (batch_size, sequence_length) 的张量,其中 loss_mat[i, j] 代表第 i 个样本的第 j 个token的PPO损失。
  • 第一步 (对应您的描述 "首先计算每个生成序列内的平均损失") : seq_losses = masked_mean(loss_mat, loss_mask, dim=-1)masked_mean 函数沿着最后一个维度 (dim=-1,即序列维度) 计算加权平均。这会得到一个形状为 (batch_size,) 的张量 seq_losses,其中 seq_losses[i] 就是第 i 个生成序列所有有效token的平均损失。
  • 第二步 (对应您的描述 "然后对不同样本的损失取平均值") : loss = (seq_losses * ...).sum() / (valid_samples.sum() + 1e-8)。这行代码计算 seq_losses 张量的加权平均值,从而得到一个最终的标量损失值,用于反向传播。

2. 应用场景:ActorWorker.loss_func 中的配置调用

agg_loss 函数只是一个工具,它如何被使用取决于调用它的地方。在 ActorWorker.loss_func 中,这个工具被用来聚合最终的策略梯度损失(pg_loss)和KL损失等。

python 复制代码
class ActorWorker(BaseActorWorker):

    def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
        # ... (计算每个token的PPO损失 `loss`)
        loss = -torch.min(surr1, surr2) # loss 的形状是 (bs, seq_len)

        # ...

        # -------------------------------------------------------------
        # 这里通过配置决定使用哪种聚合模式
        # -------------------------------------------------------------
        weighted_pg_loss = agg_loss(loss_mat=loss, 
                                    loss_mask=final_response_mask,
                                    # 这个参数决定了是否执行GRPO的损失计算方式
                                    loss_agg_mode=self.pipeline_config.loss_agg_mode,
                                    weights=sample_weights, 
                                    loss_scale=loss_scale)
        # ...
        
        return total_loss, pg_metrics

关键点:

  • self.pipeline_config.loss_agg_mode 这个配置项是"开关"。
  • 当您在训练配置中设置 loss_agg_mode: "seq-mean-token-mean" 时,ActorWorker 就会指示 agg_loss 函数采用GRPO描述的聚合方式。

总结与对比

为了让这个概念更清晰,我们来对比一下GRPO的聚合方式和标准的PPO(通常按token平均)的聚合方式。

聚合方式 描述 agg_loss 中的模式 物理意义
标准PPO (Token-level) 将所有样本的所有有效token的损失放在一起,直接计算总平均值。 "token-mean" 每个token对最终损失的贡献是平等的。长序列会比短序列有更大的影响力。
GRPO (Sample-level) 先计算每个样本的平均token损失,再计算所有样本的平均损失。 "seq-mean-token-mean" 每个样本(序列)对最终损失的贡献是平等的,无论其长短。 这更符合GRPO的思想,即优化整个序列的质量,而不是单个token。

因此,**代码通过 agg_loss 函数中的 "seq-mean-token-mean" 逻辑,以及在 loss_func 中通过配置文件 self.pipeline_config.loss_agg_mode 来调用该逻辑,完美地体现了GRPO在样本级别计算目标的思想。**好的,我们来通过一个非常具体和详细的例子来解释 masked_mean(loss_mat, loss_mask, dim=-1) 的输入和输出形状。

假设我们的 batch_size 为 3,响应序列的最大长度(response_length)为 5。


输入:loss_matloss_mask

1. loss_mat: 损失矩阵 (Tensor)

loss_mat 是一个形状为 (3, 5) 的张量,代表每个样本在每个token位置上的损失。

复制代码
loss_mat = tensor([
    [0.1, 0.5, 0.2, 0.0, 0.0],  # 样本 1 的逐token损失
    [0.8, 0.3, 0.6, 0.4, 0.1],  # 样本 2 的逐token损失
    [0.9, 0.2, 0.0, 0.0, 0.0]   # 样本 3 的逐token损失
])

# 形状: torch.Size([3, 5])
  • loss_mat[0, 1] 的值是 0.5,代表第一个样本的第二个token的损失。
  • 我们看到有些位置的损失是0,这可能是因为这些位置是填充(padding)的,但 loss_mat 本身并不包含这个信息。哪个位置是有效的,由 loss_mask 决定。
2. loss_mask: 掩码矩阵 (Tensor)

loss_mask 也是一个形状为 (3, 5) 的张量,但它的值只有0和1。1 代表这个位置是有效的(是真实的token),0 代表这个位置是无效的(是padding)。

复制代码
loss_mask = tensor([
    [1, 1, 1, 0, 0],  # 样本 1: 有效长度为 3
    [1, 1, 1, 1, 1],  # 样本 2: 有效长度为 5
    [1, 1, 0, 0, 0]   # 样本 3: 有效长度为 2
], dtype=torch.float)

# 形状: torch.Size([3, 5])
  • 样本 1: 真实响应长度是3个token,后面2个是padding。
  • 样本 2: 真实响应长度是5个token,占满了整个序列长度。
  • 样本 3: 真实响应长度是2个token,后面3个是padding。

计算过程:masked_mean(loss_mat, loss_mask, dim=-1)

这个函数会沿着最后一个维度(dim=-1,即序列维度)进行计算。对于批次中的每一个样本(每一行),它会执行以下操作:

  1. 元素相乘 : 将 loss_matloss_mask 按元素相乘。这会把所有无效位置(mask为0)的损失清零。

    复制代码
    loss_mat * loss_mask = tensor([
        [0.1, 0.5, 0.2, 0.0, 0.0],  # 样本 1: [0.1*1, 0.5*1, 0.2*1, 0.0*0, 0.0*0]
        [0.8, 0.3, 0.6, 0.4, 0.1],  # 样本 2: [0.8*1, 0.3*1, 0.6*1, 0.4*1, 0.1*1]
        [0.9, 0.2, 0.0, 0.0, 0.0]   # 样本 3: [0.9*1, 0.2*1, 0.0*0, 0.0*0, 0.0*0]
    ])
  2. 求和: 对上一步结果的每一行(每个样本)进行求和。

    • 样本 1 和 : 0.1 + 0.5 + 0.2 = 0.8
    • 样本 2 和 : 0.8 + 0.3 + 0.6 + 0.4 + 0.1 = 2.2
    • 样本 3 和 : 0.9 + 0.2 = 1.1

    得到一个临时的和向量 sum_vec = tensor([0.8, 2.2, 1.1])

  3. 计算有效长度 : 对 loss_mask 的每一行(每个样本)进行求和,得到每个样本的有效token数量。

    • 样本 1 有效长度 : 1 + 1 + 1 = 3
    • 样本 2 有效长度 : 1 + 1 + 1 + 1 + 1 = 5
    • 样本 3 有效长度 : 1 + 1 = 2

    得到一个临时的长度向量 len_vec = tensor([3.0, 5.0, 2.0])

  4. 相除 : 将步骤2得到的和向量 sum_vec 与步骤3得到的长度向量 len_vec 按元素相除。

    • 样本 1 平均损失 : 0.8 / 3 = 0.2667
    • 样本 2 平均损失 : 2.2 / 5 = 0.44
    • 样本 3 平均损失 : 1.1 / 2 = 0.55

输出:seq_losses

经过上述计算,masked_mean 函数返回一个一维张量,其长度等于 batch_size。这个张量就是我们所说的 seq_losses

python 复制代码
seq_losses = masked_mean(loss_mat, loss_mask, dim=-1)

# seq_losses 的值:
# tensor([0.2667, 0.4400, 0.5500])

# 形状: torch.Size([3])

总结

  • 输入 loss_mat 形状 : (batch_size, sequence_length) -> (3, 5)
  • 输入 loss_mask 形状 : (batch_size, sequence_length) -> (3, 5)
  • 操作 : 沿着 dim=-1 (序列维度) 计算每个样本的平均损失。
  • 输出 seq_losses 形状 : (batch_size,) -> (3,)

这个输出 seq_losses 的每一个元素都代表了一个完整的生成序列(样本)的平均损失。接下来的步骤(如GRPO中的第二步)就是对这个 seq_losses 张量再做一次聚合(例如,取所有元素的平均值),从而得到最终的、用于反向传播的单个损失值。

好的,您提供的描述和公式非常清晰,它们精确地指向了**Token级平均(Token-level Mean)**的损失聚合方式。我们来详细解析这个概念以及它在代码中的具体实现。

概念解析

您提供的描述和公式(DAPO损失函数)完美地阐述了Token级策略梯度损失的核心思想:

  1. 公式解读:

    • Σ_{i=1}^{G} Σ_{t=1}^{|o_i|} min(...):这是一个双重求和。它首先遍历一个批次中的所有样本(i 从 1到 G),然后遍历每个样本中的所有token(t 从 1 到 |o_i|)。这表示我们将批次中每一个有效token 的PPO损失(即min(...)部分)全部加起来。
    • 1 / Σ_{i=1}^{G} |o_i|:这是归一化因子。Σ_{i=1}^{G} |o_i| 表示计算批次中所有样本的长度之和,即批次中有效token的总数
    • 整体 : 整个公式的含义是:计算批次中所有有效token的PPO损失的总和,然后除以有效token的总数,得到所有token的平均损失
  2. 与您的描述对应:

    • "较长序列相比较短序列对整体梯度更新的影响可能更大": 这是因为在Token级平均中,一个长度为100的序列会贡献100个token的损失项到总和中,而一个长度为10的序列只贡献10个。因此,长序列在计算总平均损失时有更大的"话语权"。
    • "从单个 Token 的角度来看...无论它出现在哪种长度的响应中,都将被同等地促进或抑制" : 这是因为每个token的损失min(...)在被加到总和之前,没有经过任何与序列长度相关的加权或归一化。它的值只取决于它自身的ratioadvantage。所有token最终被"一视同仁"地扔进一个大池子里求平均。

代码实现

这种Token级平均的损失计算方式在您的代码中对应的是 agg_loss 函数的 "token-mean" 模式,以及 masked_mean 函数在没有指定 dim 参数时的行为。

1. 核心代码:agg_lossmasked_mean

roll/utils/functionals.py 文件中:

python 复制代码
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, ...):
    ...
    # 当配置为 "token-mean" 时,会进入这个分支
    if loss_agg_mode == "token-mean":
        if weights is None:
            weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
        # 关键调用:这里调用 masked_mean 时没有传递 dim 参数
        loss = masked_mean(loss_mat * weights.unsqueeze(-1), loss_mask)
    ...
    return loss

现在,我们看 masked_mean 函数在 dim=None 时的实现:

python 复制代码
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
    if dim is not None:
        # 这是我们之前讨论的 GRPO (seq-mean) 的情况
        ...
    else:
        # 这就是 Token-level mean 的实现!
        # (tensor * mask).sum()  ->  对应公式的 Σ_i Σ_t min(...)
        # (mask.sum())          ->  对应公式的 Σ_i |o_i|
        return (tensor * mask).sum() / (mask.sum() + 1e-8)

代码解读:

  • ActorWorker.loss_func 调用 agg_loss 并将 loss_agg_mode 设置为 "token-mean" 时,它最终会调用 masked_mean 且不带 dim 参数。
  • (tensor * mask).sum(): 这里的 tensor 就是PPO损失矩阵 loss_mat.sum() 在没有维度参数时执行全局求和 ,将整个张量中所有元素的值加在一起。这完全对应了公式中的 Σ_{i=1}^{G} Σ_{t=1}^{|o_i|},即把所有有效token的损失加起来。
  • mask.sum(): 同样,这也是一个全局求和,计算 mask 张量中所有 1 的数量,这正好是批次中有效token的总数,对应公式中的 Σ_{i=1}^{G} |o_i|
  • 最后的除法操作,完成了求平均的过程。
2. 举例说明

我们再次使用之前的例子 (batch_size=3, sequence_length=5):

  • loss_mat (形状 [3, 5])
  • loss_mask (形状 [3, 5])

计算过程 masked_mean(loss_mat, loss_mask):

  1. loss_mat * loss_mask 得到掩码后的损失矩阵:

    复制代码
    tensor([
        [0.1, 0.5, 0.2, 0.0, 0.0],
        [0.8, 0.3, 0.6, 0.4, 0.1],
        [0.9, 0.2, 0.0, 0.0, 0.0]
    ])
  2. .sum() (全局求和):
    0.1 + 0.5 + 0.2 + 0.8 + 0.3 + 0.6 + 0.4 + 0.1 + 0.9 + 0.2 = 4.1

    这就是所有有效token的损失总和

  3. mask.sum() (全局求和):
    loss_mask1 的总数是 3 (样本1) + 5 (样本2) + 2 (样本3) = 10

    这就是批次中有效token的总数

  4. 最终结果:
    loss = 4.1 / 10 = 0.41

这是一个单一的标量值,代表了整个批次中所有token的平均损失。在这个计算中,长度为5的样本2贡献了5个损失项,而长度为2的样本3只贡献了2个,因此样本2的影响力更大。

总结

您描述的Token级策略梯度损失在所提供的代码中是通过以下组合实现的:

  1. 在训练配置中设置 loss_agg_mode: "token-mean"
  2. ActorWorker.loss_func 调用 agg_loss 函数。
  3. agg_loss 进而调用 masked_mean 函数,且不指定 dim 参数
  4. masked_mean 函数内部执行 (tensor * mask).sum() / (mask.sum() + 1e-8),这在数学上完全等价于您提供的DAPO损失函数中的聚合方式。

好的,我们来详细举例说明 loss_mat * weights.unsqueeze(-1) 这个操作。这个操作的目的是对每个样本(序列)的损失进行加权

假设我们有一个批次(batch_size=3),响应序列的最大长度为 5


输入

1. loss_mat: 损失矩阵 (Tensor)

和之前的例子一样,loss_mat 是一个形状为 (3, 5) 的张量,代表每个token的损失。

python 复制代码
import torch

loss_mat = torch.tensor([
    [0.1, 0.5, 0.2, 0.0, 0.0],  # 样本 1 的逐token损失
    [0.8, 0.3, 0.6, 0.4, 0.1],  # 样本 2 的逐token损失
    [0.9, 0.2, 0.0, 0.0, 0.0]   # 样本 3 的逐token损失
])

# loss_mat 形状: torch.Size([3, 5])
2. weights: 样本权重向量 (Tensor)

weights 是一个一维张量,其长度等于 batch_size。每个元素代表对应样本的权重。这个权重可能基于样本的难度、长度或其他自定义逻辑计算得出。

python 复制代码
weights = torch.tensor([
    0.8,  # 样本 1 的权重
    1.5,  # 样本 2 的权重 (可能更重要)
    0.7   # 样本 3 的权重 (可能不那么重要)
])

# weights 形状: torch.Size([3])

计算过程

步骤 1: weights.unsqueeze(-1) - 扩展维度

这个操作会在weights张量的最后一个维度(-1)增加一个新的维度。这是一种**广播(Broadcasting)**机制的准备步骤。

python 复制代码
weights_expanded = weights.unsqueeze(-1)

# weights_expanded 的值:
# tensor([
#     [0.8],
#     [1.5],
#     [0.7]
# ])

# weights_expanded 形状: torch.Size([3, 1])

现在,weights_expanded 是一个列向量。

步骤 2: loss_mat * weights_expanded - 广播乘法

现在我们将形状为 (3, 5)loss_mat 与形状为 (3, 1)weights_expanded 相乘。

PyTorch的广播机制会起作用:

  1. 比较两个张量的维度,从后往前看。
  2. loss_mat: (3, 5)
  3. weights_expanded: (3, 1)
  4. PyTorch发现 weights_expanded 的第二个维度是1,而 loss_mat 的第二个维度是5。它会自动将 weights_expanded 在第二个维度上"复制"5次,使其形状在逻辑上变为 (3, 5)

逻辑上扩展后的 weights_expanded 看起来像这样:

复制代码
tensor([
    [0.8, 0.8, 0.8, 0.8, 0.8],
    [1.5, 1.5, 1.5, 1.5, 1.5],
    [0.7, 0.7, 0.7, 0.7, 0.7]
])

现在,进行元素级别的乘法:

python 复制代码
weighted_loss_mat = loss_mat * weights_expanded

# weighted_loss_mat 的值:
# tensor([
#     [0.1*0.8, 0.5*0.8, 0.2*0.8, 0.0*0.8, 0.0*0.8],  # 样本1的所有token损失都乘以0.8
#     [0.8*1.5, 0.3*1.5, 0.6*1.5, 0.4*1.5, 0.1*1.5],  # 样本2的所有token损失都乘以1.5
#     [0.9*0.7, 0.2*0.7, 0.0*0.7, 0.0*0.7, 0.0*0.7]   # 样本3的所有token损失都乘以0.7
# ])

# 计算结果:
# tensor([
#     [0.08, 0.40, 0.16, 0.00, 0.00],
#     [1.20, 0.45, 0.90, 0.60, 0.15],
#     [0.63, 0.14, 0.00, 0.00, 0.00]
# ])

# weighted_loss_mat 形状: torch.Size([3, 5])

总结与目的

  • 输入 loss_mat 形状 : (batch_size, sequence_length) -> (3, 5)
  • 输入 weights 形状 : (batch_size,) -> (3,)
  • weights.unsqueeze(-1) 后形状 : (batch_size, 1) -> (3, 1)
  • 操作 : 通过广播,将每个样本的权重(一个标量)乘到该样本对应的整行(所有token)的损失上。
  • 输出 weighted_loss_mat 形状 : (batch_size, sequence_length) -> (3, 5)

这个操作的核心目的 :在计算总损失之前,根据预先定义的样本重要性(weights),调整每个样本内所有token损失的相对大小。权重高的样本,其所有token的损失都会被放大,从而在后续的梯度计算中产生更大的影响;反之,权重低的样本,其影响力会被削弱。

这个加权后的 weighted_loss_mat 随后会作为 masked_mean 的输入,进行最终的损失聚合。

相关推荐
bing.shao9 小时前
AI在电商上架图片领域的应用
开发语言·人工智能·golang
百家方案9 小时前
“十五五”智慧文旅解决方案:以科技为核心,开启沉浸体验与高效治理新篇章
大数据·人工智能·智慧文旅·智慧旅游
●VON9 小时前
绿色 AI:让智能计算与地球共生
人工智能·学习·安全·制造·von
鲨莎分不晴9 小时前
注意力的本质:信息加权而已
人工智能
万俟淋曦9 小时前
【论文速递】2025年第52周(Dec-21-27)(Robotics/Embodied AI/LLM)
人工智能·深度学习·机器学习·机器人·大模型·论文·具身智能
专注数据的痴汉9 小时前
「数据获取」吉林地理基础数据(道路、水系、四级行政边界、地级城市、DEM等)
大数据·人工智能·信息可视化
dagouaofei9 小时前
AI 生成 2026 年工作计划 PPT,内容质量差异在哪里
人工智能·python·powerpoint
ai_top_trends9 小时前
2026 年工作计划汇报 PPT:AI 生成方案实测对比
人工智能·python·powerpoint
创作者mateo9 小时前
PyTorch 入门学习笔记(实战篇)二
pytorch·笔记·学习