LoRA是用于训练自定义LLM的高效参数微调技术。本文作者Sebastian Raschka通过成百上千次实验,他为使用LoRA和QLoRA对LLM进行微调提供了实用见解,包括节省内存、选择最佳配置等。
Sebastia是威斯康星大学麦迪逊分校的统计学助理教授,也是人工智能平台Lightning AI的LLM研究员。
(本文由OneFlow编译发布,转载请联系授权。原文:
lightning.ai/pages/commu...
作者 | Sebastian Raschka
OneFlow编译
翻译|宛子琳、杨婷
过去几个月里,我进行了成百上千次有关LoRA的实验。几周前,我花时间更深入地研究了一些超参数选择问题。
本文更像是一篇按时间顺序呈现的实验日记,我希望它对某些人能够有所帮助。具体而言,本文旨在回答一些关于QLoRA价值的问题:如是否应该用SGD取代AdamW,使用调度器(scheduler)的潜在价值,以及如何调整LoRA的超参数。
关于实验有很多需要讨论的内容,因此我对LoRA的介绍会尽可能简明扼要。
简而言之,LoRA(Low-Rank Adaptation的缩写)(Hu等,2021)在保持原始模型参数不变的同时,在模型中添加了一小部分可训练参数。
如下图所示,LoRA将权重矩阵分解为两个较小的权重矩阵,以便用更高效参数的方式近似完全监督微调。
1
评估任务与数据集
本文的重点是如何选择最佳设置。为保持在合理范围内,我将保持数据集不变,仅关注LLM的监督指令微调。(可能会在后续文章中讨论有关数据集或分类微调的修改。)
关于模型评估,我从Eleuther AI提供的标准化评估工具包中选择了一小部分数据集,包括TruthfulQA、BLiMP Causative和MMLU Global Facts,以及两位和四位数的简单计算任务。
在每个基准测试中,模型的性能得分被归一化到0到1之间,1表示满分。TruthfulQA报告了两项得分,定义如下:
- MC1 (单选题):给定一个问题和4-5个候选答案,选择唯一的正确答案。分数是所有问题的简单准确率。
- MC2(多选题):给定一个问题和多个正确/错误参考答案,得分为模型赋予正确答案集合的归一化总概率。
175B GPT-3模型的TruthfulQA的MC1和MC2值分别为0.21和0.33(供参考)。
以下是两个用于说明算数2ds和算数4ds之间区别的例子。
- 算数2ds:"59减去38等于多少。" "21。"
- 算数4ds:"2762加上2751等于多少。" "5513。"
如上所述,在保持数据集不变的情况下,我使用了经广泛研究或常用的Alpaca数据集进行监督指令微调。当然,还有许多其他适用于指令微调的数据集,如LIMA、Dolly、LongForm和FLAN等等。然而,在多个数据集和数据集混合上进行训练将是未来研究的一个有趣课题。
Alpaca数据集包含约50000个指令-回应对用于训练,输入的文本长度中位数为110个词元(使用LLaMA 2 SentencePiece分词器),如以下直方图所示。
数据集任务本身可以按照下图所示的方式进行结构化
2
代码框架
在本文,我使用了基于开源Lit-GPT存储库的自定义LLM微调代码。为保持本文前言简洁明了,可参阅Lit-GPT教程部分的详细指南,以了解使用细节。
简而言之,使用方法如下:
1) 复制代码库并安装需求
bash
git clone https://github.com/Lightning-AI/lit-gpt
cd lit-gpt
pip install -r requirements.txt
2) 下载并准备模型checkpoint
bash
python scripts/download.py \
--repo_id mistralai/Mistral-7B-Instruct-v0.1
# there are many other supported models
python scripts/convert_hf_checkpoint.py \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1
3) 准备数据集
bash
python scripts/prepare_alpaca.py \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1
# or from a custom CSV file
python scripts/prepare_csv.py \
--csv_dir MyDataset.csv \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1
4) 微调
bash
python finetune/lora.py \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1/ \
--precision bf16-true
5) 合并LoRA权重
bash
python scripts/merge_lora.py \
--checkpoint_dir "checkpoints/mistralai/Mistral-7B-Instruct-v0.1" \
--lora_path "out/lora/alpaca/Mistral-7B-Instruct-v0.1/lit_model_lora_finetuned.pth" \
--out_dir "out/lora_merged/Mistral-7B-Instruct-v0.1/"
cp checkpoints/mistralai/Mistral-7B-Instruct-v0.1/*.json \
out/lora_merged/Mistral-7B-Instruct-v0.1/
6) 评估
css
python eval/lm_eval_harness.py \
--checkpoint_dir "out/lora_merged/Mistral-7B-Instruct-v0.1/" \
--eval_tasks "[arithmetic_2ds, ..., truthfulqa_mc]" \
--precision "bf16-true" \
--batch_size 4 \
--num_fewshot 0 \
--save_filepath "results.json"
7) 使用
bash
python chat/base.py \
--checkpoint_dir "out/lora_merged/Mistral-7B-Instruct-v0.1/"
3
选择合适的基准模型
第一项任务是为LoRA实验选择一个足以胜任的基础模型。因此我主要关注那些尚未进行指令微调的模型:phi-1.5 1.3B、Mistral 7B、LLaMA 2 7B、LLaMA 2 13B和Falcon 40B。请注意,所有实验都是在单个A100 GPU上运行的。
如图所示,Mistral 7B模型在数学基准测试中表现出色。同时,相对较小的phi-1.5 1.3B模型在TruthfulQA MC2中的性能令人印象深刻。出于某些原因,LLaMA 2-13B在算术基准测试中表现欠佳,而较小的LLaMA 2-7B在这方面的表现则明显更为优秀。
目前,由于研究人员和从业者推测phi-1.5 1.3B和Mistral 7B可能是在基准测试数据集上进行训练的,所以我在实验中排除了这两个模型。此外,我认为选择剩余模型中最小的模型可以提供最大的改进空间,同时保持较低的硬件要求。因此,本文的剩余部分将重点关注LLaMA 2-7B。
4
评估LoRA的默认设置
首先,我使用以下默认设置评估了LoRA微调(以下设置可在finetune/lora.py脚本中进行更改):
ini
# Hyperparameters
learning_rate = 3e-4
batch_size = 128
micro_batch_size = 1
max_iters = 50000 # train dataset size
weight_decay = 0.01
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_query = True
lora_key = False
lora_value = True
lora_projection = False
lora_mlp = False
lora_head = False
warmup_steps = 100
(请注意,批大小(batch size)为128,但我们使用梯度累积并设置微批大小为1以节省内存;这导致了与批大小为128的常规训练等效的训练轨迹。
该配置在我的机器上使用单个A100训练了4194304个LoRA参数,总共可训练的参数为6738415616个。整个训练过程大约耗时1.8小时,最大的内存使用量为21.33 GB。
为评估模型的性能波动,我进行了三次重复实验,以观察每次运行间的性能变化。
从上表可看出,多次运行的性能十分一致且稳定。值得注意的是,默认LoRA模型在算术方面表现很差,但这是意料之中的,据我所知,Alpaca数据集并不包含(或包含很少的)算术任务。
此外,我还观察了由Meta使用RLHF进行指令微调的7B版本的LLaMA 2。根据以下表格,可以看到Meta的LLaMA 2 Chat模型的算术性能也较差。Chat模型在其他基准测试中的表现获得了显著改善(除了BLiMP),可以将其作为参考,希望通过LoRA微调来接近这个水平。
5
使用QLoRA节省内存
在开始调整LoRA超参数之前,我想探索QLoRA(Dettmers等人提出的流行LoRA量化技术)在建模性能和内存节省之间的权衡。
通过QLoRA(通过Lit-GPT中的--quantize标志启用,这里使用4位普通浮点类型),我们可以节省内存,具体操作如下:
此外,我还尝试了将4位浮点精度作为对照组。以下是对训练时间和最大内存使用的影响:
默认LoRA(bfloat-16):
- 训练时间:6685.75秒
- 内存占用:21.33 GB
QLoRA via ---quantize "bnb.nf4":
- 训练时间:10059.53秒
- 内存占用:14.18 GB
QLoRA via --quantize "bnb.fp4":
- Memory used: 14.19 GB
- 内存占用:14.19 GB
可以看到,QLoRA几乎减少了6 GB的内存需求,但代价是训练时间增加了30%。因为量化和反量化增加了额外步骤,所以这是意料之中的。
接下来,让我们看看QLoRA训练对模型性能的影响:
如表格所示,与常规QLoRA相比,QLoRA对模型性能的影响非常小。模型在算术基准测试中的表现有所提升,但在MMLU Global Facts数据集基准测试中的表现有所下降。
由于内存节省相当可观(这通常比更长的训练时间更重要,因为它允许用户在较小的GPU上运行模型),所以我会在本文剩余部分继续使用QLoRA。
6
学习率调度器与SGD
在之前的所有实验中,我都使用了AdamW优化器,它是LLM训练的常用选择。然而,众所周知,Adam优化器对内存的消耗很大,因为它为每个模型参数都引入和跟踪了两个额外参数(m和v)。LLM拥有许多模型参数,如LLaMA 2模型有70亿个模型参数。
本节探讨了是否值得将AdamW优化器换成SGD优化器。然而,对于SGD优化器,引入一个学习率调度器也尤为重要。我选择了余弦退火调度(cosine annealing schedule),它会在每个批次更新后降低学习率。
不幸的是,将AdamW替换为SGD只节省了少量内存。
- AdamW:14.18 GB
- SGD:14.15 GB
这可能是因为大部分内存都消耗在了大规模矩阵乘法上,而并未用于存储额外参数。
但这一细微差异是可预料的。根据当前选择的LoRA配置(r=8),我们有4194304个可训练参数。如果Adam为每个模型参数添加了2个额外的值,并以16位浮点数的形式存储,那么内存占用将为4194304 * 2 * 16位 = 134.22兆比特 = 16.78兆字节。
当我们将LoRA的r值增加到256时(随后我们会这样做),会观察到更显著的差异。当r=256时,我们有648871936个可训练参数,根据上面的计算公式,这相当于2.6 GB的内存,但实际测量结果相差3.4GB,这可能是存储和复制优化器状态时的一些额外开销所致。
底线是,在可训练参数较少的情况下,例如LoRA和较低的r(rank)值时,与预训练相比(预训练中,我们训练了大量参数),将AdamW替换为SGD节省的内存可能非常小。
SGD并不能显著节省内存,下图是用SGD替换AdamW后的模型性能:
根据上图,SGD优化器的性能与AdamW相当。有趣的是,当为AdamW添加调度器时,TruthfulQA MC2和MMLU Global Facts数据集上的性能有所提高,但算术表现却有所下降。(注:TruthfulQA MC2是其他公共排行榜中备受认可的基准测试。)我们暂时不会过多强调算术表现,并将继续使用带调度器的AdamW进行剩余实验。
如果要复现这些实验,我发现最佳的AdamW学习率是3e-4,衰减率为0.01;最佳的SGD学习率是0.1,动量为0.9。我在这两种情况下都使用了额外的100步学习率进行预热。
(基于这些实验,余弦调度器现已被添加到Lit-GPT中,并默认启用。)
7
多次迭代数据集
迄今为止,我已对所有模型迭代训练了50k次------Alpaca数据集包含了50k个训练样本。现在,最重要的问题是:我们是否可以通过多次迭代训练集来提升模型性能。为探究这个问题,我重新对之前的实验进行了100k次迭代,即迭代次数增加到两倍:
有趣的是,当迭代次数增加时,模型的整体性能却下降了,其中模型在算术基准测试中的性能下降最为明显。据我猜测,原因可能是Alpaca数据集并未包含任何与算术相关的任务,当模型更加关注其他任务时,就会主动停止对基础算数知识的学习。
无论如何,这一结果是可喜的。有了这一结果,我就可以继续在剩余的实验中使用较少的50k次迭代进行实验。
8
LoRA超参数微调:适用于所有层的LoRA
现在,我们已经探索了围绕LoRA微调脚本的基本设置,让我们将注意力转向LoRA超参数本身。默认情况下,LoRA仅在多头自注意力块中的键(Key)和查询(Query)矩阵上启用。现在,我们还将其应用于值(Value)矩阵、投影层和线性层:
9
LoRA超参数微调:增加R值
"r"是LoRA最重要的参数之一,它决定了LoRA矩阵的秩(rank)或维度(dimension),直接影响了模型的复杂度和容量。较高的"r"值意味着更强的表达能力,但可能导致过拟合;较低的"r"值可以减少过拟合,但代价是降低了表达能力。在保持所有层都启用LoRA的情况下,我们将"r"值从8增加到了16,以观察其对性能的影响。
我们可以看到,仅增加"r"值会使结果变得更糟,那么到底发生了什么呢?答案将在下一部分揭晓。
10
LoRA超参数调优:更改Alpha
在前一部分,我们在保持LoRA的alpha参数不变的情况下,增加了矩阵秩r,较高的alpha更强调低秩结构或正则化,较低的alpha则减少了其影响,使模型更依赖于原始参数。调整"alpha"有助于在拟合数据和通过正则化防止过拟合之间保持平衡。
一般来说,微调LLM时选择的alpha值是秩的两倍(请注意,当使用扩散模型时,情况可能不同)。接下来让我们来看看将alpha增加两倍会发生什么:
正如我们所看到的,将alpha增加到32产生了迄今为止最好的模型!但为此我们也训练了更多参数:
r=8:
- 可训练参数:20277248个
- 不可训练参数:6738415616个
- 内存占用:16.42 GB
r=16:
- 可训练参数:40554496个
- 不可训练参数:6738415616个
- 内存占用:16.47 GB
然而,可训练参数仍然相对较少,不足以显著影响峰值内存需求。
无论如何,现在我们终于取得了一些进展,并通过更大幅度的改进提升了模型性能。接下来我们将继续实验,看看通过增加秩和alpha,能推进到什么程度:
我还用超大秩(512、1024 和 2048)进行了额外实验,但结果更不理想。其中一些运行甚至在训练过程中无法收敛到接近零的损失,这也是我没有将它们添加到表格中的原因。
目前为止,我们注意到在最后一行中,当r=256,alpha=512时,模型整体表现最好。作为额外的对照实验,我重复了alpha为1的运行,并注意到较大的alpha值确实对良好的性能至关重要:
我还重复了alpha值为16和32的实验,并观察到:相比将alpha值设定为秩的两倍,当alpha的值为16和32时,模型性能同样糟糕。
11
LoRA超参数调优:超大R值
在本文最后的调优实验中,我希望进一步优化前面部分的最佳模型(r=256,最后一行)的alpha值,我怀疑这个值可能略大。
如上图所示,在增加秩时,选择一个较大的alpha值至关重要。
对于QLoRA模型,当r=256和alpha=512时,与基准模型相比,我们的模型性能取得了显著提升。与基准模型相比,微调模型唯一一个表现不佳的领域是4位数算术。不过,考虑到Alpaca数据集可能不包含这样的训练示例,这是可以理解的。
根据上表我们可以看到,当alpha值为秩的两倍(例如,r=256,alpha=512)时,的确产生了最佳结果,而较小的alpha值则会产生较差的结果。但如果将alpha值增加到秩的两倍之上,会发生什么呢?
基于上表显示的结果,当alpha值超出秩的两倍时,基准结果也会变差。
12
排行榜提交
众所周知,在机器学习中,不应该多次使用测试集,否则可能导致模型针对特定任务被过度优化。因此,我们最好在一个独立的最终数据集上验证模型。
巧合的是,目前正在进行NeurIPS LLM Efficiency挑战赛,该挑战赛专注于在单个GPU上对LLM进行微调。我很好奇与在Alpaca数据集上微调的最佳LoRA模型相比,基于LLaMA-2 7B的模型会有怎样的表现,所以我将基准模型和微调模型都提交到他们的排行榜上。
可以看到,(Q)LoRA微调训练(r=256,训练时长为10522.77秒(约3小时),所需GPU内存为19.2GB)在几个指标上提高了性能。考虑到除Alpaca之外的其他微调数据集和RLHF等对齐技术,模型性能还能获得潜在提升,对此,我在这篇文章(magazine.sebastianraschka.com/p/llm-train... )中进行了更为详细的解释。
13
结论
本文探讨了使用LoRA训练自定义LLM时可以调整的各种参数。我们发现,尽管QLoRA会增加运行时间成本,但它是一种很好的节省内存的方法。此外,虽然学习率调度器可能有益,但AdamW和SGD优化器之间几乎没有区别。而且,多次迭代数据集可能会导致结果变得更糟。通过优化包括秩在内的LoRA设置,可以实现更好的效果。增加秩会增加可训练参数的数量,但可能导致更高程度的过拟合,增加运行时间成本。增加秩时,我们需要选择适当的alpha值,这非常重要。
不过,由于时间和资源所限,本文并未探索所有可能的配置。利用其他数据集和模型,未来还可能获得进一步提升。
欢迎 Star、试用 OneFlow 最新版本: