面试-LoRA

1. LoRA 介绍

目的: 本质就是低秩矩阵分解,将训练的方阵拆解成两个低秩矩阵,目的是降低模型训练参数,加速训练。
核心过程: 在微调开始时,LoRA 模块的加入不会改变原模型的输出,让模型先保持原有能力,再通过训练低秩矩阵逐步学习任务相关的微调信息。
组成部分: 其中,核心部分分为 秩 的大小、矩阵 A、矩阵 B。其中 秩 控制矩阵 A 和 B 的大小,秩越大,分解的矩阵就越大。一般来说 rank = 8/16。
LoRA 的核心公式是: W ′ = W + Δ W = W + B × A W' = W + \Delta W = W + B \times A W′=W+ΔW=W+B×A

python 复制代码
# 定义Lora网络结构
class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank):
        super().__init__()
        self.rank = rank  # LoRA的秩(rank),控制低秩矩阵的大小
        self.A = nn.Linear(in_features, rank, bias=False)  # 低秩矩阵A
        self.B = nn.Linear(rank, out_features, bias=False)  # 低秩矩阵B
        # 矩阵A高斯初始化
        self.A.weight.data.normal_(mean=0.0, std=0.02)
        # 矩阵B全0初始化
        self.B.weight.data.zero_()

    def forward(self, x):
        return self.B(self.A(x))

B 全 0 的原因:

  • 目标:在训练刚开始(t=0)时,我们希望 LoRA 分支不产生任何影响,模型的输出应该完全等于预训练模型的输出。
  • 实现:如果 B B B 是全 0 矩阵,那么无论 A A A 是什么, B × A B \times A B×A 都等于 0。
    Δ W = 0    ⟹    y = W x \Delta W = 0 \implies y = Wx ΔW=0⟹y=Wx
  • 好处:这保证了微调是从预训练模型的分布平滑开始的,避免了初始阶段引入随机噪声破坏预训练知识,提高了训练的稳定性。

A 采用高斯初始化的原因:

如果 A A A 和 B B B 都初始化为 0,虽然满足了 Δ W = 0 \Delta W = 0 ΔW=0,但会导致 梯度消失 ,模型无法学习。我们需要通过链式法则来看梯度流动:

  • 打破对称性:如果 A 也全 0,那么B⋅A⋅x 永远是 0,根据链式求导法则,LoRA 模块的梯度会一直为 0(对 B 矩阵的更新就会一直为 0),参数永远无法更新(梯度消失)。
  • 提供学习起点:高斯初始化(均值 0,标准差 0.02 是 Transformer 类模型的常用值)能给 A 赋予微小的随机值,当训练开始后,B 会从 0 开始学习,逐步和 A 配合生成有意义的增量,让 LoRA 模块能正常更新参数。

举个例子:

假设损失函数为 L L L,输出为 y = B ( A x ) y = B(Ax) y=B(Ax)。

  • 对 B 的梯度
    ∂ L ∂ B = ∂ L ∂ y ⋅ ( A x ) T \frac{\partial L}{\partial B} = \frac{\partial L}{\partial y} \cdot (Ax)^T ∂B∂L=∂y∂L⋅(Ax)T

    • 如果 A A A 是全 0 :那么 A x = 0 Ax = 0 Ax=0,导致 ∂ L ∂ B = 0 \frac{\partial L}{\partial B} = 0 ∂B∂L=0。B B B 永远无法更新,训练直接失败。
    • 如果 A A A 是高斯分布 :那么 A x ≠ 0 Ax \neq 0 Ax=0,梯度可以流向 B B B, B B B 可以从 0 开始更新。
  • 对 A 的梯度
    ∂ L ∂ A = B T ⋅ ∂ L ∂ y ⋅ x T \frac{\partial L}{\partial A} = B^T \cdot \frac{\partial L}{\partial y} \cdot x^T ∂A∂L=BT⋅∂y∂L⋅xT

    • 在第一步时,因为 B = 0 B=0 B=0,确实 ∂ L ∂ A = 0 \frac{\partial L}{\partial A} = 0 ∂A∂L=0。
    • 但是 :由于 A ≠ 0 A \neq 0 A=0, B B B 在第一步就能接收到梯度并更新( B B B 变为非 0)。
    • 在第二步时, B B B 已经非 0 了, A A A 就能接收到梯度并开始更新。
      结论:必须有一个矩阵是非 0 的,才能让另一个矩阵收到梯度。而为了让初始输出为 0,必须让其中一个矩阵为 0。
  • 若 A = 0 , B ≠ 0 A=0, B \neq 0 A=0,B=0 → \rightarrow → 初始输出为 0,但 B B B 收不到梯度 → \rightarrow → 失败

  • 若 A ≠ 0 , B = 0 A \neq 0, B=0 A=0,B=0 → \rightarrow → 初始输出为 0,且 B B B 能收到梯度 → \rightarrow → 成功

