GPT监督微调SFT:在损失计算中屏蔽指令和填充 Token

1. 概述
在对自回归(Causal)语言模型(如 GPT、Llama)进行监督微调(Supervised Fine-Tuning, SFT)时,一个关键的步骤是正确地构建 labels 张量以计算损失。
一个常见的误区是让模型预测"指令 + 回应"的完整序列。正确的做法是仅让模型预测"回应"部分的 Token。
核心原则: 我们的目标是让模型学习 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( R e s p o n s e ∣ I n s t r u c t i o n ) P(Response | Instruction) </math>P(Response∣Instruction)(在给定指令的条件下生成回应),而不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( I n s t r u c t i o n ) P(Instruction) </math>P(Instruction)(复述指令)或 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( P a d d i n g ) P(Padding) </math>P(Padding)(预测填充)。
2. 为什么必须进行屏蔽?
在 SFT 中,我们使用 torch.nn.functional.cross_entropy (或其 nn.Module 形式) 作为损失函数。这个函数包含一个关键参数 ignore_index,其默认值通常为 -100。当 labels 张量中的某个值为 ignore_index 时,该位置的损失将不被计算,也不会产生梯度。
我们必须利用这一特性屏蔽掉两部分内容:
2.1 屏蔽填充(Padding Tokens)
这是最基础的屏蔽。
- 原因: 在批处理(Batching)中,为了使序列具有统一的长度,我们会用
<PAD>Token 填充较短的序列。 - 后果: 如果不屏蔽填充,模型会被迫学习"在序列末尾预测
<PAD>Token"。这是一个毫无意义且有害的训练目标,它会浪费模型的学习能力。 - 解决方案: 将
labels中所有对应<PAD>Token 的位置设置为-100。
2.2 屏蔽指令(Instruction/Prompt Tokens)
这是 SFT 成功的关键,也是最容易被忽视的。
-
原因 1:避免错误的训练目标
- SFT 的任务是教会模型如何"回答" ,而不是如何"复述问题"。
- 在推理(Inference)时,用户会提供完整的指令(Prompt)。模型唯一的工作就是从指令的末尾开始接着生成回应。
- 如果不屏蔽指令,模型会花费大量的训练资源去学习"复述指令"(即 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( I n s t r u c t i o n ) P(Instruction) </math>P(Instruction))。这是一个完全错误的目标,因为它在推理时毫无用处。
-
原因 2:避免损失稀释(Loss Dilution)
- 模型的总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L o s s total Loss_{\text{total}} </math>Losstotal 是所有 Token 损失的平均值。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> L o s s total = ( L o s s instruction + L o s s response ) / TotalTokens Loss_{\text{total}} = (Loss_{\text{instruction}} + Loss_{\text{response}}) / \text{TotalTokens} </math>Losstotal=(Lossinstruction+Lossresponse)/TotalTokens
- "复述指令"是一个非常简单的任务(输入和目标几乎一致),其损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L o s s instruction Loss_{\text{instruction}} </math>Lossinstruction 会迅速降低并趋近于 0。
- "生成回应"是一个困难的任务,其损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L o s s response Loss_{\text{response}} </math>Lossresponse 才是我们真正关心的。
- 如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> L o s s instruction Loss_{\text{instruction}} </math>Lossinstruction 占了总损失的很大一部分(例如指令很长),它会"稀释"我们真正关心的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L o s s response Loss_{\text{response}} </math>Lossresponse 所占的比重。这会导致梯度信号被无效任务分散,降低了模型学习"如何回答"的效率。
3. 构建一个微型 SFT 示例
在这个示例中,我们的词汇表里只有 4 个 Token,序列长度也只有 6。这样我们就可以把每一步的张量都打印出来,看得一清二楚。
我们将模拟以下设置:
-
VOCAB_SIZE = 4 -
Token IDs:
PAD_ID = 0PROMPT_ID = 1RESPONSE_ID = 2EOS_ID = 3
-
BATCH_SIZE = 2 -
SEQ_LEN = 6 -
IGNORE_INDEX = -100
Python
import torch
import torch.nn.functional as F
# --- 1. 定义我们的"玩具"词汇表和设置 ---
VOCAB_SIZE = 4
BATCH_SIZE = 2
SEQ_LEN = 6
PAD_ID = 0
PROMPT_ID = 1
RESPONSE_ID = 2
EOS_ID = 3
IGNORE_INDEX = -100
# --- 2. 手动构建 Input IDs 和 Prompt 长度 ---
# 假设我们有2个样本:
# 样本 1 (Prompt Len=3): [PROMPT, PROMPT, PROMPT, RESPONSE, RESPONSE, EOS]
# 样本 2 (Prompt Len=2): [PROMPT, PROMPT, RESPONSE, EOS]
# 经过左填充 (Left Padding) 到 SEQ_LEN=6 后:
input_ids = torch.tensor([
[PROMPT_ID, PROMPT_ID, PROMPT_ID, RESPONSE_ID, RESPONSE_ID, EOS_ID], # 样本 1 (无填充)
[PAD_ID, PAD_ID, PROMPT_ID, PROMPT_ID, RESPONSE_ID, EOS_ID] # 样本 2 (填充 2)
], dtype=torch.long)
# 对应的 attention_mask
attention_mask = torch.tensor([
[1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]
], dtype=torch.long)
# 关键:我们知道每个样本的"指令"长度
prompt_lengths = [3, 2] # 样本1的指令有3个 token, 样本2有2个
print("--- 准备数据 ---")
print(f"Input IDs:\n{input_ids}")
# --- 3. 创建 (A) 正确屏蔽的 Labels ---
labels_correct = input_ids.clone()
for i in range(BATCH_SIZE):
# 屏蔽指令
prompt_len = prompt_lengths[i]
labels_correct[i, :prompt_len] = IGNORE_INDEX
# 屏蔽填充
labels_correct[i, attention_mask[i] == 0] = IGNORE_INDEX
print(f"\n(A) 正确屏蔽的 Labels:\n{labels_correct}")
# --- 4. 创建 (B) 错误_未屏蔽的 Labels ---
# 这个 labels 将计算所有 Token(包括 Padding 和 Prompt)
labels_unmasked = input_ids.clone()
print(f"\n(B) 未屏蔽的 Labels:\n{labels_unmasked}")
# --- 5. 模拟一个"部分训练过"的 Logits ---
# 形状: (BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) -> (2, 6, 4)
# 我们用 randn 初始化,值在 -1 到 1 之间
mock_logits_trained = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
# 关键:模拟模型"学会了"预测 PAD_ID = 0
# 我们手动将所有位置上,预测 PAD_ID 的 logit 设为一个很高的值 (例如 10.0)
# 这代表模型在"疯狂地"预测每个 Token 都应该是 PAD
mock_logits_trained[:, :, PAD_ID] += 10.0
print(f"\n--- 模拟的 Logits (形状: {mock_logits_trained.shape}) ---")
print("Logits (每个位置都在高置信度预测 PAD_ID=0):")
# .softmax(-1) 将 logits 转换为概率,[..., :4] 仅为演示
# 我们只看第一个样本的前3个 token
for i in range(3):
print(f" 样本 0, Token {i} 的概率: {mock_logits_trained[0, i].softmax(-1).numpy().round(2)}")
# --- 6. 比较两种损失 ---
# (A) 计算正确损失 (只计算"回应"部分)
loss_correct = F.cross_entropy(
mock_logits_trained.view(-1, VOCAB_SIZE), # (12, 4)
labels_correct.view(-1), # (12,)
ignore_index=IGNORE_INDEX
)
print(f"\n(A) [正确屏蔽] 的损失: {loss_correct.item():.4f}")
# (B) 计算错误损失 (计算所有 Token)
# 注意:我们必须设置 ignore_index=-1 (一个不存在的ID)
# 否则它会默认使用 -100,这会意外地屏蔽掉我们 (A) 中设置的 -100
loss_unmasked = F.cross_entropy(
mock_logits_trained.view(-1, VOCAB_SIZE), # (12, 4)
labels_unmasked.view(-1), # (12,)
ignore_index=-1 # 关键:不忽略任何东西
)
print(f"(B) [未屏蔽] 的损失: {loss_unmasked.item():.4f}")
结果分析 (Analysis of Results)
运行上述代码时,会得到类似下面的输出:
lua
--- 准备数据 ---
Input IDs:
tensor([[1, 1, 1, 2, 2, 3],
[0, 0, 1, 1, 2, 3]])
(A) 正确屏蔽的 Labels:
tensor([[-100, -100, -100, 2, 2, 3],
[-100, -100, -100, -100, 2, 3]])
(B) 未屏蔽的 Labels:
tensor([[1, 1, 1, 2, 2, 3],
[0, 0, 1, 1, 2, 3]])
--- 模拟的 Logits (形状: torch.Size([2, 6, 4])) ---
Logits (每个位置都在高置信度预测 PAD_ID=0):
样本 0, Token 0 的概率: [1. 0. 0. 0.]
样本 0, Token 1 的概率: [1. 0. 0. 0.]
样本 0, Token 2 的概率: [1. 0. 0. 0.]
(A) [正确屏蔽] 的损失: 10.9415
(B) [未屏蔽] 的损失: 8.9512
为什么损失值差异如此巨大?
-
Logits 在做什么?
我们的 mock_logits_trained 在所有 12 个位置 (2*6) 上都在高置信度地预测 Token 0 (即 PAD_ID)。你可以从 [1. 0. 0. 0.] 的概率中看到这一点。
-
(A)
loss_correct(10.94) 是如何计算的?- 它查看
labels_correct,发现只有 4 个有效标签([2, 2, 3]和[2, 3])。 - 在这些位置上 ,模型预测
0,但真实标签是2或3。 - 这是严重的不匹配 !损失函数对这种"指鹿为马"的行为给予了极高的惩罚(损失
~10.94)。 - 这是正确的:它告诉模型:"你在回应部分预测 PAD 是大错特错的!"
- 它查看
-
(B)
loss_unmasked(8.95) 是如何计算的?-
它查看
labels_unmasked,计算所有 12 个位置的损失。 -
对于填充位置(样本 2 的前 2 个 Token):
- 模型预测
0(PAD_ID)。 - 真实标签也是
0(PAD_ID)。 - 完美匹配! 这些位置的损失几乎为 0。
- 模型预测
-
对于指令位置(例如样本 1 的前 3 个 Token):
- 模型预测
0(PAD_ID)。 - 真实标签是
1(PROMPT_ID)。 - 不匹配,损失很高。
- 模型预测
-
对于回应位置:
- 模型预测
0(PAD_ID)。 - 真实标签是
2或3。 - 不匹配,损失很高。
- 模型预测
-
总损失 :
loss_unmasked是mean(loss_padding + loss_prompt + loss_response)。 -
因为
loss_padding几乎为 0,这几个 0 **拉低(稀释)**了总体的平均损失。
-
我们可以清晰地看到:
loss_correct(高损失) 准确地反映了模型在**"回应"**任务上的糟糕表现。loss_unmasked(低损失) 被模型在**"填充"**任务上的"良好"表现(猜对了 PAD)所污染,给出了一个虚假的、偏低的损失值,这会严重误导梯度更新。
4. 总结
在 SFT 中,正确屏蔽 labels 不是一个可选项,而是保证模型在正确的目标上进行优化的必要步骤。
- 始终屏蔽 Padding Token:避免模型学习预测填充符。
- 始终屏蔽 Instruction Token :避免模型"复述问题",防止"损失稀释",强制模型专注于 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( R e s p o n s e ∣ I n s t r u c t i o n ) P(Response | Instruction) </math>P(Response∣Instruction)。
使用 ignore_index=-100 是实现这一目标最直接和最高效的方法。