DataCollator

总体对比表

特性 DataCollatorForSeq2Seq DataCollatorWithPadding DataCollatorForLanguageModeling
适用模型 Encoder-Decoder(如 T5, BART, Qwen2-Audio) 任意(分类、回归、Encoder-only) Decoder-only(GPT, Llama, Qwen)或 MLM(BERT)
任务类型 序列到序列(翻译、摘要、对话生成) 分类 / 回归 / 特征提取 语言建模(Causal LM 或 Masked LM)
是否自动创建 labels 是(从 labels 字段复制并 padding) 否(需 dataset 提供 labels) 是(Causal LM: labels = input_ids;MLM: 随机 mask)
是否支持 mlm 不支持 不支持 支持(mlm=True/False
典型输入字段 input_ids, labels input_ids, labels(可选) input_ids(自动生成 labels)
padding 方向 input 和 labels 分别 padding 仅对 input 做 padding 对 input_ids 和 labels 同步 padding
是否适合 SFT(指令微调) Encoder-Decoder 模型 不适合生成任务 Decoder-only 模型(但需注意 loss 范围)

逐个详解

1.DataCollatorForSeq2Seq(tokenizer=tokenizer)

用途:

专为 Encoder-Decoder 架构(如 T5、BART、Flan-T5、Qwen2-VL 的文本部分)设计。

工作流程:
  • 输入 batch 包含:

    复制代码
    {
      "input_ids": [...],      # 编码器输入(如 "Translate to English: 你好")
      "labels": [...]          # 解码器目标(如 "Hello")
    }
  • 自动对 input_idslabels 分别做 padding(因为长度通常不同)

  • labels 中的 pad token 替换为 -100(PyTorch CrossEntropyLoss 忽略 -100)

典型参数:
复制代码
collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,             # 可选,用于获取 decoder_start_token_id
    padding=True,
    max_length=512,
    pad_to_multiple_of=8
)

注意 :如果你用 Qwen3(纯 Decoder)不要用这个 collator


2.DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=512, pad_to_multiple_of=8)

用途:

通用 padding 工具,适用于 非生成式任务,如:

  • 文本分类(sentiment)
  • NER(token classification)
  • 句子对匹配(STS)
工作流程:
  • 仅对 input_idsattention_masktoken_type_ids 等做 padding
  • 不会修改或生成 labels ------ labels 必须由 dataset 提供,且 collator 只是"透传"并可能 padding(如果是序列标注)
关键限制:
  • 不适用于语言建模或生成任务 ,因为它不会设置 labels = input_ids,也不会处理自回归掩码。
示例(文本分类):
复制代码
dataset = [{"input_ids": [101, 234, 567, 102], "labels": 1}, ...]
batch = collator(dataset)
# output: {"input_ids": padded_tensor, "attention_mask": ..., "labels": [1, 0, 1, ...]}

3.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

用途:

专为 语言模型训练 设计:

  • mlm=False → Causal LM(GPT-style,如 Llama, Qwen)
  • mlm=True → Masked LM(BERT-style)
mlm=False(你的情况):
  • 输入只需 input_ids
  • 自动设置 labels = input_ids.clone()
  • padding 时,input_idslabels 同步填充 ,pad 位置在 labels 中设为 -100
  • 但注意 :它会对整个序列计算 loss(包括 user 输入),不适合标准 SFT
SFT 场景下的问题:
  • 在指令微调中,我们通常只希望模型学习 assistant 的回复
  • 此 collator 会让模型也去预测 "user\n你好" 这部分,导致训练目标混乱

解决方案 :预处理时手动构造 labels,将 user/system 部分设为 -100,然后仍可用此 collator(因它会保留已有 labels)


如何选择?------ 根据你的模型和任务

你的模型 你的任务 推荐 Collator
Qwen3 / Llama / Phi(纯 Decoder) 指令微调(SFT) DataCollatorForLanguageModeling(mlm=False) 但需预处理 labels(user 部分 = -100)
T5 / BART / UL2 翻译、摘要、问答生成 DataCollatorForSeq2Seq
BERT / RoBERTa 文本分类、NER DataCollatorWithPadding
Qwen3(Decoder) 全序列语言建模(非 SFT) DataCollatorForLanguageModeling(mlm=False)

最佳实践建议(针对 Qwen3 SFT)

复制代码
# 1. 预处理:构造 input_ids 和 labels(labels 中 user 部分为 -100)
def preprocess(example):
    messages = example["messages"]
    input_ids = tokenizer.apply_chat_template(messages, tokenize=True)
    labels = input_ids.copy()
    
    # 找到 assistant 内容的位置,其余设为 -100
    # (可通过模板解析或简单规则:从最后一个 <|im_start|>assistant 开始保留)
    ...
    return {"input_ids": input_ids, "labels": labels}

# 2. 使用 DataCollatorForLanguageModeling(它会正确 padding labels)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8
)

这样既利用了 HF 的高效 padding,又确保 loss 只在 assistant 回复上计算。

相关推荐
Hcoco_me1 小时前
大模型面试题40:结合RoPE位置编码、优秀位置编码的核心特性
人工智能·深度学习·lstm·transformer·word2vec
Hcoco_me1 小时前
大模型面试题37:Scaling Law完全指南
人工智能·深度学习·学习·自然语言处理·transformer
高洁011 小时前
10分钟了解向量数据库(1)
python·深度学习·机器学习·transformer·知识图谱
vibag1 小时前
RAG向量数据库
python·语言模型·langchain·大模型
Hcoco_me2 小时前
大模型面试题41:RoPE改进的核心目标与常见方法
开发语言·人工智能·深度学习·自然语言处理·transformer·word2vec
Zyx20072 小时前
让大模型“记住你”:LangChain 中的对话记忆机制实战
langchain
Hcoco_me2 小时前
大模型面试题39:KV Cache 完全指南
人工智能·深度学习·自然语言处理·transformer·word2vec
斯外戈的小白3 小时前
【NLP】Transformer在pytorch 的实现+情感分析案例+生成式任务案例
pytorch·自然语言处理·transformer
vibag3 小时前
RAG项目实践
python·语言模型·langchain·大模型
Fuxiao___18 小时前
Transformer知识点答疑
transformer