2. 如何注入 LoRA 权重

python 复制代码
import torch.nn as nn
# 假设 LoRA 类已经定义,且包含正确的初始化 (A 高斯,B 全 0) 和 forward 逻辑

def apply_lora(model, rank=32):
    """
    遍历模型,为符合条件的线性层注入 LoRA 模块。
    """
    # 1. 遍历模型的所有子模块
    # named_modules() 会递归遍历所有层级,返回 (名称,模块对象)
    for name, module in model.named_modules():
        
        # 2. 筛选目标层
        # 条件 1: 必须是线性层 (nn.Linear)
        # 条件 2: 权重矩阵必须是方阵 (输入维度 == 输出维度)
        # 注意:这是一个较强的限制,标准 LoRA 通常不需要方阵,这里可能是针对特定架构(如 Transformer 的 MLP 层)
        if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
            
            # 3. 创建 LoRA 模块
            # 输入/输出维度与原层一致,rank 控制低秩大小
            # .to(model.device): 关键!确保 LoRA 参数与原模型在同一设备 (CPU/GPU)
            lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
            
            # 4. 注册参数
            # 使用 setattr 将 lora 模块绑定为当前层的属性
            # 作用:让 PyTorch 的优化器 (optimizer) 能追踪到 lora 的参数,否则无法训练
            setattr(module, "lora", lora)
            
            # 5. 备份原始方法
            # 保存该层原本的 forward 函数引用,以便在新函数中调用
            original_forward = module.forward

            # 6. 定义新的前向传播函数 (闭包)
            # 关键语法:使用默认参数 (layer1=..., layer2=...) 进行"显式绑定"
            # 原因:Python 闭包变量是"延迟绑定"的。如果直接引用 original_forward,
            #      循环结束后,所有层的 forward 都会指向最后一个层的 original_forward。
            #      通过默认参数,在函数定义时就将当前的对象"固化"下来。
            def forward_with_lora(x, layer1=original_forward, layer2=lora):
                # 新逻辑:原输出 + LoRA 分支输出
                # 对应公式:h = Wx + BAx
                return layer1(x) + layer2(x)

            # 7. 猴子补丁 (Monkey Patching)
            # 将层的 forward 方法替换为新定义的函数
            # 此后调用 module(x) 时,实际执行的是 forward_with_lora(x)
            module.forward = forward_with_lora

问题一:为什么要用 setattr?

把 LoRA 模块"注册"到原模型层上,让它成为模型的一部分。如果不加这行,LoRA 的参数无法被训练。类似于:

python 复制代码
class Person:
    pass

p = Person()

# 方法 1:直接赋值
p.name = "张三"

# 方法 2:用 setattr(效果一样)
setattr(p, "name", "张三")

print(p.name)  # 输出:张三

**原代码:**

setattr(module, "lora", lora)

问题二:什么是猴子补丁?
猴子补丁 是指在 运行时(Runtime)动态修改 类、模块或对象的行为,而不需要修改原始的源代码。

