机器学习|从0开始大模型之模型LoRA训练

继续《从0开发大模型》系列文章,上一篇用全量数据做微调,训练时间太长,参数比较大,但是有一种高效的微调方式LoRA。

1、LoRA是如何实现的?

在深入了解 LoRA 之前,我们先回顾一下一些基本的线性代数概念。

1.1、秩

给定矩阵中线性独立的列(或行)的数量,称为矩阵的秩,记为 rank(A)

  • 矩阵的秩小于或等于列(或行)的数量,rank(A) ≤ min{m, n}
  • 满秩矩阵是所有的行或者列都独立,rank(A) = min{m, n}
  • 不满秩矩阵是满秩矩阵的反面是不满秩,即 rank(A) < min(m, n),矩阵的列(或行)不是彼此线性独立的

举个两个秩的例子:

不满秩

满秩

1.2、秩相关属性

从上面的秩的介绍中可以看出,矩阵的秩可以被理解为它所表示的特征空间的维度,在这种情况下,特定大小的低秩矩阵比相同维度的满秩矩阵封装更少的特征(或更低维的特征空间)。与之相关的属性如下:

  • 矩阵的秩受其行数和列数中最小值的约束,rank(A) ≤ min{m, n}
  • 两个矩阵的乘积的秩受其各自秩的最小值的约束,给定矩阵 AB,其中 rank(A) = mrank(A) = n,则 rank(AB) ≤ min{m, n}

1.3、LoRA

LoRA(Low rand adaption) 是微软研究人员提出的一种高效的微调技术,用于使大型模型适应特定任务和数据集。
LoRA 的背后的主要思想是模型微调期间权重的变化也具有较低的内在维度,具体来说,如果Wₙₖ代表单层的权重,ΔWₙₖ代表模型自适应过程中权重的变化,作者提出ΔWₙₖ是一个低秩矩阵,即:rank(ΔWₙₖ) << min(n,k)

为什么?

模型有了基座以后,如果强调学习少量的特征,那么就可以大大减少参数的更新量,而ΔWₙₖ就可以实现,这样就可以认为ΔWₙₖ是一个低秩矩阵。

实现原理

ΔWₙₖ是一个更新矩阵,然后ΔWₙₖ根据秩的属性,又可以拆分两个低秩矩阵的乘积,即:BₙᵣAᵣₖ ,其中 r << min{n,k}

这意味着网络中权重 Wx = Wx + ΔWx = Wx + BₙᵣAᵣₖx,由于 r 很小,所以 BₙᵣAᵣₖ 的参数数量非常少,所以只需要更新很少的参数。

LoRA

2、peft库

LoRA 训练非常方便,只需要借助 https://huggingface.co/blog/zh/peft 库,这是 huggingface 提供的,使用方法如下:

ini 复制代码
# 引入库
from peft import get_peft_model, LoraConfig, TaskType

# 创建对应的配置
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.01,
    bias="none"
    task_type="SEQ_2_SEQ_LM",
)

# 包装模型
model = AutoModelForSeq2SeqLM.from_pretrained(
    "t5-small",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

LoraConfig 详细参数如下:

  • r:秩,即上面的r,默认为8;
  • target_modules:对特定的模块进行微调,默认为None,支持nn.Linear、nn.Embedding和nn.Conv2d;
  • lora_alpha:ΔW 按 α / r 缩放,其中 α 是常数,默认为8;
  • task_type:任务类型,支持包括 CAUSAL_LM、FEATURE_EXTRACTION、QUESTION_ANS、SEQ_2_SEQ_LM、SEQ_CLS 和 TOKEN_CLS 等;
  • lora_dropout:Dropout 概率,默认为0,通过在训练过程中以 dropout 概率随机选择要忽略的神经元来减少过度拟合的技术;
  • bias:是否添加偏差,默认为 "none";

3、训练

使用 peft 库对SFT全量训练修改如下:

ini 复制代码
def init_model():
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    def find_all_linear_names(model):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, cls):
                names = name.split('.')
                lora_module_names.add(names[0] if len(names) == 1 else names[-1])

        return list(lora_module_names)

    model = Transformer(lm_config)
    ckp = f'./out/pretrain_{lm_config.dim}.pth.{batch_size}'
    state_dict = torch.load(ckp, map_location=device_type)
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)

    target_modules = find_all_linear_names(model)
    peft_config = LoraConfig(
        r=8,
        target_modules=target_modules
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    print(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
    model = model.to(device_type)
    return model

只需要修改模型初始化部分,其他不变,训练过程和之前一样,这里不再赘述。

参考

(1)cloud.tencent.com/developer/a...

(2)www.bimant.com/blog/lora-d...

(3)blog.csdn.net/shebao3333/...

相关推荐
风止何安啊4 分钟前
我一个前端仔,居然用 Python 搞起了 AI?从零到一,撸了个 AI 聊天框小 demo
前端·人工智能·后端
装不满的克莱因瓶6 分钟前
图像尺寸调整:缩放矩阵如何改变像素坐标?
人工智能·线性代数·数学·算法·机器学习·矩阵
GlobalInfo6 分钟前
八旋翼无人机产业洞察与市场占有率演变:2026年趋势分析报告
人工智能·无人机
GISer_Jing6 分钟前
Claude Code插件系统全解析
前端·人工智能·ai·架构
AI前沿资讯9 分钟前
2026年AI 3D赛道新势力崛起:一体化创作平台成主流,V2Fun凭全流程能力突围
人工智能·3d
猫头虎15 分钟前
Cursor推出的Composer 2.5 是什么?从定向 RL 到合成数据,AI 编程智能体再进化
人工智能·开源·prompt·aigc·copilot·ai编程·composer
触底反弹24 分钟前
给 Claude 装上 27 个「外挂」后,我直接起飞了!
人工智能·react.js
KaMeidebaby25 分钟前
卡梅德生物技术快报|peg 修饰调控 MXene/WS2 异质结,氨气传感器制备与机理研究
大数据·前端·人工智能·架构·spark·新浪微博
ydyd2026042127 分钟前
设备管理应用推荐2026深度测评!
大数据·人工智能·机器学习
美狐美颜SDK开放平台28 分钟前
从采集到渲染:直播APP开发与实时美颜SDK技术实现全流程详解
人工智能·美颜sdk·直播美颜sdk·第三方美颜sdk·视频美颜sdk·美颜api