机器学习|从0开始大模型之模型LoRA训练

继续《从0开发大模型》系列文章,上一篇用全量数据做微调,训练时间太长,参数比较大,但是有一种高效的微调方式LoRA。

1、LoRA是如何实现的?

在深入了解 LoRA 之前,我们先回顾一下一些基本的线性代数概念。

1.1、秩

给定矩阵中线性独立的列(或行)的数量,称为矩阵的秩,记为 rank(A)

  • 矩阵的秩小于或等于列(或行)的数量,rank(A) ≤ min{m, n}
  • 满秩矩阵是所有的行或者列都独立,rank(A) = min{m, n}
  • 不满秩矩阵是满秩矩阵的反面是不满秩,即 rank(A) < min(m, n),矩阵的列(或行)不是彼此线性独立的

举个两个秩的例子:

不满秩

满秩

1.2、秩相关属性

从上面的秩的介绍中可以看出,矩阵的秩可以被理解为它所表示的特征空间的维度,在这种情况下,特定大小的低秩矩阵比相同维度的满秩矩阵封装更少的特征(或更低维的特征空间)。与之相关的属性如下:

  • 矩阵的秩受其行数和列数中最小值的约束,rank(A) ≤ min{m, n}
  • 两个矩阵的乘积的秩受其各自秩的最小值的约束,给定矩阵 AB,其中 rank(A) = mrank(A) = n,则 rank(AB) ≤ min{m, n}

1.3、LoRA

LoRA(Low rand adaption) 是微软研究人员提出的一种高效的微调技术,用于使大型模型适应特定任务和数据集。
LoRA 的背后的主要思想是模型微调期间权重的变化也具有较低的内在维度,具体来说,如果Wₙₖ代表单层的权重,ΔWₙₖ代表模型自适应过程中权重的变化,作者提出ΔWₙₖ是一个低秩矩阵,即:rank(ΔWₙₖ) << min(n,k)

为什么?

模型有了基座以后,如果强调学习少量的特征,那么就可以大大减少参数的更新量,而ΔWₙₖ就可以实现,这样就可以认为ΔWₙₖ是一个低秩矩阵。

实现原理

ΔWₙₖ是一个更新矩阵,然后ΔWₙₖ根据秩的属性,又可以拆分两个低秩矩阵的乘积,即:BₙᵣAᵣₖ ,其中 r << min{n,k}

这意味着网络中权重 Wx = Wx + ΔWx = Wx + BₙᵣAᵣₖx,由于 r 很小,所以 BₙᵣAᵣₖ 的参数数量非常少,所以只需要更新很少的参数。

LoRA

2、peft库

LoRA 训练非常方便,只需要借助 https://huggingface.co/blog/zh/peft 库,这是 huggingface 提供的,使用方法如下:

ini 复制代码
# 引入库
from peft import get_peft_model, LoraConfig, TaskType

# 创建对应的配置
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.01,
    bias="none"
    task_type="SEQ_2_SEQ_LM",
)

# 包装模型
model = AutoModelForSeq2SeqLM.from_pretrained(
    "t5-small",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

LoraConfig 详细参数如下:

  • r:秩,即上面的r,默认为8;
  • target_modules:对特定的模块进行微调,默认为None,支持nn.Linear、nn.Embedding和nn.Conv2d;
  • lora_alpha:ΔW 按 α / r 缩放,其中 α 是常数,默认为8;
  • task_type:任务类型,支持包括 CAUSAL_LM、FEATURE_EXTRACTION、QUESTION_ANS、SEQ_2_SEQ_LM、SEQ_CLS 和 TOKEN_CLS 等;
  • lora_dropout:Dropout 概率,默认为0,通过在训练过程中以 dropout 概率随机选择要忽略的神经元来减少过度拟合的技术;
  • bias:是否添加偏差,默认为 "none";

3、训练

使用 peft 库对SFT全量训练修改如下:

ini 复制代码
def init_model():
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    def find_all_linear_names(model):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, cls):
                names = name.split('.')
                lora_module_names.add(names[0] if len(names) == 1 else names[-1])

        return list(lora_module_names)

    model = Transformer(lm_config)
    ckp = f'./out/pretrain_{lm_config.dim}.pth.{batch_size}'
    state_dict = torch.load(ckp, map_location=device_type)
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)

    target_modules = find_all_linear_names(model)
    peft_config = LoraConfig(
        r=8,
        target_modules=target_modules
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    print(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
    model = model.to(device_type)
    return model

只需要修改模型初始化部分,其他不变,训练过程和之前一样,这里不再赘述。

参考

(1)cloud.tencent.com/developer/a...

(2)www.bimant.com/blog/lora-d...

(3)blog.csdn.net/shebao3333/...

相关推荐
好评笔记2 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
算家云2 小时前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
叫我:松哥4 小时前
基于Python django的音乐用户偏好分析及可视化系统设计与实现
人工智能·后端·python·mysql·数据分析·django
熊文豪5 小时前
深入解析人工智能中的协同过滤算法及其在推荐系统中的应用与优化
人工智能·算法
Vol火山5 小时前
AI引领工业制造智能化革命:机器视觉与时序数据预测的双重驱动
人工智能·制造
tuan_zhang6 小时前
第17章 安全培训筑牢梦想根基
人工智能·安全·工业软件·太空探索·战略欺骗·算法攻坚
Antonio9156 小时前
【opencv】第10章 角点检测
人工智能·opencv·计算机视觉
互联网资讯6 小时前
详解共享WiFi小程序怎么弄!
大数据·运维·网络·人工智能·小程序·生活
helianying557 小时前
AI赋能零售:ScriptEcho如何提升效率,优化用户体验
前端·人工智能·ux·零售
坐吃山猪7 小时前
机器学习10-解读CNN代码Pytorch版
pytorch·机器学习·cnn