LONGWRITER: UNLEASHING 10,000+ WORD GENERATION FROM LONG CONTEXT LLMS
一句话总结:
加入长输出的数据进行训练微调,即可解锁大模型的长输出能力。
摘要
- 当前的LLMs能够处理100,000个token的输入,但在难以生成超过2,000个词。实验发现模型的有效生成长度本质上受到其在监督式微调(SFT)期间所见过的样本的限制。
- 为了解决这个问题,我们引入了AgentWrite,它将超长生成任务分解为子任务,使现成的LLMs能够生成超过20,000个词的连贯输出。
利用AgentWrite,我们构建了一个包含6,000个SFT数据的LongWriter-6k数据集,输出长度范围在2k到32k个词之间。 - 通过将这个数据集纳入模型训练,我们将现有模型的输出长度扩展到超过10,000个词,同时保持了输出质量。
- 我们还开发了LongBench-Write,一个全面的基准测试,用于评估超长生成能力。
- 9B参数模型,通过DPO,在这个基准上实现了最先进的性能,甚至超过了更大的专有模型。
- 总的来说,现有的长上下文LLM已经拥有更大的输出窗口的潜力------你所需要的只是在模型对齐期间用扩展输出的数据来解锁这个能力。
第1章 引言
介绍了长上下文大型语言模型(LLMs)的最新进展以及它们在生成长文本方面的局限性。
第2章 发现生成长度限制的原因
通过构建LongWrite-Ruler评估来探测LLMs的生成长度限制,并探索了造成它们生成长度限制的原因。
结论:llm输出一般为2k个左右,训练数据的长度增加,输出长度也会增加。
第3章 AGENTWRITE: 自动数据构建
为了使用现成的LLMs自动生成具有更长输出的SFT数据,我们设计了AgentWrite,这是一个分而治之风格的代理流水线。逐章节的去输出内容,再合并即可。
第4章 LONGWRITER: 教会模型生成超长输出
数据分布
- 经过自动选择和人工检查,确保指令确实需要几千个词的响应。
- 数据经过过滤和清理,以确保输出长度和质量。
训练时同时使用SFT180K数据和LongWriter-6k
模型训练
SFT
选择了两个最新的开源模型作为基础模型:GLM-4-9B和Llama-3.1-8B。
- 这些模型支持高达128k tokens的上下文窗口,适合于长输出任务的训练。
- 采用了打包训练(packing training)和损失加权(loss weighting)策略来提高训练效率。
损失加权策略
- 通过按token平均损失而非按batch句子损失,确保长输出数据中每个目标token对损失的贡献是均衡的。因为如果按照句子计算平均损失,在长输出的情况下,会弱化每一个token对target的loss
硬件和配置
- 使用了8个H800 80G GPU的节点进行训练,结合DeepSpeed和ZeRO以及CPU卸载技术。
- 设置了批量大小为8,学习率为1e-5,打包长度为32k。一共训了4 epochs, 差不多2,500-3,000 steps.
训练周期
- 模型训练了4个epoch,大约需要2500到3000步。
DPO
数据来源之一
- 使用了来自GLM-4的聊天DPO数据,以及特别为长形式写作指令构建的4k数据对。另外
数据来源之二 自己构建的4k数据
- 对long-writer输出4个样本打分,选择了得分最高的作为正样本,随机选择其他输出之一作为负样本。
- DPO训练了250步,遵循了Hou等人(Chatglm-rlhf)的DPO训练方法。
结果
- 训练后的模型LongWriter-9B和LongWriter-9B-DPO在LongBench-Write基准测试上表现出色,特别是在生成超过2,000个词的长文本方面。
- DPO显著提高了模型遵循长度指令的能力,并且提升了输出的整体质量。
主要结果
- 对比了4个私有模型和5个开源模型在LongBench-Write上的性能,以及训练的LongWriter模型。
- 观察到大多数现有模型无法满足超过2,000个词的长度要求,而LongWriter模型能够提供更长、更丰富的响应。
结果分析
- LongWriter模型在所有长度范围内的输出长度得分(Sl)和质量得分(Sq)上均表现出色。
- DPO进一步提高了模型的输出质量和遵循长度约束的能力。
消融研究
- 进行了三项数据消融实验,以评估不同数据集对模型性能的影响。
消融LongWriter-6k数据集
- 证明了LongWriter-6k数据集对提高模型处理长文本能力的重要性。
计划增强输出数据
- 探讨了在生成内容前先输出写作计划对模型性能的影响。类似于CoT,但是对质量并没有什么提升。
与指令回传合成数据的比较
分析了使用指令回传方法构建的长输出SFT数据对模型性能的影响。发现对结果是有害的。作者分析两点原因:
- 指令回传数据质量比较低
- 回传的指令与用户指令有gap,不一致