上下文压缩

去年我第一次微调大模型,拿着全量微调的代码就往上冲。结果呢?一块 A100 跑了一个周末没跑完,中途还 OOM 了三次。

那时候我就知道,全量微调这条路,穷人走不起。

后来同事推荐了 LoRA,说"很省"。我一开始不信------参数少这么多,效果能好?结果试完闭嘴了,真香。

LoRA 到底在做什么?

一句话说清楚:全量微调是修改所有参数,LoRA 是插一些小模块只改这些模块的参数。

具体原理是这样的:

大模型的权重矩阵 W 是固定的(冻结,不训练)。LoRA 在 W 旁边并联了两个小矩阵 A 和 B,只训练 A 和 B。

复制代码
原始前向:y = Wx
LoRA前向:y = Wx + BAx

W 是 d×d(几万×几万),A 是 d×r,B 是 r×d,r 通常只有 8 或 16。

全量参数量:d²

LoRA 参数量:2dr

当 d=4096, r=8 时,全量是 1678 万参数,LoRA 只有 6.5 万参数------前者是后者的 256 倍。

能省多少显存,自己体会。

r 值怎么选?我踩过的坑

LoRA 最重要的超参数就是 rank r------它决定了 LoRA 模块的容量。

官方论文推荐 r 设 8 或 16。但我实际试下来:

r=1 到 r=2

参数量少,微调速度快。但表达能力弱,简单任务(比如情感分类)还行,复杂任务(对话生成、代码生成)效果明显差。

r=8 到 r=16

黄金区间。参数适中,效果和全量微调差距在 5% 以内。我 90% 的场景都用 r=8。

r=32 以上

参数量上去了,效果提升并不明显。有论文说 r 超过 128 甚至开始过拟合。

复制代码
我的经验法则:
- 简单分类任务:r=4
- 对话/指令微调:r=8
- 代码/数学推理:r=16
- 再大就上 QLoRA

QLoRA:不给 GPU 活路了

QLoRA 是 LoRA 的升级版,核心思路:模型量化到 4-bit,然后在量化后的模型上跑 LoRA

python 复制代码
from transformers import BitsAndBytesConfig
import torch

# 4-bit 量化配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    "model_name",
    quantization_config=bnb_config,
    device_map="auto"
)

from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none"
)
model = get_peft_model(model, lora_config)

QLoRA 的显存数据:

模型 全量微调 LoRA QLoRA
LLaMA-7B 112GB 24GB 6GB
LLaMA-13B 208GB 40GB 10GB
LLaMA-70B --- --- 48GB

用 QLoRA,一张 RTX 3090(24GB)就能微调 LLaMA-7B。我去年配的机器是两张 3090,跑 13B 模型 QLoRA 毫无压力。

注意:QLoRA 训练速度比 LoRA 慢 30-50%,因为需要反复量化和反量化。但考虑到显存节省,这个代价完全能接受。

Target Modules 怎么配?

这是另一个让我踩坑的地方------到底对哪些层做 LoRA?

论文只改了 Q 和 V 矩阵。但实际应用中有更优选择:

python 复制代码
# 最佳实践:全连接层也加上
lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"]
)

我对比过几种配置:

  • 只改 QV:效果最差,loss 降得最慢
  • 改 QKV + O:效果和全量差距不大,推荐
  • 全改(6个模块):效果最好,但训练速度慢 20%

折中方案是改 QKVO 四个模块。

实际微调:从数据到保存

LoRA 微调的全流程大概是这样的:

1. 准备数据

格式最简单的就是 Alpaca 格式:

json 复制代码
[
  {
    "instruction": "写一首关于春天的诗",
    "input": "",
    "output": "春风拂面来,百花次第开..."
  }
]

数据量:500-2000 条效果就很好了。我用 1000 条数据微调 LLaMA-7B,效果提升肉眼可见。

2. 训练配置

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./lora-llama-7b",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
    fp16=True,
)

关键参数说明:

  • learning_rate:LoRA 的 lr 要比全量微调高。全量用 2e-5,LoRA 用 2e-4 左右
  • batch_size:用梯度累积模拟大 batch,防止显存溢出

3. 合并和部署

训练完的 LoRA 权重很小(几十 MB),但推理时要加载基座模型再插入 LoRA 权重:

python 复制代码
from peft import PeftModel

# 推理时加载
base_model = AutoModelForCausalLM.from_pretrained("base_model")
model = PeftModel.from_pretrained(base_model, "./lora-checkpoint")

如果需要部署到生产,可以合并权重:

python 复制代码
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./merged-model")

合并后的模型可以直接用,和加载完整模型一样。缺点是文件变大了,从几十 MB 变成几个 GB。

LoRA 的局限性

说完了好处也得说说缺点,不然显得不客观。

当 LoRA 不够用的时候:

1. 领域差异太大

如果预训练模型和你任务的数据分布差异极大(比如通用模型 → 医疗影像分析),LoRA 的表达能力可能不够。

2. 微调数据量大

超过 10 万条高质量微调数据,LoRA 的效果可能不如全量微调。不过大多数人接触不到这个量级的数据。

3. 需要极端精确

数学推理、代码生成等要求极高精度的任务,LoRA 和全量微调还是有差距。

在这些场景下,还是老老实实上全量微调吧。

写在最后

LoRA 最大的价值不是"多么创新",而是让更多人能以可接受的成本微调大模型。

说实话,我觉得 LoRA 之所以成为工业标准,不是因为它比全量微调强,而是它让不可能变成了可能。以前需要 8 张 A100 才能做的事,现在 1 张 3090 就可以。

对于大多数实际场景,LoRA 的效果已经完全够用了。如果还不够,上 QLoRA。如果 QLoRA 还不够...那可能不是微调方式的问题,而是你的基座模型选错了。

下一步我要试试 DoRA(Weight-Decomposed Low-Rank Adaptation),据说在 LoRA 的基础上加了一个"方向"和"幅度"的解耦,效果更好。有结果了再来汇报。

相关推荐
AI技术增长3 小时前
Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程
pytorch·python·深度学习
日取其半万世不竭3 小时前
Minecraft Java版社区服搭建教程(Windows版)
java·开发语言·windows
wjs20243 小时前
HTML 文本格式化
开发语言
本地化文档3 小时前
setuptools-docs-l10n
python·github·gitcode
白夜11173 小时前
C++任务调度与状态机
开发语言·c++·笔记
南宫萧幕3 小时前
MATLAB/Simulink 从零打通:HEV 能量管理 GA 联合仿真保姆级建模指南
开发语言·算法·matlab·汽车·控制·pid
梦想不只是梦与想3 小时前
Python 属性访问的 MRO 规则
python·mro规则
Ulyanov4 小时前
基于 Python 的三维动态导弹攻防演示系统设计与实现:从架构到实战的深度剖析
开发语言·python·qt·架构·雷达电子对抗
苍煜4 小时前
Java自定义注解-SpringBoot实战
java·开发语言·spring boot