总体对比表
| 特性 | DataCollatorForSeq2Seq |
DataCollatorWithPadding |
DataCollatorForLanguageModeling |
|---|---|---|---|
| 适用模型 | Encoder-Decoder(如 T5, BART, Qwen2-Audio) | 任意(分类、回归、Encoder-only) | Decoder-only(GPT, Llama, Qwen)或 MLM(BERT) |
| 任务类型 | 序列到序列(翻译、摘要、对话生成) | 分类 / 回归 / 特征提取 | 语言建模(Causal LM 或 Masked LM) |
| 是否自动创建 labels | 是(从 labels 字段复制并 padding) |
否(需 dataset 提供 labels) | 是(Causal LM: labels = input_ids;MLM: 随机 mask) |
是否支持 mlm |
不支持 | 不支持 | 支持(mlm=True/False) |
| 典型输入字段 | input_ids, labels |
input_ids, labels(可选) |
input_ids(自动生成 labels) |
| padding 方向 | input 和 labels 分别 padding | 仅对 input 做 padding | 对 input_ids 和 labels 同步 padding |
| 是否适合 SFT(指令微调) | Encoder-Decoder 模型 | 不适合生成任务 | Decoder-only 模型(但需注意 loss 范围) |
逐个详解
1.DataCollatorForSeq2Seq(tokenizer=tokenizer)
用途:
专为 Encoder-Decoder 架构(如 T5、BART、Flan-T5、Qwen2-VL 的文本部分)设计。
工作流程:
-
输入 batch 包含:
{ "input_ids": [...], # 编码器输入(如 "Translate to English: 你好") "labels": [...] # 解码器目标(如 "Hello") } -
自动对
input_ids和labels分别做 padding(因为长度通常不同) -
将
labels中的 pad token 替换为-100(PyTorch CrossEntropyLoss 忽略 -100)
典型参数:
collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model, # 可选,用于获取 decoder_start_token_id
padding=True,
max_length=512,
pad_to_multiple_of=8
)
注意 :如果你用 Qwen3(纯 Decoder) ,不要用这个 collator!
2.DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=512, pad_to_multiple_of=8)
用途:
通用 padding 工具,适用于 非生成式任务,如:
- 文本分类(sentiment)
- NER(token classification)
- 句子对匹配(STS)
工作流程:
- 仅对
input_ids、attention_mask、token_type_ids等做 padding - 不会修改或生成
labels------ labels 必须由 dataset 提供,且 collator 只是"透传"并可能 padding(如果是序列标注)
关键限制:
- 不适用于语言建模或生成任务 ,因为它不会设置 labels = input_ids,也不会处理自回归掩码。
示例(文本分类):
dataset = [{"input_ids": [101, 234, 567, 102], "labels": 1}, ...]
batch = collator(dataset)
# output: {"input_ids": padded_tensor, "attention_mask": ..., "labels": [1, 0, 1, ...]}
3.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
用途:
专为 语言模型训练 设计:
mlm=False→ Causal LM(GPT-style,如 Llama, Qwen)mlm=True→ Masked LM(BERT-style)
当 mlm=False(你的情况):
- 输入只需
input_ids - 自动设置
labels = input_ids.clone() - padding 时,
input_ids和labels同步填充 ,pad 位置在 labels 中设为-100 - 但注意 :它会对整个序列计算 loss(包括 user 输入),不适合标准 SFT
SFT 场景下的问题:
- 在指令微调中,我们通常只希望模型学习 assistant 的回复
- 此 collator 会让模型也去预测
"user\n你好"这部分,导致训练目标混乱
解决方案 :预处理时手动构造
labels,将 user/system 部分设为-100,然后仍可用此 collator(因它会保留已有 labels)
如何选择?------ 根据你的模型和任务
| 你的模型 | 你的任务 | 推荐 Collator |
|---|---|---|
| Qwen3 / Llama / Phi(纯 Decoder) | 指令微调(SFT) | DataCollatorForLanguageModeling(mlm=False) 但需预处理 labels(user 部分 = -100) |
| T5 / BART / UL2 | 翻译、摘要、问答生成 | DataCollatorForSeq2Seq |
| BERT / RoBERTa | 文本分类、NER | DataCollatorWithPadding |
| Qwen3(Decoder) | 全序列语言建模(非 SFT) | DataCollatorForLanguageModeling(mlm=False) |
最佳实践建议(针对 Qwen3 SFT)
# 1. 预处理:构造 input_ids 和 labels(labels 中 user 部分为 -100)
def preprocess(example):
messages = example["messages"]
input_ids = tokenizer.apply_chat_template(messages, tokenize=True)
labels = input_ids.copy()
# 找到 assistant 内容的位置,其余设为 -100
# (可通过模板解析或简单规则:从最后一个 <|im_start|>assistant 开始保留)
...
return {"input_ids": input_ids, "labels": labels}
# 2. 使用 DataCollatorForLanguageModeling(它会正确 padding labels)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=8
)
这样既利用了 HF 的高效 padding,又确保 loss 只在 assistant 回复上计算。