解密prompt系列27. LLM对齐经验之如何降低通用能力损失

前面我们已经聊过众多指令微调的方案,这一章我们重点讨论下如何注入某一类任务或能力的同时,尽可能不损失模型原有的通用指令理解能力。因为在下游或垂直领域应用中,我们设计的推理任务风格或形式,往往很难通过prompt来稳定实现。这种情况下就会依赖微调来注入稳定的推理风格,但同时在部分场景下我们又需要模型微调前的通用指令理解能力。虽然理论上说有得必有失,但成年人真的不想做选择!这里我们讨论两种方案,来尽可能降低通用能力的损失,一种数据方案,一种训练方案。

Two Stage Finetune - ProMoT

  • Preserving In-Context Learning ability in Large Language Model Fine-tuning
  • TWO-STAGE LLM FINE-TUNING WITH LESS SPECIALIZATION AND MORE GENERALIZATION

先说训练方案,论文先是分析了模型在微调过程中能力损失的一个主要原因是Format Specialization,也就是模型过拟合了微调任务的输出格式。举几个例子

  • 下游是二分类任务的,微调之后即便丢掉分类任务的指令,模型的输出还是True/False,损失了对话和其他指令理解能力
  • 下游是摘要任务的,微调之后即便丢掉TL;DL的总结指令,你让模型翻译模型还是给你总结。这个在当前推出的一些长文本无敌的基座模型上可能会看到,他们后期的指令微调中指令分布都比较偏向于长文本的QA和总结,其他能力会略弱
  • 通过微调让模型学习拒绝回答的,你会发现你的模型可能在很多不该拒绝的场景也拒绝你

其实核心就是我们本想让模型学习条件生成能力,也就是在分类指令时分类,摘要指令下摘要,该拒绝的场景下再拒绝。但论文通过分析发现在传统微调任务中,模型是先无脑拟合输出格式,例如二分类的True/False,拒绝里的对不起,在微调后期才开始学习input和ouput之间的关系(Semantic Relation),例如何时该分类,何时该拒绝。

那能否把下游任务的Task Format先行注入到额外的参数中,之后把这部分信息喂给模型,让大模型微调直接学习Semantic Relation,这样在稳定注入新的Format的前提下,尽可能不影响其他input的推理格式。

论文提出了两阶段微调,第一阶段也使用了谷歌提出的Prompt Tuning用来学习Format,第二阶段再进行全量微调。如下

第一阶段Prompt Tuning,简单说就是冻结大模型,只微调Embedding层的一组虚拟Token,这一组虚拟Token学习的就是下游推理任务Format的任务表征。

这里可以反过来思考,之前有论文提出任务指令(prompt)其实可以压缩成一个Task Vector的任务表征用来引导模型给出不同的推理输出;那反过来我们想让模型学习一种推理风格/任务,其实就是构建该Format对应的Task Vector,以及Task Vector对应的任务指令的过程。只不过prompt tunning的prompt使用的是虚拟Token。想更多了解Prompt Tuning的童鞋看这里解密Prompt系列3. 冻结LM微调Prompt: Prefix-tuning & Prompt-tuning & P-tuning

第二阶段Fine-tuning默认在输入层Embedding前拼接Prompt Embedding,b并冻结这部分Embedding,然后全量微调大模型, 让模型在已知输出格式的前提下,学习Input和Output格式之间的Semantic联系。之前有些疑惑这里为何要冻结prompt,后来又想了想应该是为了避免模型再把Task Format相关信息更新到模型内部参数中,但感觉不冻结的方式也值得测试下。

几个值得聊聊的细节

  1. 第一阶段微调能否用Lora,从Prompt实际学习的是推理格式的任务表征这个逻辑原理来说其实Adapter类的微调方案,似乎并不合理。论文测试后也发现Lora的效果并不好
  2. 能否把两个阶段合二为一,既加上一个虚拟Prompt,同时微调模型和prompt,论文测试后发现效果和SFT相差不多,都会有过拟合。毕竟这种微调方式无法引导模型把格式学到Prompt Embedding上。

效果上,论文在mT5模型上对比了SFT,Prompt-Tuning,和ProMoT在下游微调任务,和微调任务之外其他通用任务的能力对比。发现ProMoTe可以在分类,翻译,NLI,QA等任务微调上对比全量微调都有更好的效果。同时以分类任务为例,在分类任务上进行微调后,在QA等其他任务上相比基座模型能力也没有显著下降,作为对照组的SFT会有显著的通用能力的下降。

Dual-Stage Mixed Finetuning - DMT

  • How Abilities in Large Language Models are Affected by Supervised Fine-tuning Data Composition
  • Scaling Relationship on Learning Mathematical Reasoning with Large Language Models

DMT的论文主要探究了不同领域数据混合,以及混合训练策略对模型在多领域能力的影响。

1. 单领域Scaling curve

