高效大语言模型微调:LoRA 和 QLoRA

大语言模型微调存在的问题

问题1:微调参数量

对于微调模型来说,常见的做法是根据损失函数,对预训练好的模型所有的参数进行更新。目的是希望能够通过少量的例子,让已经具备"一定基础知识"的模型能够快速学习新的任务。

例如我们已经使用大量的中文文本和英文文本训练了一个 GPT 模型,因此可以认为模型对于两种语言都有了一定的了解。那么在此基础上,可以使用少量的高质量中英双语平行语料对模型进行再次训练,就可以获得一个表现优异的翻译模型。但是对于 LLM 来说,这种方式存在以下两个主要难点:

  1. 如果想要微调所有参数,对计算资源(主要是 GPU)的需求量巨大
  2. 微调之后的模型和原始模型不在是同一个模型,也就是原始模型的一些通用能力可能在微调过程中遗忘。如果下次需要在另一个任务上微调,则需要再存储/部署一个全量参数的模型

只考虑第二点的话,一种解决方法是使用两个或多个任务的一起对模型进行微调。这样模型可以同时学习多个任务,但这个过程中可能会出现任务冲突等问题,导致一个或多个任务的表现不如对每个任务单独微调的效果。

解决方案: LoRA

LoRA 的全称是 Low-Rank Adaptation,也就是使用两个低秩矩阵对原始权重矩阵进行近似,在极大减少参数量的同时,实现对模型的微调。以最简单的模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = x W y = xW </math>y=xW 为例:在微调的过程中保持 W 不变(只参与计算,不涉及更新),额外增加一个参数 W' (W' 的参数量 << W 的参数量),在微调过程过程中只对 W' 进行更新。也就是将模型变为: <math xmlns="http://www.w3.org/1998/Math/MathML"> y = x W + x W ′ y = xW + xW^{'} </math>y=xW+xW′ 。 同样的简化过程可以应用在 transformer 的每层的权重矩阵(例如注意力机制的 QKV 矩阵)上,就是实现了额外增加少量参数实现对模型的微调。

如何实现 LoRA

假设原始参数 W 是一个 512 * 128 的矩阵,对于输入为 1 * 512 的 x,如果想要获得和 xW 形状一致的 1 * 128 的输出,可以使用两个矩阵来近似 W' ,在降低参数量的同时获得一致的输出,也就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> W ′ = A B W' = AB </math>W′=AB。其中 A 的维度为 512 * 2;B 的维度为 2 * 128,这样就在输出一致的情况下参数变为原始参数的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 512 ∗ 2 + 2 ∗ 128 ) / ( 512 ∗ 128 ) ≈ 0.02 (512 * 2 + 2 * 128) / (512 * 128) \approx 0.02 </math>(512∗2+2∗128)/(512∗128)≈0.02 。 除此以外,在更新的过程中,需要存储的参数数量也大幅降低。以 Adam 优化器为例,由于需要记录梯度的一阶和二阶的信息,因此参数量大约是原始参数量的三倍 。但是 LoRA 不需要对原始参数进行微调,所以需要在反向传播过程中需要存储的参数量大幅下降: LoRA 模型对 x 求导得到的结果是 <math xmlns="http://www.w3.org/1998/Math/MathML"> W + A B W + AB </math>W+AB。在上面的例子中,如果不是使用 LoRA 则在优化过程中需要保存的参数量约为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 ∗ 512 ∗ 3 ≈ 786 k 512 * 512 * 3 \approx 786k </math>512∗512∗3≈786k。相比之下 LoRA 只需要保存 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 ∗ 512 + ( 512 ∗ 2 ∗ 3 ) + ( 2 ∗ 128 ∗ 3 ) ≈ 266 K 512 * 512 + (512 * 2 * 3) + (2 * 128 * 3) \approx 266K </math>512∗512+(512∗2∗3)+(2∗128∗3)≈266K

如何使用 LoRA

由于 LoRA 部分的输出形状和原始输出一致,因此只要将原始结果和 LoRA 模块输出的结果相加就得到了最终结果。因此,相比于原始的微调方法,需要将微调后全量的模型参数进行部署,现在对于具体的微调任务只需要额外部署 LoRA 部分的参数(数量很小,远远小于全量模型从桉树),全量模型参数可以实现多个微调模型之间的共享。

LoRA 的超参数选择

  1. 训练至少需要 100+ steps

<math xmlns="http://www.w3.org/1998/Math/MathML"> s t e p s = 数据量 b a t c h _ s i z e ∗ e p o c h s steps = \frac{数据量}{batch\_size} * epochs </math>steps=batch_size数据量∗epochs

  1. rank 需要根据任务来选

一般来说,原始模型的能力越强,rank (rank 越大 LoRA 的参数越多)可以相对小一些。一定程度上来说,rank 需要根据微调任务的难度选取。对于原始模型已经表现相对不错(预训练的时候任务相关数据比较多)的任务,rank 可以选择的小一些(比如 16 以下)。反之,rank 可以设置的更大。

  1. Alpha 和数据量相关

