DataCollator

总体对比表

特性 DataCollatorForSeq2Seq DataCollatorWithPadding DataCollatorForLanguageModeling
适用模型 Encoder-Decoder(如 T5, BART, Qwen2-Audio) 任意(分类、回归、Encoder-only) Decoder-only(GPT, Llama, Qwen)或 MLM(BERT)
任务类型 序列到序列(翻译、摘要、对话生成) 分类 / 回归 / 特征提取 语言建模(Causal LM 或 Masked LM)
是否自动创建 labels 是(从 labels 字段复制并 padding) 否(需 dataset 提供 labels) 是(Causal LM: labels = input_ids;MLM: 随机 mask)
是否支持 mlm 不支持 不支持 支持(mlm=True/False
典型输入字段 input_ids, labels input_ids, labels(可选) input_ids(自动生成 labels)
padding 方向 input 和 labels 分别 padding 仅对 input 做 padding 对 input_ids 和 labels 同步 padding
是否适合 SFT(指令微调) Encoder-Decoder 模型 不适合生成任务 Decoder-only 模型(但需注意 loss 范围)

逐个详解

1.DataCollatorForSeq2Seq(tokenizer=tokenizer)

用途:

专为 Encoder-Decoder 架构(如 T5、BART、Flan-T5、Qwen2-VL 的文本部分)设计。

工作流程:
  • 输入 batch 包含:

    复制代码
    {
      "input_ids": [...],      # 编码器输入(如 "Translate to English: 你好")
      "labels": [...]          # 解码器目标(如 "Hello")
    }
  • 自动对 input_idslabels 分别做 padding(因为长度通常不同)

  • labels 中的 pad token 替换为 -100(PyTorch CrossEntropyLoss 忽略 -100)

典型参数:
复制代码
collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,             # 可选,用于获取 decoder_start_token_id
    padding=True,
    max_length=512,
    pad_to_multiple_of=8
)

注意 :如果你用 Qwen3(纯 Decoder)不要用这个 collator


2.DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=512, pad_to_multiple_of=8)

用途:

通用 padding 工具,适用于 非生成式任务,如:

  • 文本分类(sentiment)
  • NER(token classification)
  • 句子对匹配(STS)
工作流程:
  • 仅对 input_idsattention_masktoken_type_ids 等做 padding
  • 不会修改或生成 labels ------ labels 必须由 dataset 提供,且 collator 只是"透传"并可能 padding(如果是序列标注)
关键限制:
  • 不适用于语言建模或生成任务 ,因为它不会设置 labels = input_ids,也不会处理自回归掩码。
示例(文本分类):
复制代码
dataset = [{"input_ids": [101, 234, 567, 102], "labels": 1}, ...]
batch = collator(dataset)
# output: {"input_ids": padded_tensor, "attention_mask": ..., "labels": [1, 0, 1, ...]}

3.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

用途:

专为 语言模型训练 设计:

  • mlm=False → Causal LM(GPT-style,如 Llama, Qwen)
  • mlm=True → Masked LM(BERT-style)
mlm=False(你的情况):
  • 输入只需 input_ids
  • 自动设置 labels = input_ids.clone()
  • padding 时,input_idslabels 同步填充 ,pad 位置在 labels 中设为 -100
  • 但注意 :它会对整个序列计算 loss(包括 user 输入),不适合标准 SFT
SFT 场景下的问题:
  • 在指令微调中,我们通常只希望模型学习 assistant 的回复
  • 此 collator 会让模型也去预测 "user\n你好" 这部分,导致训练目标混乱

解决方案 :预处理时手动构造 labels,将 user/system 部分设为 -100,然后仍可用此 collator(因它会保留已有 labels)


如何选择?------ 根据你的模型和任务

你的模型 你的任务 推荐 Collator
Qwen3 / Llama / Phi(纯 Decoder) 指令微调(SFT) DataCollatorForLanguageModeling(mlm=False) 但需预处理 labels(user 部分 = -100)
T5 / BART / UL2 翻译、摘要、问答生成 DataCollatorForSeq2Seq
BERT / RoBERTa 文本分类、NER DataCollatorWithPadding
Qwen3(Decoder) 全序列语言建模(非 SFT) DataCollatorForLanguageModeling(mlm=False)

最佳实践建议(针对 Qwen3 SFT)

复制代码
# 1. 预处理:构造 input_ids 和 labels(labels 中 user 部分为 -100)
def preprocess(example):
    messages = example["messages"]
    input_ids = tokenizer.apply_chat_template(messages, tokenize=True)
    labels = input_ids.copy()
    
    # 找到 assistant 内容的位置,其余设为 -100
    # (可通过模板解析或简单规则:从最后一个 <|im_start|>assistant 开始保留)
    ...
    return {"input_ids": input_ids, "labels": labels}

# 2. 使用 DataCollatorForLanguageModeling(它会正确 padding labels)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8
)

这样既利用了 HF 的高效 padding,又确保 loss 只在 assistant 回复上计算。

相关推荐
西西弗Sisyphus15 小时前
基于 Transformer 架构的翻译模型实践 - SentencePiece 分词的例子
transformer
li星野16 小时前
Transformer 核心模块详解:多头注意力、前馈网络与词嵌入
人工智能·深度学习·transformer
小新同学^O^16 小时前
简单学习 --> LangChain
python·学习·langchain
晚霞的不甘17 小时前
CANN-ATB加速库:Transformer推理性能密码
人工智能·深度学习·transformer
Restart-AHTCM17 小时前
LangChain学习之提示词模板 (Prompts) - 练习(2/8)
学习·langchain
TechWayfarer17 小时前
AI大模型时代:IP数据云如何适配智能体场景需求
开发语言·人工智能·python·网络协议·tcp/ip·langchain
爱编程的小新☆19 小时前
LangGraph4j工作流框架
前端·数据库·ai·langchain·langgraph4j
高洁0119 小时前
中国人工智能培训网—AI系列录播课
人工智能·机器学习·数据挖掘·transformer·知识图谱
wuxinyan12320 小时前
工业级大模型学习之路020:LangChain零基础入门教程(第三篇):提示词工程与提示模板系统
人工智能·python·学习·langchain