【LLM微调】拒绝“假装聪明”:SFTTrainer 中 completion_only_loss 新旧版本用法详解

在指令微调(SFT)大模型时,你是否遇到过:训练 Loss 迅速降到 0.0x,准确率飙升到 99%,但模型实际推理效果却很差?

这通常是因为模型在"作弊"------它背下了固定的 System Prompt 和 User Instruction,而不是在学习如何回答问题。本文介绍如何使用 completion_only_loss 强制模型只学习"回答"部分,并重点对比 trl 库在 0.20.0 版本 前后的巨大用法差异。


1. 为什么要用 completion_only_loss

默认情况下,Hugging Face 的 SFTTrainer 对输入的所有文本计算 Loss。

  • 如果不加限制

    • 计算范围[System Prompt] + [User Instruction] + [Assistant Answer] 全选。
    • 后果:如果 Prompt 很长且固定,模型会优先背诵这段长文本,导致 Loss 虚低,掩盖了模型在 Answer 部分的无能。
  • 加上 completion_only_loss

    • 计算范围 :仅计算 [Assistant Answer] 部分。
    • 效果:模型必须全力以赴学习生成逻辑,训练更加有效、真实。

2. 核心区别:0.20.0 版本的分界线

trl 库从 0.20.0 开始,极大地简化了这一流程。

特性 旧版 (trl < 0.20.0) 新版 (trl >= 0.20.0) 🚀
实现方式 手动导入 DataCollator 直接配置 SFTConfig 参数
数据格式 返回拼接好的长字符串 拆分为 promptcompletion 字段
处理时机 formatting_func 给 Trainer 训练前用 dataset.map 预处理
复杂度 高(需手动指定分隔符) 低(自动识别,不易出错)

3. 旧版写法 (trl < 0.20.0)

特点 :需要手动引入 DataCollatorForCompletionOnlyLM 并指定分隔符。在新版本中,该类的 import 路径经常变动,容易报错。

python 复制代码
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

# 1. 定义分隔符
response_template = "<|im_start|>assistant\n"
# 2. 初始化 Collator
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
# 3. 传给 Trainer
trainer = SFTTrainer(..., formatting_func=my_text_func, data_collator=collator)

4. 新版写法 (trl >= 0.20.0) ------ 推荐!

特点:配置化,逻辑清晰。

第一步:修改数据处理函数

不再返回单一字符串,而是返回字典,区分 prompt (题目) 和 completion (答案)。

python 复制代码
def format_prompts(examples):
    output_dicts = {"prompt": [], "completion": []}
    for i in range(len(examples['instruction'])):
        # 1. 构建 Prompt (System + User)
        # 技巧:add_generation_prompt=True 会自动加上 assistant 的开头 tag
        prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
        # 2. 构建 Completion (Assistant 回答 + EOS)
        completion_text = output_json_str + tokenizer.eos_token
        
        output_dicts["prompt"].append(prompt_text)
        output_dicts["completion"].append(completion_text)
    return output_dicts

第二步:提前处理 Dataset

注意:不要把函数传给 Trainer,要在外面先 map 好!

python 复制代码
# 必须移除旧列,否则 Trainer 会报错
dataset = dataset.map(format_prompts, batched=True, remove_columns=dataset.column_names)

第三步:配置 SFTConfig

直接开启开关,并关闭冲突参数。

python 复制代码
from trl import SFTConfig, SFTTrainer

args = SFTConfig(
    output_dir="./output",
    # === 核心配置 ===
    completion_only_loss=True,  # ✅ 开启只计算回答 Loss
    packing=False,              # ❌ 必须关闭!packing 与此模式互斥
    # ================
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,      # 传入处理好的 dataset
    # formatting_func=...       # ❌ 千万别传这个,前面已经 map 过了
    args=args
)

5. 避坑指南

  1. 互斥参数completion_only_loss=True 时,必须设置 packing=False
  2. 报错 ValueError :如果你在新版中同时开启了 completion_only_loss 又传入了 formatting_func,会直接报错。请务必使用 dataset.map 提前处理。
  3. 指标变化 :开启后,Loss 变高、Accuracy 变低是正常的。这代表模型开始啃"硬骨头"了,而不是在"背书"。
相关推荐
ariesjzj1 小时前
DeepSeek时代的Large-scale LLM推理
大模型·llm·deepseek·推理优化·大规模ep
智泊AI3 小时前
长上下文、Agent记忆、Text2SQL中,谁会取代RAG?
llm
赋范大模型技术社区11 小时前
大模型训练的“最后一公里”:为什么强化学习(RL)不可或缺?
大模型·微调·sft·模型训练·rl
CoderJia程序员甲12 小时前
GitHub 热榜项目 - 日榜(2025-12-7)
git·ai·开源·llm·github
大模型教程13 小时前
小猫都能懂的大模型原理 3 - 自注意力机制
程序员·llm·agent
大模型教程13 小时前
小猫都能懂的大模型原理 2 - 初见大语言模型
程序员·llm·agent
leo030813 小时前
深度解析Hugging Face Accelerate:`Trainer`背后的“隐形”分布式引擎
pytorch·大模型·llm·ddp
AI大模型13 小时前
2025最新大模型技术学习路线:从入门到精通,一篇文章全掌握
程序员·llm·agent
AI大模型14 小时前
AI大模型学习路线,带你6周成为大模型工程师!
程序员·llm·agent