要想设计更合理的多领域混合训练策略,首先要确认不同领域样本模型学习的scaling curve。这个问题之前已经有很多论文讨论过,这里简单回顾下,如下图所示

  • 数学和代码等领域能力,会随样本量上升而持续提升,并且模型规模越大scaling curve越单调且陡峭。 这一点和我们的测试效果相似,数学和代码样本你就可劲加,加一点模型好一点,更多细节看上面Scaling的论文。
  • 通用指令能力,基本在1K(1/256的样本)的样本上效果就很好了,后续能力提升会比较慢,并且在不同规模的模型上差异相对有限。 这一点我们在前文讨论过详见LLM对齐经验之数据越少越好?

2. 多领域混合Scaling curve

明确单一领域的scling curve之后,我们来看多领域的数据混合,这里会分别讨论数据混合中的两个要点:整体量级和混合比例

  1. 整体量级:和以上单领域实验相同的5种不同采样比例,直接对三个领域的数据进行混合,和上面的单领域实验结果进行对比。观察下图会发现在低资源上领域混合会有提升,但在更大的样本量级上单领域微调效果会略好 一个可能的解释是在小量级样本上会有彼此的能力迁移,而当单领域信息逐步提升后信息冲突会逐渐显现
  1. 混合比例:为了进一步探究以上全样本混合训练中出现的信息冲突的来源,作者进一步做了控制变量的实验。固定一个领域(math和code合成一个领域)的样本量改变另一个领域的样本量,看不同比例数据混合的影响。主要结论有
  • 主领域样本还是越多越好
  • 当领域样本差异(输出格式/输入分布)较大时,通用领域数据对特殊领域影响有限
  • 当样本存在相似性时混合会带来冲突,但冲突和数据比例没有显著单调性

3. 训练策略影响

论文实验了不同训练策略的影响,包括多领域联合训练,有序训练(Code->Math->General),以及先训练Math+Code再训练general的有序混合训练,。这几种策略之前也有很多论文做过测试,这里简单说下结论

  • 多领域联合训练:会更多学到特殊领域(Math+code),更多损伤通用能力。这块可以更多借用ProMoT的逻辑,因为特殊领域输出风格一致模型更容易学到,而通用领域输出风格更多样些
  • 有序训练和有序混合训练:只要是先训练领域能力再训练通用能力,因为灾难遗忘的原因,最终模型会把先学到的领域能力遗忘

在以上三种训练方案的基础上,论文提出了两阶段混合训练(DMT)如下

第一阶段是领域数据的训练,按照单领域scaling curve,这一部分的数据量越大效果越好,所以使用全量级的数学和代码进行训练。

第二阶段用于恢复通用能力,同时尽量避免有序训练带来的灾难遗忘。 这里使用了上面多领域混合的insight,领域数据的混合比例对通用能力影响较小;同时低资源混合带来的冲突较小。因为论文使用了1/256的领域数据和通用数据进行混合进行第二阶段的训练。在尽量避免第一阶段模型学到的能力丢失的基础上,帮助模型恢复通用能力。

效果上在LLaMA7B,13B,和33B的模型上,DMT的训练方案能在保留单领域训练绝大多数领域能力的基础上,保证模型通用能力不受损失,甚至略微有所提升 。如果想要保留更多的领域能力,允许更多的通用能力损失,则可以适当提高第二阶段的领域数据占比,具体要数据集上case by case的测试。

想看更全的大模型相关论文梳理·微调及预训练数据和框架·AIGC应用,移步Github >> DecryPrompt

相关推荐
fly-974 小时前
LLM大模型微调入门Lora(LlamaFactory)
chatgpt·nlp
Just Jump8 小时前
大语言模型LLM综述
llm·大语言模型
数据智能老司机12 小时前
LLM工程师手册——RAG 推理管道
人工智能·llm·aiops
ApiHug15 小时前
ApiSmart-QWen2.5 coder vs GPT-4o 那个更强? ApiSmart 测评
java·人工智能·ai·llm·通义千问·apihug·apismart
Hamm18 小时前
先别急着喷,没好用的iOS-Ollama客户端那就自己写个然后开源吧
人工智能·llm·swift
小森( ﹡ˆoˆ﹡ )1 天前
词嵌入方法(Word Embedding)
人工智能·机器学习·自然语言处理·nlp·word·embedding
人工智障调包侠2 天前
Pytorch从0复现worc2vec skipgram模型及fasttext训练维基百科语料词向量演示
人工智能·pytorch·自然语言处理·nlp·word2vec·词向量·skipgram
数据智能老司机2 天前
机器学习生产系统——可解释性
人工智能·机器学习·llm
SpikeKing2 天前
LLM - 使用 LLaMA-Factory 微调大模型 Qwen2-VL SFT(LoRA) 图像数据集 教程 (2)
人工智能·lora·llm·sft·多模态大模型·llama-factory·qwen2-vl
Power20246663 天前
NLP论文速读(NeurIPS 2024)|大语言模型在评估的时候更倾向于自己生成的内容
人工智能·深度学习·机器学习·计算机视觉·语言模型·自然语言处理·nlp