LoRA 实际的实现中还有一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> s c a l i n g _ f a c t o r = a l p h a r a n k scaling\_factor = \frac{alpha}{rank} </math>scaling_factor=rankalpha ,和 LoRA 输出的结果相乘,用于控制 LoRA 模块的结果如何影响最终结果。因此 alpha 越大,LoRA 的输出对模型最终输出的影响也越大。在微调的数据量很小的情况下,可以尝试调大 alpha 为 rank 的两到四倍。但如果数据量比较大,可以尝试更小的 alpha,一般和 rank 保持一致(也就是 scaling_factor 为 1)即可。

问题2:模型参数量

通过使用 LoRA 对模型进行微调,较好的解决了问题 1 中提到的两个难点。但是计算资源需求变小是相对于对全量参数进行微调而言(主要是可以省掉存储优化器需要的梯度信息的存储开销),原始模型的参数仍需要加载到显存中,对于计算资源的需求仍然相对较大。

解决方案: QLoRA

LoRA 带来的额外参数量是十分小的,如果将量化后的原始模型和 LoRA 配合使用,就解决了 LoRA 微调时显存使用的问题。可以使用更少的 GPU 对大语言模型进行微调。这就是QLoRA 的核心思想。

对于模型过大常用的方法包括参数剪裁(pruning)和模型量化(quantization)。二者都会带来一定程度的效果下降,但是相比之下,量化对模型的改动比较小,更易操作。通过量化可以使得模型在结构和参数基本保持不变的情况下,可以使用更少的显存。例如之前 512 * 128 的参数矩阵的例子。如果使用使用 float32 存储,那么参数需要使用 512 * 128 * 4 bytes 的显存。但是如果在损失一定精度的情况下使用 float16 存储,那么参数所需的显存则变为 512 * 128 * 2 bytes ,减少了一半。这个压缩效果可以进一步优化,如果使用 int8 则降低 4 倍, int4 可以降低 8 倍。

如何实现 QLoRA

量化存在的问题

量化看似提供了一个更为高效的微调模型的方案。虽然 int4 量化可以极大的压缩模型使用的显存,但是精度也带来了巨大的损失。使用 4 bits 的 int4 的取值仅有 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 4 = 16 2^{4} = 16 </math>24=16 个。这对于模型微调来说肯定是远远不够的。具体来说,可以考虑下面的例子:

如果想要使用 int8(值域为 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 127 , 127 ] [-127, 127] </math>[−127,127]) 来量化 float32 的向量:

<math xmlns="http://www.w3.org/1998/Math/MathML"> x = [ − 1024.0 , 5.0 , 2048.0 , 256.0 ] x = [-1024.0, 5.0, 2048.0, 256.0] </math>x=[−1024.0,5.0,2048.0,256.0]

那么量化参数为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> c = 127 a b s m a x ( x ) = 127 2048 ≈ 0.062 c = \frac{127}{absmax(x)} = \frac{127}{2048} \approx 0.062 </math>c=absmax(x)127=2048127≈0.062

使用 c 和 x 相乘,得到:

<math xmlns="http://www.w3.org/1998/Math/MathML"> c x = 0.062 ∗ [ − 1024.0 , 5.0 , 2048.0 , 256.0 ] = [ − 63.488 , 0.31 , 126.976 , 15.872 ] cx = 0.062 * [-1024.0, 5.0, 2048.0, 256.0] = [-63.488, 0.31, 126.976, 15.872] </math>cx=0.062∗[−1024.0,5.0,2048.0,256.0]=[−63.488,0.31,126.976,15.872]

最后对 cx 取整,就得到了量化后的结果:

<math xmlns="http://www.w3.org/1998/Math/MathML"> q u a n t ( x ) = r o u n d ( c x ) = [ − 63 , 0 , 127 , 16 ] quant(x) = round(cx) = [-63, 0, 127, 16] </math>quant(x)=round(cx)=[−63,0,127,16]

只要保存了 c 就能够将量化后的结果近似的映射会原始值。但是如果 x 的值域很广,例如 <math xmlns="http://www.w3.org/1998/Math/MathML"> x = [ 1 , 100 , 100000 ] x = [1, 100, 100000] </math>x=[1,100,100000],那么按照上述的计算方式,量化后的 x 则变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 0 , 127 ] [0, 0, 127] </math>[0,0,127]。将 1 和 100 都变为 0,这样的精度损失肯定会带来很大的问题的。也就是一个离群点可以能会导致整个量化过程失败。

NormalFloat Quantization 和 Double Quantization

为了解决量化存在的问题,QLoRA 使用了 NormalFloat Quantization。如果是使用 4 bits 的量化,则称为 NF4(4-bit NormalFloat Quantization)。简单来说,NF4 的假设是模型的参数基本符合正态分布,那么对于值域只有 16 值的 int4 来说,可以根据向量的正态分布平均划分出 16 个桶。随后将原始输入的值压缩到这些桶中。

