LoRA(Low-Rank Adaptation)的原理和实现

一、LoRA 的原理

LoRA 是一种高效微调技术,它的核心思想非常巧妙:冻结预训练模型原有的参数,然后在模型原有的参数矩阵(主要是权重层)旁边,添加一个可训练的、低秩的分解矩阵,通过仅优化这个低秩矩阵来模拟模型参数的更新。

下面我们逐步拆解这个原理。

1. 背景动机
  • 微调大模型的痛点:传统的全量微调需要为每个下游任务更新所有参数,计算成本和显存开销巨大。

  • 内在低维假设:有研究指出,虽然预训练模型拥有海量参数,但模型在下游任务上做适应时,其参数更新的过程往往具有一个"极低的内在维度"(intrinsic dimension)。这意味着,参数的改变量 ΔW 虽然位于一个极高维的空间中,但其有效变化可以被限制在一个非常低维的子空间中。

2. 核心思想:低秩分解

LoRA 正是基于这个假设。对于一个预训练好的线性层,其权重矩阵为 W0∈Rd×kW0​∈Rd×k(d 是输出维度,k 是输入维度)。

  • 传统微调:我们需要学习一个参数更新量 ΔWΔW,最终的参数变为 W=W0+ΔWW=W0​+ΔW。ΔWΔW 的维度与 W0W0​ 相同,即 d×kd×k。

  • LoRA 微调:LoRA 不对 W0W0​ 做任何更新,而是将更新量 ΔWΔW 分解为两个远小于原始矩阵的矩阵相乘:

    h=W0x+ΔWx=W0x+(BA)xh=W0x+ΔWx=W0x+(BA)x

    其中:

    • B∈Rd×rB∈Rd×r,A∈Rr×kA∈Rr×k。

    • rr 是秩(rank),是 LoRA 中最重要的超参数,且 r≪min⁡(d,k)r≪min(d,k)。

    • 这样一来,可训练的参数总量就从 d×kd×k 变成了 d×r+r×kd×r+r×k,参数量通常能减少几千甚至上万倍。

3. 为什么是低秩?

通过使用一个很小的 rr,我们强制让 ΔWΔW 变得"低秩"。这相当于在说:我们相信对于特定的下游任务,模型权重的有效改变可以被压缩到一个非常低维的空间中表达。

打个比方:ΔWΔW 原本像一本包含 d×kd×k 个字的厚书。LoRA 认为这本书的主要内容可以被压缩成一本小册子(AA)和一个解读指南(BB),两者相乘就能基本还原出原书的内容。

4. 训练与推理过程
  • 训练阶段

    • 冻结:原始权重矩阵 W0W0​ 被冻结,不计算梯度。

    • 训练:只有新增的矩阵 AA 和 BB 是可训练的。

    • 初始化:这是 LoRA 设计的关键细节之一。

      • 矩阵 AA 通常使用随机高斯分布初始化。

      • 矩阵 BB 通常初始化为零。这样,在训练开始时,ΔW=BA=0ΔW=BA=0,模型的输出与原始预训练模型完全一致,保证了训练的稳定性。

  • 推理阶段

    • 合并:因为 W0W0​ 是固定的,且 BB 和 AA 训练完成后也是固定的,我们可以将训练好的低秩矩阵合并回原始权重中,形成一个全新的权重矩阵:Wmerged=W0+BAWmerged​=W0​+BA。

    • 无损速度 :合并后,WmergedWmerged​ 的维度与 W0W0​ 完全相同。在推理时,我们直接使用这个合并后的矩阵进行前向传播。因此,相比原始模型,推理速度完全没有损失

二、LoRA 的实现

我们以 PyTorch 为例,展示如何为一个普通的 nn.Linear 层添加 LoRA。

1. 核心模块代码
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class LoRALayer(nn.Module):
    """
    一个简单的 LoRA 适配层,用于包装一个 nn.Linear 模块。
    """
    def __init__(self, original_layer: nn.Linear, rank: int = 4, alpha: float = 1.0):
        super().__init__()
        self.original_layer = original_layer
        # 冻结原始层的参数
        self.original_layer.weight.requires_grad = False
        # 如果原层有bias,通常不冻结也不修改,这里我们保留原bias
        if self.original_layer.bias is not None:
            self.original_layer.bias.requires_grad = False

        # LoRA 参数: 低秩矩阵 A 和 B
        in_features = original_layer.in_features
        out_features = original_layer.out_features

        # 矩阵 A: 将输入从 in_features 压缩到 rank
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        # 矩阵 B: 将 rank 映射回 out_features
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))

        # 缩放因子 alpha
        # 在LoRA论文中,最终的输出是 original_output + (alpha/rank) * (B @ A) @ input
        self.scale = alpha / rank

        # 初始化 A (Kaiming均匀初始化) 和 B (零初始化)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        # 原始路径
        original_output = F.linear(x, self.original_layer.weight, self.original_layer.bias)
        # LoRA 路径: x @ A.T @ B.T
        # 注意: lora_A 和 lora_B 是 Parameter,可以直接参与矩阵乘法
        lora_output = (x @ self.lora_A.T) @ self.lora_B.T
        # 合并输出
        return original_output + self.scale * lora_output

    def merge_weights(self):
        """
        将训练好的 LoRA 权重合并回原始权重,用于推理加速。
        """
        if hasattr(self, 'lora_A') and hasattr(self, 'lora_B'):
            # W_merged = W_original + (B @ A) * scale
            merged_weight = self.original_layer.weight + (self.lora_B @ self.lora_A) * self.scale
            # 创建一个新的 Linear 层,用合并后的权重
            new_layer = nn.Linear(
                self.original_layer.in_features,
                self.original_layer.out_features,
                bias=self.original_layer.bias is not None
            )
            new_layer.weight.data = merged_weight
            if self.original_layer.bias is not None:
                new_layer.bias.data = self.original_layer.bias.data
            return new_layer
        else:
            return self.original_layer
