cpp
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
"""
loss func接口定义:
data: DataProto, 由train_step透传
output_tensor: torch.Tensor, model.forward()的输出Tensor
"""
response_mask = data.batch["response_mask"][:, 1:].long()
ref_log_probs = data.batch["ref_log_probs"]
old_log_probs = data.batch["old_log_probs"]
advantages = data.batch["advantages"]
log_probs = self.strategy.op_compute_log_probs(
logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"]
)
ratio = (log_probs - old_log_probs).exp()
pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages
pg_loss = -torch.min(surr1, surr2)
if self.pipeline_config.dual_clip_loss:
dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages)
pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss)
pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode)
kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask,
kl_penalty="k3")
kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode)
approxkl = compute_approx_kl(
log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse"
)
policykl = compute_approx_kl(
log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl"
)
clipped_low = (ratio < 1 - pg_clip_low).float()
clipped_high = (ratio > 1 + pg_clip_high).float()
clipped = (clipped_low + clipped_high).float()
if self.pipeline_config.use_kl_loss:
total_loss = pg_loss + kl_loss * self.pipeline_config.kl_loss_coef
else:
total_loss = pg_loss
if self.pipeline_config.entropy_loss_coef > 0:
entropy = self.strategy.op_compute_entropy(logits=output_tensor, attention_mask=data.batch["response_mask"])
entropy_loss = agg_loss(
loss_mat=entropy,
loss_mask=response_mask,
loss_agg_mode=self.pipeline_config.loss_agg_mode,
)
total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef
pg_metrics = {
"actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(),
"actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(),
"actor/ppo_ratio_clipfrac": clipped.mean().detach().item(),
"actor/ratio_mean": masked_mean(ratio, response_mask, dim=-1).mean().detach().item(),
"actor/ratio_max": torch.max(ratio * response_mask).detach().item(),
"actor/ratio_min": torch.min(ratio * response_mask + (1 - response_mask) * 1e10).detach().item(),
"actor/clipfrac": agg_loss(loss_mat=torch.lt(surr2, surr1).float(), loss_mask=response_mask,
loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),
"actor/pg_loss": pg_loss.detach().item(),
"actor/kl_loss": kl_loss.detach().item(),
"actor/total_loss": total_loss.detach().item(),
"actor/approxkl": agg_loss(loss_mat=approxkl, loss_mask=response_mask,
loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),
"actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask,
loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),
}
return total_loss, pg_metrics
你提出了一个非常好的问题,这个问题涉及到模型输入、输出和损失计算之间如何对齐的关键细节。
data.batch["response_mask"][:, 1:].long() 这么做的原因是为了 让 response_mask 与模型的输出 logits 以及需要计算损失的目标 log_probs 在序列长度维度上对齐。
让我们一步步来分解这个原因。
1. 语言模型的输入和输出
假设我们有一个输入序列 input_ids:
[BOS, "你好", "吗", "?", EOS]
BOS: Begin of Sentence token (起始符)EOS: End of Sentence token (结束符)
当这个序列被送入一个自回归语言模型(比如 GPT)时,模型的工作方式是 预测下一个 token。
- 输入
[BOS],模型预测"你好" - 输入
[BOS, "你好"],模型预测"吗" - 输入
[BOS, "你好", "吗"],模型预测"?" - 输入
[BOS, "你好", "吗", "?"],模型预测EOS
因此,如果输入序列的长度是 L,那么模型输出的 logits(预测每个位置的下一个 token 的概率分布)的序列长度通常也是 L。
但是,logits 在位置 i 的输出,是基于 input_ids 在位置 0 到 i 的输入得到的,它是用来预测 input_ids 在位置 i+1 的 token 的。
看下图示:
input_ids: [ token_0, token_1, token_2, token_3 ] (长度 L=4)
/ / / /
logits: [ logit_0, logit_1, logit_2, logit_3 ] (长度 L=4)
| | | |
预测目标: (token_1) (token_2) (token_3) (token_4)
2. 计算损失时的对齐问题
我们要计算的损失,是模型预测的 log_probs 和真实 token 之间的交叉熵。
logit_0用来预测token_1logit_1用来预测token_2logit_2用来预测token_3
我们注意到,logit 的最后一个位置 logit_3 是用来预测 token_4 的,但我们的输入序列里没有 token_4。因此,logit_3 在计算损失时通常是被忽略的。
反过来看,token_0 (通常是 BOS 或 PAD) 从来不作为被预测的目标,所以与它对应的 logit 也不需要。
所以,为了计算损失,我们需要:
logits: 取[logit_0, logit_1, logit_2],即logits[:, :-1, :]- 目标
input_ids: 取[token_1, token_2, token_3],即input_ids[:, 1:]
两者现在长度都是 L-1,完美对齐。
3. response_mask 的对齐
现在回到你的问题 response_mask。
response_mask 的原始形状与 input_ids 是一样的,长度为 L。它标记了哪些 token 属于需要计算损失的 "response" 部分。
例如,对于输入 [prompt_1, prompt_2, resp_1, resp_2, resp_3, EOS]:
原始 response_mask 可能是 [0, 0, 1, 1, 1, 1]
当我们计算损失时,我们用的是 logits[:, :-1] 和 input_ids[:, 1:]。这意味着我们实际上是在处理一个长度为 L-1 的序列。
为了让 response_mask 能够正确地屏蔽掉这个长度为 L-1 的序列中不需要计算损失的部分,它自身也需要被截断,以匹配新的序列长度。
我们截断的方式必须和目标 input_ids 的截断方式一样,即 从第二个元素开始取,去掉第一个元素。
- 原始
response_mask(长度 L):[mask_0, mask_1, mask_2, ..., mask_{L-1}] - 截断后的
response_mask(长度 L-1):[mask_1, mask_2, ..., mask_{L-1}]- 这个操作就是
response_mask[:, 1:]
- 这个操作就是
这样,截断后的 mask_1 就会对应 logit_0 和 目标 token_1,mask_2 对应 logit_1 和 目标 token_2,以此类推,对齐关系就完全正确了。
总结
data.batch["response_mask"][:, 1:].long() 这行代码的目的是:
[:, 1:]: 为了与用于计算损失的logits[:, :-1]和目标input_ids[:, 1:]在序列长度上保持对齐。它通过移除 mask 的第一个元素来实现这一点。.long(): 将 mask 的数据类型转换为长整型 (torch.int64)。这通常是为了兼容某些 PyTorch 操作,比如用 mask 进行索引或者作为某些损失函数的权重时,可能要求是整型而不是布尔型或浮点型。
这是一个在实现语言模型训练时非常常见且至关重要的细节,确保了计算的正确性。
好的,我们来详细分析你提供的这个特定版本的 op_compute_log_probs 函数,并用一个具体的例子来走一遍它的执行过程。
这个版本的实现比我之前给出的通用版本更简洁,它依赖一个名为 log_probs_from_logits 的辅助函数(HuggingFace TRL 库中常见)。我们假设 log_probs_from_logits 的功能就是我之前描述的第 2 步(Log-Softmax)和第 3 步(Gather)的结合。
函数代码分析
python
# 假设 log_probs_from_logits 的实现如下:
def log_probs_from_logits(logits, labels):
log_probs = F.log_softmax(logits, dim=-1)
return torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
class YourClass:
def op_compute_log_probs(self, logits: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor):
"""
logits: llm logits, 形状 [batch_size, seq_len, vocab_size]
input_ids [[p, p, r, r, r, 0, 0]], 形状 [batch_size, seq_len]
attention_mask(response_mask) [[0, 0, 1, 1, 1, 0, 0]], 形状 [batch_size, seq_len]
"""
# 1. 准备 Labels,并处理无效 Token
labels: torch.Tensor = input_ids[:, 1:].clone()
# 将 mask 为 0 的位置的 label 设置为 0
labels[attention_mask[:, 1:] == 0] = 0
# 2. 计算 Log Probs
# 传入错位对齐的 logits 和处理过的 labels
log_probs = log_probs_from_logits(logits[:, :-1], labels)
# 3. 应用 Mask
# 将不属于 response 的位置的 log_probs 清零
log_probs = log_probs * attention_mask[:, 1:]
return log_probs
核心逻辑点
labels[attention_mask[:, 1:] == 0] = 0: 这是这个实现中最有趣和最关键的一步。它的目的是防止log_probs_from_logits访问到无效的 token ID 。- 在 PPO 训练中,
input_ids可能包含PADtoken(其 ID 通常是 0)。 - 如果一个
PADtoken 出现在labels中,torch.gather会尝试去访问词汇表索引为 0 的位置。这本身没问题。 - 但更重要的是,对于 prompt 部分和 padding 部分,我们根本不关心它们的
log_probs,因为它们不会计入最终的损失。将这些位置的label统一设置为 0,可以简化计算。虽然gather仍然会为这些位置计算一个值(即词汇表中 token 0 的对数概率),但这没关系,因为 第 3 步会把这些位置的log_probs全部清零。这是一个"先计算再丢弃"的策略。
- 在 PPO 训练中,
举例说明执行过程
假设有以下微型配置:
batch_size = 1seq_len = 7vocab_size = 50000PAD_TOKEN_ID = 0
输入:
input_ids:[[101, 102, 201, 202, 203, 0, 0]][p, p, r, r, r, pad, pad]
attention_mask(response_mask) :[[0, 0, 1, 1, 1, 0, 0]]logits: 一个由模型生成的[1, 7, 50000]的张量。
执行步骤:
第 1 步: 准备 labels
-
labels = input_ids[:, 1:].clone()input_ids[:, 1:]得到[[102, 201, 202, 203, 0, 0]]labels的值现在是[[102, 201, 202, 203, 0, 0]],形状[1, 6]。
-
计算
maskforlabelsattention_mask[:, 1:]得到[[0, 1, 1, 1, 0, 0]]
-
labels[attention_mask[:, 1:] == 0] = 0attention_mask[:, 1:] == 0会产生一个布尔掩码[[True, False, False, False, True, True]]。- 这个掩码会选中
labels中需要被修改的位置:labels的第 0 个元素 (对应 prompt 部分)labels的第 4 个元素 (对应第一个 pad)labels的第 5 个元素 (对应第二个 pad)
labels被原地修改,修改后的值为:[[0, 201, 202, 203, 0, 0]]- 注意:原来的
102变成了0。
- 注意:原来的
至此,
labels准备完毕,值为[[0, 201, 202, 203, 0, 0]]。
第 2 步: 计算 log_probs
-
准备
logitslogits[:, :-1]得到一个[1, 6, 50000]的张量。
-
调用
log_probs_from_logits(logits[:, :-1], labels)-
log_probs_from_logits内部会:
a. 对logits[:, :-1]在最后一个维度上做log_softmax。
b. 使用labels[[0, 201, 202, 203, 0, 0]]作为索引,通过torch.gather从log_softmax的结果中提取值。 -
log_probs的计算结果(形状为[1, 6])会是:[[ logP(token=0 | p), // prompt 部分,计算了 pad token 的 log_prob logP(token=201 | p,p), // response 部分,正确 logP(token=202 | p,p,r), // response 部分,正确 logP(token=203 | p,p,r,r), // response 部分,正确 logP(token=0 | p,p,r,r,r), // padding 部分,计算了 pad token 的 log_prob logP(token=0 | p,p,r,r,r,0) // padding 部分,计算了 pad token 的 log_prob ]] -
假设计算出的值为
[[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]]。
-
第 3 步: 应用 Mask
log_probs = log_probs * attention_mask[:, 1:]-
log_probs:[[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]] -
attention_mask[:, 1:]:[[0, 1, 1, 1, 0, 0]] -
两者进行逐元素相乘:
[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ] * [ 0, 1, 1, 1, 0, 0 ] ---------------------------------------- = [ 0, -1.1, -0.8, -1.5, 0, 0 ]
-
最终返回结果
函数返回的 log_probs 张量为: [[ 0, -1.1, -0.8, -1.5, 0, 0 ]]
这个结果非常完美:
- 只有 response 部分 (
r, r, r) 的log_probs被保留了下来。 - prompt 部分和 padding 部分的
log_probs都被清零了。 - 这个张量可以直接用于后续的损失计算(例如计算
ratio = (log_probs - old_log_probs).exp()),而不需要担心无效位置的干扰。
这个实现方式非常高效和简洁,巧妙地利用了 PyTorch 的广播和掩码操作。
好的,我们来详细拆解这两行代码,它们是PyTorch中一个非常常见且强大的模式,用于从概率分布中根据标签提取特定的值。
这两行代码的核心是 torch.gather 函数。
torch.gather 的工作原理
torch.gather 就像一个高级的索引工具。它的作用是:沿着指定的维度(dim),根据 index 张量中的值,从输入张量(input)中收集元素。
它的签名是 torch.gather(input, dim, index)。
为了让它工作,index 张量需要满足一些条件,最重要的是:
index的维度数量必须和input的维度数量相同。- 在所有非
dim的维度上,index的大小必须和input的大小相同(或者为 1,可以广播)。
结合你的代码进行分解
我们一步步来看:
python
# 假设我们有以下张量(以 batch_size=1, seq_len=3, vocab_size=5 为例)
# log_probs: [1, 3, 5] 的张量,代表了3个位置上,每个词的对数概率
log_probs = torch.tensor([[
[-1.6, -2.1, -0.9, -3.0, -1.8], # 位置0的 log_probs
[-0.5, -1.1, -2.5, -1.3, -4.0], # 位置1的 log_probs
[-3.2, -1.9, -1.0, -2.2, -0.8] # 位置2的 log_probs
]])
# labels: [1, 3] 的张量,代表了3个位置上,正确的 token ID
labels = torch.tensor([[2, 0, 4]])
第 1 步: labels.unsqueeze(-1)
- 目的 : 增加一个维度,使
labels的维度数量与log_probs相同,从而满足torch.gather的要求。 - 输入
labels:- 形状:
[1, 3] - 值:
[[2, 0, 4]]
- 形状:
- 操作 :
unsqueeze(-1)在最后一个维度(维度索引为-1)上增加一个大小为 1 的新维度。 - 输出
index:-
形状:
[1, 3, 1] -
值:
[[[2], [0], [4]]]
-
现在,log_probs (3D) 和 index (3D) 的维度数量相同了。
第 2 步: log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
input:log_probs(形状[1, 3, 5])dim:-1(或2)。这意味着我们将在最后一个维度------**词汇表维度(vocab_size)**上进行收集。index:labels.unsqueeze(-1)(形状[1, 3, 1])
gather 的执行过程 (可以想象成一个 for 循环):
gather会遍历index张量的所有位置。- 对于
index中的每个元素(batch_idx, seq_idx, 0),它会取出其中的值v = index[batch_idx, seq_idx, 0]。 - 然后,它会在
log_probs张量的对应位置(batch_idx, seq_idx, ...)上,沿着dim=-1收集索引为v的元素。 - 它将收集到的值放在输出张量的
(batch_idx, seq_idx, 0)位置。
让我们手动走一遍:
-
处理
index的[0, 0, 0]位置:index[0, 0, 0]的值是2。gather去log_probs的[0, 0, :]位置,也就是[-1.6, -2.1, -0.9, -3.0, -1.8]。- 它从这个向量中取出索引为
2的元素,即-0.9。 - 输出张量的
[0, 0, 0]位置被设置为-0.9。
-
处理
index的[0, 1, 0]位置:index[0, 1, 0]的值是0。gather去log_probs的[0, 1, :]位置,也就是[-0.5, -1.1, -2.5, -1.3, -4.0]。- 它从这个向量中取出索引为
0的元素,即-0.5。 - 输出张量的
[0, 1, 0]位置被设置为-0.5。
-
处理
index的[0, 2, 0]位置:index[0, 2, 0]的值是4。gather去log_probs的[0, 2, :]位置,也就是[-3.2, -1.9, -1.0, -2.2, -0.8]。- 它从这个向量中取出索引为
4的元素,即-0.8。 - 输出张量的
[0, 2, 0]位置被设置为-0.8。
gather 的输出 log_probs_labels:
-
形状:
[1, 3, 1](与index的形状相同) -
值:
[[[-0.9], [-0.5], [-0.8]]]
直观理解 : 对于序列中的每个位置,我们都从完整的词汇表概率分布中,只挑选出了正确标签(label)对应的那个对数概率。
第 3 步: .squeeze(-1)
- 目的: 移除多余的、大小为 1 的维度,让张量更易于处理。
- 输入
log_probs_labels:- 形状:
[1, 3, 1]
- 形状:
- 操作 :
squeeze(-1)移除最后一个维度(因为它的大小是 1)。 - 输出 :
- 形状:
[1, 3] - 值:
[[-0.9, -0.5, -0.8]]
- 形状:
最终结果
函数最终返回了一个 [1, 3] 的张量 [[-0.9, -0.5, -0.8]]。
这个张量的每个元素 output[i, j] 都代表了在批次 i 的序列位置 j,模型赋予正确 label 的对数概率。这正是我们计算交叉熵损失或 PPO 损失时所需要的核心数值。