就像你买了一个成品玩具,发现它少个功能,你没有拆开重造,而是直接用胶带在外面粘了一个新按钮上去,让它有了新功能。

python 复制代码
**原始代码:**
# 我们通过这行代码,强行把它替换成了我们自己的函数
module.forward = forward_with_lora  

**示例:**
class A:
    def say_hello(self):
        print("Hello")

obj = A()
obj.say_hello()  # 输出:Hello

# 【猴子补丁】:运行时动态修改方法
def new_say_hello(self):
    print("Hi, I'm patched!")

A.say_hello = new_say_hello  # 直接替换类的方法

obj.say_hello()  # 输出:Hi, I'm patched! (行为变了)

问题三:闭包 (Closure) 与 延迟绑定陷阱
是什么?

闭包 是指一个函数定义在另一个函数内部,并且内部函数引用了外部函数的变量。即使外部函数执行结束了,内部函数依然能"记住"并使用那些变量。

而 Python 在函数里使用外部变量时,有两种策略:

  • 延迟绑定(Late Binding):函数里写的是变量名。等到函数被调用时,才去外面找这个变量现在的值。

比喻:你告诉朋友"去用那张桌子上的笔"。(朋友跑过去时,桌子上的笔可能已经被换掉了)

  • 默认参数绑定函数定义 时,把变量的值复制一份存进函数里。

比喻:你直接买了一支笔送给朋友。(不管后来桌子上的笔怎么变,朋友手里那支不变)

演示"陷阱"(延迟绑定):

python 复制代码
functions = []

for i in range(3):  # i 会依次变成 0, 1, 2
    def f():
        return i    # ⚠️ 注意:这里没有把 i 存下来,只是记了个名字 "i"
    functions.append(f)

# 循环结束后,i 变成了 2
print(functions[0]())  # 你以为是 0?实际输出 2
print(functions[1]())  # 你以为是 1?实际输出 2
print(functions[2]())  # 输出 2

为什么全变成了 2?

  1. 循环创建 f 时,Python 没有把 i 当时的值(0 或 1)存进 f 里。
  2. f 只是记住了:"我要去外面找一个叫 i 的变量"。
  3. 循环跑完后,内存里的 i 定格在 2。
  4. 当你调用 functions0 时,它去外面找 i,发现已经是 2 了。
  5. 所有函数共享同一个 i。

这就是延迟绑定陷阱:用的时候才去找,找到的都是最后的值。

在 LoRA 代码中的体现

python 复制代码
for name, module in model.named_modules():
    # ...
    original_forward = module.forward  # 外部变量
    
    # 内部函数引用了外部变量 original_forward
    def forward_with_lora(x, layer1=original_forward, ...): 
        return layer1(x) + ...
相关推荐
Dr.AE1 小时前
AI+金融 行业分析报告
大数据·人工智能·金融·产品经理
2501_945318491 小时前
非技术背景转型AI产品经理的可行性分析与详细路径图
人工智能·产品经理
星爷AG I1 小时前
12-5 共情(AGI基础理论)
人工智能·agi
小真zzz1 小时前
ChatPPT Nano Banana Pro · Magic模式深度解析 ——重新定义“所想即所得”的PPT智能编辑
人工智能·ai·powerpoint·ppt·aippt
进阶的鱼1 小时前
一文了解RAG———检索增强生成
人工智能·python·ai编程
Points1 小时前
飞哥学习人工智能之路第三讲:CNN、RNN与Transformer
人工智能
这儿有一堆花1 小时前
OpenAI 和 Paradigm 推出 EVMbench:AI 帮智能合约把关的新工具
人工智能·智能合约
一路往蓝-Anbo1 小时前
第 4 章:串口驱动进阶——GPDMA + Idle 中断实现变长数据流接收
linux·人工智能·stm32·单片机·嵌入式硬件
shangjian0071 小时前
AI-大模型应用开发-大模型生成参数调优速查表
人工智能