2. 将 LoRA 应用到模型上

通常我们会选择性地将 LoRA 应用到模型的关键层,比如 Transformer 的 Query 和 Value 投影矩阵上。

python 复制代码
import math
from copy import deepcopy

def apply_lora_to_model(model, rank=4, alpha=1.0, target_modules=None):
    """
    递归遍历模型,将指定的 nn.Linear 层替换为 LoRALayer。

    Args:
        model: 原始模型
        rank: LoRA 的秩
        alpha: 缩放因子
        target_modules: 需要替换的层名称列表,例如 ['q_proj', 'v_proj']。
                        如果为 None,则替换所有 Linear 层(通常不推荐)。
    """
    if target_modules is None:
        # 如果没有指定,默认替换 q_proj 和 v_proj(针对 Transformer 模型)
        target_modules = ['q_proj', 'v_proj']

    for name, module in model.named_children():
        # 如果是 Linear 层且名称在目标列表中
        if isinstance(module, nn.Linear) and any(target in name for target in target_modules):
            # 创建 LoRA 层并替换
            lora_layer = LoRALayer(module, rank=rank, alpha=alpha)
            setattr(model, name, lora_layer)
        else:
            # 递归处理子模块
            apply_lora_to_model(module, rank, alpha, target_modules)
    return model

# 使用示例 (伪代码)
# model = YourPreTrainedModel()
# model = apply_lora_to_model(model, rank=8)
# 现在,只有 LoRA 层的参数是可训练的,可以正常进行训练
3. 关键超参数说明
  • r (Rank)

    • 作用:决定了低秩矩阵的维度,也就是可训练参数的数量和表达能力。

    • 典型值 :通常在 1 到 64 之间。对于大多数任务,r=4r=8 就能达到很好的效果。更大的 r 不一定更好,反而可能引入噪声和过拟合。

  • alpha (Scaling Factor)

    • 作用 :一个缩放因子,用于控制 LoRA 分支(BABA)对最终结果的贡献权重。最终输出为 W0x + (alpha/r) * BAx

    • r 的关系 :当使用 Adam 优化器时,调整 alpha 大致相当于调整学习率。为了简化调参,通常将 alpha 设置为第一个使用的 r 值(例如 alpha=8r=8),然后主要调整学习率。如果后续改变 r,可以相应缩放 alpha

  • target_modules (目标模块)

    • 作用:决定将 LoRA 应用到模型的哪些部分。

    • 典型实践 :在 Transformer 架构中,通常只对 Self-Attention 中的 QueryValue 矩阵应用 LoRA。有时也会对 KeyOutput 矩阵应用,或者对 MLP 层应用,但这会增加参数量。

三、LoRA 的优势总结

  1. 极高的训练效率:可训练参数极少,大幅降低显存占用和计算量。

  2. 无损推理速度:训练完成后,可以将 LoRA 权重合并回原模型,推理时零延迟。

  3. 便携的模型分發:对于同一个基础模型,不同的下游任务只需保存对应的、体积很小的 LoRA 权重文件(通常几 MB 到几十 MB),而不是整个庞大的新模型。

  4. 支持任务切换:在推理时,可以通过动态加载不同的 LoRA 权重(而不合并),让一个基础模型服务于多个不同的任务,实现高效的"热插拔"。

LoRA 目前已经是 LLM 微调的事实标准之一,其简洁而强大的思想也在 CV(计算机视觉)、多模态等领域得到了广泛应用。

相关推荐
喵手1 小时前
Python爬虫实战:同名实体消歧 - 店铺/公司名规则合并与标准化等!
爬虫·python·爬虫实战·零基础python爬虫教学·同名实体消歧·店铺/公司名规则合并与标准化
We་ct1 小时前
LeetCode 106. 从中序与后序遍历序列构造二叉树:题解+思路拆解
前端·数据结构·算法·leetcode·typescript
iAkuya1 小时前
(leetcode)力扣100 72每日温度(栈)
算法·leetcode·职场和发展
weixin_477271691 小时前
掾象:援助,辅佐。基于马王堆帛书《周易》原文及甲骨文还原周朝生活活动现象(《函谷门》原创)
算法·图搜索算法
七夜zippoe1 小时前
集成测试实战:构建可靠的测试金字塔体系
python·log4j·e2e·fastapi·持续集成·flask api
yunhuibin1 小时前
VGGNet网络学习
人工智能·python·深度学习·神经网络·学习
hhzz1 小时前
使用Python对MySQL进行数据分析
python·mysql·数据分析
随意起个昵称1 小时前
建图优化小记
c++·算法
逆境不可逃1 小时前
【从零入门23种设计模式04】创建型之原型模式
java·后端·算法·设计模式·职场和发展·开发·原型模式