https://www.xiaoqiedun.com/posts/2025-03-18-dapo/
您提的问题非常精准,直击了GRPO(Group-wise Reward Policy Optimization)算法与标准PPO在实现上的一个核心区别。您的描述是完全正确的:GRPO在样本(序列)级别计算目标,它首先计算每个生成序列的平均损失,然后对所有样本的损失取平均。
这个机制主要体现在代码的两个地方:
- 损失聚合函数
agg_loss: 这是最直接的体现。 ActorWorker.loss_func中对agg_loss的调用: 这里决定了使用哪种聚合模式。
下面我们来详细解析。
1. 核心机制:agg_loss 函数中的 "seq-mean-token-mean" 模式
在 roll/utils/functionals.py 文件中,agg_loss 函数是实现这一逻辑的关键。让我们聚焦于 loss_agg_mode="seq-mean-token-mean" 这个分支:
python
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str,
weights: Optional[torch.Tensor] = None, loss_scale: Optional[float] = None):
"""
...
Args:
loss_mat: `(torch.Tensor)`
shape: (bs, response_length) <-- 这是每个token的损失矩阵
...
"""
# ... 其他模式的代码 ...
elif loss_agg_mode == "seq-mean-token-mean":
# -------------------------------------------------------------
# 第一步:计算每个生成序列内的平均损失 (seq-mean)
# -------------------------------------------------------------
# masked_mean 在 dim=-1 (序列维度) 上操作
# 它计算 loss_mat 中每个序列(行)的有效token损失的平均值
seq_losses = masked_mean(loss_mat, loss_mask, dim=-1) # (bs,)
# 此时 seq_losses 是一个形状为 (bs,) 的张量,每个元素代表一个样本(序列)的平均损失
# -------------------------------------------------------------
# 第二步:对不同样本的损失取平均值 (token-mean or batch-mean)
# -------------------------------------------------------------
valid_samples = torch.any(loss_mask > 0, dim=-1).float() # 找出哪些是有效样本
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
# 将所有样本的平均损失(seq_losses)加权求和,再除以有效样本的数量
# 这就是对不同样本的损失取平均值
loss = (seq_losses * weights * valid_samples).sum() / (valid_samples.sum() + 1e-8) # 标量
# ... 其他模式的代码 ...
return loss * loss_scale if loss_scale else loss
代码解读:
- 输入 :
loss_mat是一个形状为(batch_size, sequence_length)的张量,其中loss_mat[i, j]代表第i个样本的第j个token的PPO损失。 - 第一步 (对应您的描述 "首先计算每个生成序列内的平均损失") :
seq_losses = masked_mean(loss_mat, loss_mask, dim=-1)。masked_mean函数沿着最后一个维度 (dim=-1,即序列维度) 计算加权平均。这会得到一个形状为(batch_size,)的张量seq_losses,其中seq_losses[i]就是第i个生成序列所有有效token的平均损失。 - 第二步 (对应您的描述 "然后对不同样本的损失取平均值") :
loss = (seq_losses * ...).sum() / (valid_samples.sum() + 1e-8)。这行代码计算seq_losses张量的加权平均值,从而得到一个最终的标量损失值,用于反向传播。
2. 应用场景:ActorWorker.loss_func 中的配置调用
agg_loss 函数只是一个工具,它如何被使用取决于调用它的地方。在 ActorWorker.loss_func 中,这个工具被用来聚合最终的策略梯度损失(pg_loss)和KL损失等。
python
class ActorWorker(BaseActorWorker):
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
# ... (计算每个token的PPO损失 `loss`)
loss = -torch.min(surr1, surr2) # loss 的形状是 (bs, seq_len)
# ...
# -------------------------------------------------------------
# 这里通过配置决定使用哪种聚合模式
# -------------------------------------------------------------
weighted_pg_loss = agg_loss(loss_mat=loss,
loss_mask=final_response_mask,
# 这个参数决定了是否执行GRPO的损失计算方式
loss_agg_mode=self.pipeline_config.loss_agg_mode,
weights=sample_weights,
loss_scale=loss_scale)
# ...
return total_loss, pg_metrics
关键点:
self.pipeline_config.loss_agg_mode这个配置项是"开关"。- 当您在训练配置中设置
loss_agg_mode: "seq-mean-token-mean"时,ActorWorker就会指示agg_loss函数采用GRPO描述的聚合方式。
总结与对比
为了让这个概念更清晰,我们来对比一下GRPO的聚合方式和标准的PPO(通常按token平均)的聚合方式。
| 聚合方式 | 描述 | agg_loss 中的模式 |
物理意义 |
|---|---|---|---|
| 标准PPO (Token-level) | 将所有样本的所有有效token的损失放在一起,直接计算总平均值。 | "token-mean" |
每个token对最终损失的贡献是平等的。长序列会比短序列有更大的影响力。 |
| GRPO (Sample-level) | 先计算每个样本的平均token损失,再计算所有样本的平均损失。 | "seq-mean-token-mean" |
每个样本(序列)对最终损失的贡献是平等的,无论其长短。 这更符合GRPO的思想,即优化整个序列的质量,而不是单个token。 |
因此,**代码通过 agg_loss 函数中的 "seq-mean-token-mean" 逻辑,以及在 loss_func 中通过配置文件 self.pipeline_config.loss_agg_mode 来调用该逻辑,完美地体现了GRPO在样本级别计算目标的思想。**好的,我们来通过一个非常具体和详细的例子来解释 masked_mean(loss_mat, loss_mask, dim=-1) 的输入和输出形状。
假设我们的 batch_size 为 3,响应序列的最大长度(response_length)为 5。
输入:loss_mat 和 loss_mask
1. loss_mat: 损失矩阵 (Tensor)
loss_mat 是一个形状为 (3, 5) 的张量,代表每个样本在每个token位置上的损失。
loss_mat = tensor([
[0.1, 0.5, 0.2, 0.0, 0.0], # 样本 1 的逐token损失
[0.8, 0.3, 0.6, 0.4, 0.1], # 样本 2 的逐token损失
[0.9, 0.2, 0.0, 0.0, 0.0] # 样本 3 的逐token损失
])
# 形状: torch.Size([3, 5])
loss_mat[0, 1]的值是0.5,代表第一个样本的第二个token的损失。- 我们看到有些位置的损失是0,这可能是因为这些位置是填充(padding)的,但
loss_mat本身并不包含这个信息。哪个位置是有效的,由loss_mask决定。
2. loss_mask: 掩码矩阵 (Tensor)
loss_mask 也是一个形状为 (3, 5) 的张量,但它的值只有0和1。1 代表这个位置是有效的(是真实的token),0 代表这个位置是无效的(是padding)。
loss_mask = tensor([
[1, 1, 1, 0, 0], # 样本 1: 有效长度为 3
[1, 1, 1, 1, 1], # 样本 2: 有效长度为 5
[1, 1, 0, 0, 0] # 样本 3: 有效长度为 2
], dtype=torch.float)
# 形状: torch.Size([3, 5])
- 样本 1: 真实响应长度是3个token,后面2个是padding。
- 样本 2: 真实响应长度是5个token,占满了整个序列长度。
- 样本 3: 真实响应长度是2个token,后面3个是padding。
计算过程:masked_mean(loss_mat, loss_mask, dim=-1)
这个函数会沿着最后一个维度(dim=-1,即序列维度)进行计算。对于批次中的每一个样本(每一行),它会执行以下操作:
-
元素相乘 : 将
loss_mat与loss_mask按元素相乘。这会把所有无效位置(mask为0)的损失清零。loss_mat * loss_mask = tensor([ [0.1, 0.5, 0.2, 0.0, 0.0], # 样本 1: [0.1*1, 0.5*1, 0.2*1, 0.0*0, 0.0*0] [0.8, 0.3, 0.6, 0.4, 0.1], # 样本 2: [0.8*1, 0.3*1, 0.6*1, 0.4*1, 0.1*1] [0.9, 0.2, 0.0, 0.0, 0.0] # 样本 3: [0.9*1, 0.2*1, 0.0*0, 0.0*0, 0.0*0] ]) -
求和: 对上一步结果的每一行(每个样本)进行求和。
- 样本 1 和 :
0.1 + 0.5 + 0.2 = 0.8 - 样本 2 和 :
0.8 + 0.3 + 0.6 + 0.4 + 0.1 = 2.2 - 样本 3 和 :
0.9 + 0.2 = 1.1
得到一个临时的和向量
sum_vec = tensor([0.8, 2.2, 1.1]) - 样本 1 和 :
-
计算有效长度 : 对
loss_mask的每一行(每个样本)进行求和,得到每个样本的有效token数量。- 样本 1 有效长度 :
1 + 1 + 1 = 3 - 样本 2 有效长度 :
1 + 1 + 1 + 1 + 1 = 5 - 样本 3 有效长度 :
1 + 1 = 2
得到一个临时的长度向量
len_vec = tensor([3.0, 5.0, 2.0]) - 样本 1 有效长度 :
-
相除 : 将步骤2得到的和向量
sum_vec与步骤3得到的长度向量len_vec按元素相除。- 样本 1 平均损失 :
0.8 / 3 = 0.2667 - 样本 2 平均损失 :
2.2 / 5 = 0.44 - 样本 3 平均损失 :
1.1 / 2 = 0.55
- 样本 1 平均损失 :
输出:seq_losses
经过上述计算,masked_mean 函数返回一个一维张量,其长度等于 batch_size。这个张量就是我们所说的 seq_losses。
python
seq_losses = masked_mean(loss_mat, loss_mask, dim=-1)
# seq_losses 的值:
# tensor([0.2667, 0.4400, 0.5500])
# 形状: torch.Size([3])
总结
- 输入
loss_mat形状 :(batch_size, sequence_length)->(3, 5) - 输入
loss_mask形状 :(batch_size, sequence_length)->(3, 5) - 操作 : 沿着
dim=-1(序列维度) 计算每个样本的平均损失。 - 输出
seq_losses形状 :(batch_size,)->(3,)
这个输出 seq_losses 的每一个元素都代表了一个完整的生成序列(样本)的平均损失。接下来的步骤(如GRPO中的第二步)就是对这个 seq_losses 张量再做一次聚合(例如,取所有元素的平均值),从而得到最终的、用于反向传播的单个损失值。
好的,您提供的描述和公式非常清晰,它们精确地指向了**Token级平均(Token-level Mean)**的损失聚合方式。我们来详细解析这个概念以及它在代码中的具体实现。
概念解析
您提供的描述和公式(DAPO损失函数)完美地阐述了Token级策略梯度损失的核心思想:

-
公式解读:
Σ_{i=1}^{G} Σ_{t=1}^{|o_i|} min(...):这是一个双重求和。它首先遍历一个批次中的所有样本(i从 1到G),然后遍历每个样本中的所有token(t从 1 到|o_i|)。这表示我们将批次中每一个有效token 的PPO损失(即min(...)部分)全部加起来。1 / Σ_{i=1}^{G} |o_i|:这是归一化因子。Σ_{i=1}^{G} |o_i|表示计算批次中所有样本的长度之和,即批次中有效token的总数。- 整体 : 整个公式的含义是:计算批次中所有有效token的PPO损失的总和,然后除以有效token的总数,得到所有token的平均损失。
-
与您的描述对应:
- "较长序列相比较短序列对整体梯度更新的影响可能更大": 这是因为在Token级平均中,一个长度为100的序列会贡献100个token的损失项到总和中,而一个长度为10的序列只贡献10个。因此,长序列在计算总平均损失时有更大的"话语权"。
- "从单个 Token 的角度来看...无论它出现在哪种长度的响应中,都将被同等地促进或抑制" : 这是因为每个token的损失
min(...)在被加到总和之前,没有经过任何与序列长度相关的加权或归一化。它的值只取决于它自身的ratio和advantage。所有token最终被"一视同仁"地扔进一个大池子里求平均。
代码实现
这种Token级平均的损失计算方式在您的代码中对应的是 agg_loss 函数的 "token-mean" 模式,以及 masked_mean 函数在没有指定 dim 参数时的行为。
1. 核心代码:agg_loss 与 masked_mean
在 roll/utils/functionals.py 文件中:
python
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, ...):
...
# 当配置为 "token-mean" 时,会进入这个分支
if loss_agg_mode == "token-mean":
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
# 关键调用:这里调用 masked_mean 时没有传递 dim 参数
loss = masked_mean(loss_mat * weights.unsqueeze(-1), loss_mask)
...
return loss
现在,我们看 masked_mean 函数在 dim=None 时的实现:
python
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
if dim is not None:
# 这是我们之前讨论的 GRPO (seq-mean) 的情况
...
else:
# 这就是 Token-level mean 的实现!
# (tensor * mask).sum() -> 对应公式的 Σ_i Σ_t min(...)
# (mask.sum()) -> 对应公式的 Σ_i |o_i|
return (tensor * mask).sum() / (mask.sum() + 1e-8)
代码解读:
- 当
ActorWorker.loss_func调用agg_loss并将loss_agg_mode设置为"token-mean"时,它最终会调用masked_mean且不带dim参数。 (tensor * mask).sum(): 这里的tensor就是PPO损失矩阵loss_mat。.sum()在没有维度参数时执行全局求和 ,将整个张量中所有元素的值加在一起。这完全对应了公式中的Σ_{i=1}^{G} Σ_{t=1}^{|o_i|},即把所有有效token的损失加起来。mask.sum(): 同样,这也是一个全局求和,计算mask张量中所有1的数量,这正好是批次中有效token的总数,对应公式中的Σ_{i=1}^{G} |o_i|。- 最后的除法操作,完成了求平均的过程。
2. 举例说明
我们再次使用之前的例子 (batch_size=3, sequence_length=5):
loss_mat(形状[3, 5])loss_mask(形状[3, 5])
计算过程 masked_mean(loss_mat, loss_mask):
-
loss_mat * loss_mask得到掩码后的损失矩阵:tensor([ [0.1, 0.5, 0.2, 0.0, 0.0], [0.8, 0.3, 0.6, 0.4, 0.1], [0.9, 0.2, 0.0, 0.0, 0.0] ]) -
.sum()(全局求和):
0.1 + 0.5 + 0.2 + 0.8 + 0.3 + 0.6 + 0.4 + 0.1 + 0.9 + 0.2 = 4.1这就是所有有效token的损失总和。
-
mask.sum()(全局求和):
loss_mask中1的总数是3 (样本1) + 5 (样本2) + 2 (样本3) = 10。这就是批次中有效token的总数。
-
最终结果:
loss = 4.1 / 10 = 0.41
这是一个单一的标量值,代表了整个批次中所有token的平均损失。在这个计算中,长度为5的样本2贡献了5个损失项,而长度为2的样本3只贡献了2个,因此样本2的影响力更大。
总结
您描述的Token级策略梯度损失在所提供的代码中是通过以下组合实现的:
- 在训练配置中设置
loss_agg_mode: "token-mean"。 ActorWorker.loss_func调用agg_loss函数。agg_loss进而调用masked_mean函数,且不指定dim参数。masked_mean函数内部执行(tensor * mask).sum() / (mask.sum() + 1e-8),这在数学上完全等价于您提供的DAPO损失函数中的聚合方式。
好的,我们来详细举例说明 loss_mat * weights.unsqueeze(-1) 这个操作。这个操作的目的是对每个样本(序列)的损失进行加权。
假设我们有一个批次(batch_size=3),响应序列的最大长度为 5。
输入
1. loss_mat: 损失矩阵 (Tensor)
和之前的例子一样,loss_mat 是一个形状为 (3, 5) 的张量,代表每个token的损失。
python
import torch
loss_mat = torch.tensor([
[0.1, 0.5, 0.2, 0.0, 0.0], # 样本 1 的逐token损失
[0.8, 0.3, 0.6, 0.4, 0.1], # 样本 2 的逐token损失
[0.9, 0.2, 0.0, 0.0, 0.0] # 样本 3 的逐token损失
])
# loss_mat 形状: torch.Size([3, 5])
2. weights: 样本权重向量 (Tensor)
weights 是一个一维张量,其长度等于 batch_size。每个元素代表对应样本的权重。这个权重可能基于样本的难度、长度或其他自定义逻辑计算得出。
python
weights = torch.tensor([
0.8, # 样本 1 的权重
1.5, # 样本 2 的权重 (可能更重要)
0.7 # 样本 3 的权重 (可能不那么重要)
])
# weights 形状: torch.Size([3])
计算过程
步骤 1: weights.unsqueeze(-1) - 扩展维度
这个操作会在weights张量的最后一个维度(-1)增加一个新的维度。这是一种**广播(Broadcasting)**机制的准备步骤。
python
weights_expanded = weights.unsqueeze(-1)
# weights_expanded 的值:
# tensor([
# [0.8],
# [1.5],
# [0.7]
# ])
# weights_expanded 形状: torch.Size([3, 1])
现在,weights_expanded 是一个列向量。
步骤 2: loss_mat * weights_expanded - 广播乘法
现在我们将形状为 (3, 5) 的 loss_mat 与形状为 (3, 1) 的 weights_expanded 相乘。
PyTorch的广播机制会起作用:
- 比较两个张量的维度,从后往前看。
loss_mat:(3, 5)weights_expanded:(3, 1)- PyTorch发现
weights_expanded的第二个维度是1,而loss_mat的第二个维度是5。它会自动将weights_expanded在第二个维度上"复制"5次,使其形状在逻辑上变为(3, 5)。
逻辑上扩展后的 weights_expanded 看起来像这样:
tensor([
[0.8, 0.8, 0.8, 0.8, 0.8],
[1.5, 1.5, 1.5, 1.5, 1.5],
[0.7, 0.7, 0.7, 0.7, 0.7]
])
现在,进行元素级别的乘法:
python
weighted_loss_mat = loss_mat * weights_expanded
# weighted_loss_mat 的值:
# tensor([
# [0.1*0.8, 0.5*0.8, 0.2*0.8, 0.0*0.8, 0.0*0.8], # 样本1的所有token损失都乘以0.8
# [0.8*1.5, 0.3*1.5, 0.6*1.5, 0.4*1.5, 0.1*1.5], # 样本2的所有token损失都乘以1.5
# [0.9*0.7, 0.2*0.7, 0.0*0.7, 0.0*0.7, 0.0*0.7] # 样本3的所有token损失都乘以0.7
# ])
# 计算结果:
# tensor([
# [0.08, 0.40, 0.16, 0.00, 0.00],
# [1.20, 0.45, 0.90, 0.60, 0.15],
# [0.63, 0.14, 0.00, 0.00, 0.00]
# ])
# weighted_loss_mat 形状: torch.Size([3, 5])
总结与目的
- 输入
loss_mat形状 :(batch_size, sequence_length)->(3, 5) - 输入
weights形状 :(batch_size,)->(3,) weights.unsqueeze(-1)后形状 :(batch_size, 1)->(3, 1)- 操作 : 通过广播,将每个样本的权重(一个标量)乘到该样本对应的整行(所有token)的损失上。
- 输出
weighted_loss_mat形状 :(batch_size, sequence_length)->(3, 5)
这个操作的核心目的 :在计算总损失之前,根据预先定义的样本重要性(weights),调整每个样本内所有token损失的相对大小。权重高的样本,其所有token的损失都会被放大,从而在后续的梯度计算中产生更大的影响;反之,权重低的样本,其影响力会被削弱。
这个加权后的 weighted_loss_mat 随后会作为 masked_mean 的输入,进行最终的损失聚合。