DataCollator

总体对比表

特性 DataCollatorForSeq2Seq DataCollatorWithPadding DataCollatorForLanguageModeling
适用模型 Encoder-Decoder(如 T5, BART, Qwen2-Audio) 任意(分类、回归、Encoder-only) Decoder-only(GPT, Llama, Qwen)或 MLM(BERT)
任务类型 序列到序列(翻译、摘要、对话生成) 分类 / 回归 / 特征提取 语言建模(Causal LM 或 Masked LM)
是否自动创建 labels 是(从 labels 字段复制并 padding) 否(需 dataset 提供 labels) 是(Causal LM: labels = input_ids;MLM: 随机 mask)
是否支持 mlm 不支持 不支持 支持(mlm=True/False
典型输入字段 input_ids, labels input_ids, labels(可选) input_ids(自动生成 labels)
padding 方向 input 和 labels 分别 padding 仅对 input 做 padding 对 input_ids 和 labels 同步 padding
是否适合 SFT(指令微调) Encoder-Decoder 模型 不适合生成任务 Decoder-only 模型(但需注意 loss 范围)

逐个详解

1.DataCollatorForSeq2Seq(tokenizer=tokenizer)

用途:

专为 Encoder-Decoder 架构(如 T5、BART、Flan-T5、Qwen2-VL 的文本部分)设计。

工作流程:
  • 输入 batch 包含:

    复制代码
    {
      "input_ids": [...],      # 编码器输入(如 "Translate to English: 你好")
      "labels": [...]          # 解码器目标(如 "Hello")
    }
  • 自动对 input_idslabels 分别做 padding(因为长度通常不同)

  • labels 中的 pad token 替换为 -100(PyTorch CrossEntropyLoss 忽略 -100)

典型参数:
复制代码
collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,             # 可选,用于获取 decoder_start_token_id
    padding=True,
    max_length=512,
    pad_to_multiple_of=8
)

注意 :如果你用 Qwen3(纯 Decoder)不要用这个 collator


2.DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=512, pad_to_multiple_of=8)

用途:

通用 padding 工具,适用于 非生成式任务,如:

  • 文本分类(sentiment)
  • NER(token classification)
  • 句子对匹配(STS)
工作流程:
  • 仅对 input_idsattention_masktoken_type_ids 等做 padding
  • 不会修改或生成 labels ------ labels 必须由 dataset 提供,且 collator 只是"透传"并可能 padding(如果是序列标注)
关键限制:
  • 不适用于语言建模或生成任务 ,因为它不会设置 labels = input_ids,也不会处理自回归掩码。
示例(文本分类):
复制代码
dataset = [{"input_ids": [101, 234, 567, 102], "labels": 1}, ...]
batch = collator(dataset)
# output: {"input_ids": padded_tensor, "attention_mask": ..., "labels": [1, 0, 1, ...]}

3.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

用途:

专为 语言模型训练 设计:

  • mlm=False → Causal LM(GPT-style,如 Llama, Qwen)
  • mlm=True → Masked LM(BERT-style)
mlm=False(你的情况):
  • 输入只需 input_ids
  • 自动设置 labels = input_ids.clone()
  • padding 时,input_idslabels 同步填充 ,pad 位置在 labels 中设为 -100
  • 但注意 :它会对整个序列计算 loss(包括 user 输入),不适合标准 SFT
SFT 场景下的问题:
  • 在指令微调中,我们通常只希望模型学习 assistant 的回复
  • 此 collator 会让模型也去预测 "user\n你好" 这部分,导致训练目标混乱

解决方案 :预处理时手动构造 labels,将 user/system 部分设为 -100,然后仍可用此 collator(因它会保留已有 labels)


如何选择?------ 根据你的模型和任务

你的模型 你的任务 推荐 Collator
Qwen3 / Llama / Phi(纯 Decoder) 指令微调(SFT) DataCollatorForLanguageModeling(mlm=False) 但需预处理 labels(user 部分 = -100)
T5 / BART / UL2 翻译、摘要、问答生成 DataCollatorForSeq2Seq
BERT / RoBERTa 文本分类、NER DataCollatorWithPadding
Qwen3(Decoder) 全序列语言建模(非 SFT) DataCollatorForLanguageModeling(mlm=False)

最佳实践建议(针对 Qwen3 SFT)

复制代码
# 1. 预处理:构造 input_ids 和 labels(labels 中 user 部分为 -100)
def preprocess(example):
    messages = example["messages"]
    input_ids = tokenizer.apply_chat_template(messages, tokenize=True)
    labels = input_ids.copy()
    
    # 找到 assistant 内容的位置,其余设为 -100
    # (可通过模板解析或简单规则:从最后一个 <|im_start|>assistant 开始保留)
    ...
    return {"input_ids": input_ids, "labels": labels}

# 2. 使用 DataCollatorForLanguageModeling(它会正确 padding labels)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8
)

这样既利用了 HF 的高效 padding,又确保 loss 只在 assistant 回复上计算。

相关推荐
西柚小萌新17 小时前
【人工智能:Agent】--9.1.Langchain内置中间件
langchain
小王努力学编程20 小时前
LangChain——AI应用开发框架(核心组件1)
linux·服务器·前端·数据库·c++·人工智能·langchain
孤狼warrior1 天前
图像生成 Stable Diffusion模型架构介绍及使用代码 附数据集批量获取
人工智能·python·深度学习·stable diffusion·cnn·transformer·stablediffusion
小王努力学编程1 天前
LangChain——AI应用开发框架(核心组件2)
linux·服务器·c++·人工智能·python·langchain·信号
GatiArt雷1 天前
AI 赋能 Python:基于 LLM + Pandas 的自动化数据清洗实操AI赋能Python数据清洗:基于LLM+Pandas的自动化实操
人工智能·langchain
GatiArt雷1 天前
AI自动化测试落地指南:基于LangChain+TestGPT的实操实现与效能验证
人工智能·langchain
楚来客1 天前
AI基础概念之十三:Transformer 算法结构相比传统神经网络的改进
深度学习·神经网络·transformer
老蒋每日coding1 天前
AI Agent 设计模式系列(十六)—— 资源感知优化设计模式
人工智能·设计模式·langchain
AI即插即用1 天前
即插即用系列 | AAAI 2025 Mesorch:CNN与Transformer的双剑合璧:基于频域增强与自适应剪枝的篡改定位
人工智能·深度学习·神经网络·计算机视觉·cnn·transformer·剪枝
i02081 天前
langChain
langchain