在指令微调(SFT)大模型时,你是否遇到过:训练 Loss 迅速降到 0.0x,准确率飙升到 99%,但模型实际推理效果却很差?
这通常是因为模型在"作弊"------它背下了固定的 System Prompt 和 User Instruction,而不是在学习如何回答问题。本文介绍如何使用 completion_only_loss 强制模型只学习"回答"部分,并重点对比 trl 库在 0.20.0 版本 前后的巨大用法差异。
1. 为什么要用 completion_only_loss?
默认情况下,Hugging Face 的 SFTTrainer 对输入的所有文本计算 Loss。
-
如果不加限制:
- 计算范围 :
[System Prompt] + [User Instruction] + [Assistant Answer]全选。 - 后果:如果 Prompt 很长且固定,模型会优先背诵这段长文本,导致 Loss 虚低,掩盖了模型在 Answer 部分的无能。
- 计算范围 :
-
加上
completion_only_loss:- 计算范围 :仅计算
[Assistant Answer]部分。 - 效果:模型必须全力以赴学习生成逻辑,训练更加有效、真实。
- 计算范围 :仅计算
2. 核心区别:0.20.0 版本的分界线
trl 库从 0.20.0 开始,极大地简化了这一流程。
| 特性 | 旧版 (trl < 0.20.0) |
新版 (trl >= 0.20.0) 🚀 |
|---|---|---|
| 实现方式 | 手动导入 DataCollator 类 |
直接配置 SFTConfig 参数 |
| 数据格式 | 返回拼接好的长字符串 | 拆分为 prompt 和 completion 字段 |
| 处理时机 | 传 formatting_func 给 Trainer |
训练前用 dataset.map 预处理 |
| 复杂度 | 高(需手动指定分隔符) | 低(自动识别,不易出错) |
3. 旧版写法 (trl < 0.20.0)
特点 :需要手动引入 DataCollatorForCompletionOnlyLM 并指定分隔符。在新版本中,该类的 import 路径经常变动,容易报错。
python
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
# 1. 定义分隔符
response_template = "<|im_start|>assistant\n"
# 2. 初始化 Collator
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
# 3. 传给 Trainer
trainer = SFTTrainer(..., formatting_func=my_text_func, data_collator=collator)
4. 新版写法 (trl >= 0.20.0) ------ 推荐!
特点:配置化,逻辑清晰。
第一步:修改数据处理函数
不再返回单一字符串,而是返回字典,区分 prompt (题目) 和 completion (答案)。
python
def format_prompts(examples):
output_dicts = {"prompt": [], "completion": []}
for i in range(len(examples['instruction'])):
# 1. 构建 Prompt (System + User)
# 技巧:add_generation_prompt=True 会自动加上 assistant 的开头 tag
prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
# 2. 构建 Completion (Assistant 回答 + EOS)
completion_text = output_json_str + tokenizer.eos_token
output_dicts["prompt"].append(prompt_text)
output_dicts["completion"].append(completion_text)
return output_dicts
第二步:提前处理 Dataset
注意:不要把函数传给 Trainer,要在外面先 map 好!
python
# 必须移除旧列,否则 Trainer 会报错
dataset = dataset.map(format_prompts, batched=True, remove_columns=dataset.column_names)
第三步:配置 SFTConfig
直接开启开关,并关闭冲突参数。
python
from trl import SFTConfig, SFTTrainer
args = SFTConfig(
output_dir="./output",
# === 核心配置 ===
completion_only_loss=True, # ✅ 开启只计算回答 Loss
packing=False, # ❌ 必须关闭!packing 与此模式互斥
# ================
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset, # 传入处理好的 dataset
# formatting_func=... # ❌ 千万别传这个,前面已经 map 过了
args=args
)
5. 避坑指南
- 互斥参数 :
completion_only_loss=True时,必须设置packing=False。 - 报错 ValueError :如果你在新版中同时开启了
completion_only_loss又传入了formatting_func,会直接报错。请务必使用dataset.map提前处理。 - 指标变化 :开启后,Loss 变高、Accuracy 变低是正常的。这代表模型开始啃"硬骨头"了,而不是在"背书"。