实现的方式可以参考简化版本:使用 2 bits 对输入进行量化。

2 bits 的取值为 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 1 , 2 , 3 ] [0, 1, 2, 3] </math>[0,1,2,3]。目标是将值域压缩到在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 1 , 1 ] [-1, 1] </math>[−1,1] 的四个桶中,每个桶的量化参数分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 1.0 , 0.3 , 0.5 , 1.0 ] [-1.0, 0.3, 0.5, 1.0] </math>[−1.0,0.3,0.5,1.0]。

对于待量化输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 10 , − 3 , 5 , 4 ] [10, -3, 5, 4] </math>[10,−3,5,4],通过以下步骤完成量化:

  1. 对输入进行归一化得到结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 1 , − 0.3 , 0.5 , 0.4 ] [1, -0.3, 0.5, 0.4] </math>[1,−0.3,0.5,0.4]
  2. 找到过一化结果和距离最近的桶均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 1 , 0.3 , 0.5 , 0.5 ] [1, 0.3, 0.5, 0.5] </math>[1,0.3,0.5,0.5]
  3. 保存对应的桶下标就实现了量化 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 3 , 1 , 2 , 2 ] [3, 1, 2, 2] </math>[3,1,2,2]

但是为了能够将结果近似的映射回去,现在每个桶都需要存储对应的 float32 的量化参数。因此论文中提出对这些量化参数进行进一步量化,论文中称为 Double Quantization 。简单来说,Double Quantization 是将 16 个 float32 的参数量化为 float8 ,这样只需要额外存储一个 float32 的新量化参数用于将 float8 映射回 float32 ,进一步降低显存的使用。

如何应用在模型中

还是以 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = x W + x W ′ y = xW + xW^{'} </math>y=xW+xW′ 为例,由于 LoRA 的参数比较少,所以组成 W' 的 A 和 B 矩阵依旧可以使用 float16 存储。W 的部分参数量大,使用的 NF4 量化后结果。那么在计算时,输出和输出也是 float16 ,为了保证数据类型的一致,只需要根据量化参数将 W 映射回 float16 即可。

在 transformer 的计算过程中,由于每层之间有依赖关系,所以每次只需要将当前计算需要的原始参数映射回 float16 即可,在当前计算完成后无需保留映射回 float16 的结果,这样使用了少量的额外显存开销,完成了模型的计算。

文中 table 3 的试验结果验证 NF4 + DQ 的 QLoRA 在节省了微调需要的显存的情况下,获得了基本相同的效果。

模型量化使用

在 huggingface 的 transformers 库中通过 BitsAndBytesConfig 提供了 NF4 和 DQ 的支持。可以在部署模型时减少显存的使用,代码可以参考:

python3 复制代码
import torch

# transformers.__version__ == '4.30.2'
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,  # Double Quantization
    bnb_4bit_quant_type="nf4"  # 4-bit NormalFloat Quantization
)
kwargs = {
    'load_in_4bit': True, 
    'torch_dtype': torch.float16, 
    'trust_remote_code': True, 
    "quantization_config": quantization_config, 
    "device_map":"auto"  # 使用多 GPU
}

model_name = 'meta-llama/Llama-2-7b-hf'
llm_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)

参考资料

github.com/S-LoRA/S-Lo... www.youtube.com/watch?v=_K3... www.youtube.com/watch?v=C33... www.youtube.com/watch?v=LR3... huggingface.co/blog/4bit-t... developer.nvidia.com/blog/cuda-1... developer.nvidia.com/blog/mixed-... athekunal.medium.com/qlora-quant...

相关推荐
bastgia2 天前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
新智元2 天前
李飞飞谢赛宁:多模态 LLM「空间大脑」觉醒,惊现世界模型雏形!
人工智能·llm
RWKV元始智能3 天前
RWKV-7:极先进的大模型架构,长文本能力极强
人工智能·llm
zaim13 天前
计算机的错误计算(一百八十七)
人工智能·ai·大模型·llm·错误·正弦/sin·误差/error
张拭心3 天前
Google 提供的 Android 端上大模型组件:MediaPipe LLM 介绍
android·人工智能·llm
带电的小王3 天前
whisper.cpp: Android端测试 -- Android端手机部署音频大模型
android·智能手机·llm·whisper·音频大模型·whisper.cpp
带电的小王4 天前
whisper.cpp: PC端测试 -- 电脑端部署音频大模型
llm·whisper·音视频·音频大模型
Ambition_LAO4 天前
LLaMA-Factory QuickStart 流程详解
llm·llama
宇梵文书C4 天前
在CFFF云平台使用llama-factory部署及微调Qwen2.5-7B-Instruct
llm·llama·cfff
zaim14 天前
计算机的错误计算(一百八十六)
人工智能·python·ai·大模型·llm·误差·decimal