机器学习|从0开始大模型之模型LoRA训练

继续《从0开发大模型》系列文章,上一篇用全量数据做微调,训练时间太长,参数比较大,但是有一种高效的微调方式LoRA。

1、LoRA是如何实现的?

在深入了解 LoRA 之前,我们先回顾一下一些基本的线性代数概念。

1.1、秩

给定矩阵中线性独立的列(或行)的数量,称为矩阵的秩,记为 rank(A)

  • 矩阵的秩小于或等于列(或行)的数量,rank(A) ≤ min{m, n}
  • 满秩矩阵是所有的行或者列都独立,rank(A) = min{m, n}
  • 不满秩矩阵是满秩矩阵的反面是不满秩,即 rank(A) < min(m, n),矩阵的列(或行)不是彼此线性独立的

举个两个秩的例子:

不满秩

满秩

1.2、秩相关属性

从上面的秩的介绍中可以看出,矩阵的秩可以被理解为它所表示的特征空间的维度,在这种情况下,特定大小的低秩矩阵比相同维度的满秩矩阵封装更少的特征(或更低维的特征空间)。与之相关的属性如下:

  • 矩阵的秩受其行数和列数中最小值的约束,rank(A) ≤ min{m, n}
  • 两个矩阵的乘积的秩受其各自秩的最小值的约束,给定矩阵 AB,其中 rank(A) = mrank(A) = n,则 rank(AB) ≤ min{m, n}

1.3、LoRA

LoRA(Low rand adaption) 是微软研究人员提出的一种高效的微调技术,用于使大型模型适应特定任务和数据集。
LoRA 的背后的主要思想是模型微调期间权重的变化也具有较低的内在维度,具体来说,如果Wₙₖ代表单层的权重,ΔWₙₖ代表模型自适应过程中权重的变化,作者提出ΔWₙₖ是一个低秩矩阵,即:rank(ΔWₙₖ) << min(n,k)

为什么?

模型有了基座以后,如果强调学习少量的特征,那么就可以大大减少参数的更新量,而ΔWₙₖ就可以实现,这样就可以认为ΔWₙₖ是一个低秩矩阵。

实现原理

ΔWₙₖ是一个更新矩阵,然后ΔWₙₖ根据秩的属性,又可以拆分两个低秩矩阵的乘积,即:BₙᵣAᵣₖ ,其中 r << min{n,k}

这意味着网络中权重 Wx = Wx + ΔWx = Wx + BₙᵣAᵣₖx,由于 r 很小,所以 BₙᵣAᵣₖ 的参数数量非常少,所以只需要更新很少的参数。

LoRA

2、peft库

LoRA 训练非常方便,只需要借助 https://huggingface.co/blog/zh/peft 库,这是 huggingface 提供的,使用方法如下:

ini 复制代码
# 引入库
from peft import get_peft_model, LoraConfig, TaskType

# 创建对应的配置
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.01,
    bias="none"
    task_type="SEQ_2_SEQ_LM",
)

# 包装模型
model = AutoModelForSeq2SeqLM.from_pretrained(
    "t5-small",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

LoraConfig 详细参数如下:

  • r:秩,即上面的r,默认为8;
  • target_modules:对特定的模块进行微调,默认为None,支持nn.Linear、nn.Embedding和nn.Conv2d;
  • lora_alpha:ΔW 按 α / r 缩放,其中 α 是常数,默认为8;
  • task_type:任务类型,支持包括 CAUSAL_LM、FEATURE_EXTRACTION、QUESTION_ANS、SEQ_2_SEQ_LM、SEQ_CLS 和 TOKEN_CLS 等;
  • lora_dropout:Dropout 概率,默认为0,通过在训练过程中以 dropout 概率随机选择要忽略的神经元来减少过度拟合的技术;
  • bias:是否添加偏差,默认为 "none";

3、训练

使用 peft 库对SFT全量训练修改如下:

ini 复制代码
def init_model():
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    def find_all_linear_names(model):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, cls):
                names = name.split('.')
                lora_module_names.add(names[0] if len(names) == 1 else names[-1])

        return list(lora_module_names)

    model = Transformer(lm_config)
    ckp = f'./out/pretrain_{lm_config.dim}.pth.{batch_size}'
    state_dict = torch.load(ckp, map_location=device_type)
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)

    target_modules = find_all_linear_names(model)
    peft_config = LoraConfig(
        r=8,
        target_modules=target_modules
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    print(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
    model = model.to(device_type)
    return model

只需要修改模型初始化部分,其他不变,训练过程和之前一样,这里不再赘述。

参考

(1)cloud.tencent.com/developer/a...

(2)www.bimant.com/blog/lora-d...

(3)blog.csdn.net/shebao3333/...

相关推荐
十有久诚38 分钟前
SVL-Adapter: Self-Supervised Adapter for Vision-Language Pretrained Models
人工智能·深度学习·计算机视觉·视觉语言模型·适配器微调
学习前端的小z1 小时前
【AI绘画】Midjourney进阶:色调详解(上)
人工智能·ai作画·aigc·midjourney
cv2016_DL1 小时前
BERT相关知识
人工智能·算法·transformer
搏博1 小时前
神经网络中的损失函数(Loss Function)
人工智能·深度学习·神经网络
大大大反派1 小时前
深入了解决策树---机器学习中的经典算法
算法·决策树·机器学习
分才1 小时前
用天翼云搭建一个HivisionIDPhoto证件照处理网站
人工智能·学习·html5
数据猎手小k1 小时前
FineTuneBench:由斯坦福大学创建,包含625个训练问题和1075个测试问题,覆盖4个领域。目的评估商业微调API在不同泛化任务中的知识注入能力。
人工智能·深度学习·机器学习·数据集·机器学习数据集·ai大模型应用
傻瓜学习2 小时前
机器学习知识点
人工智能·机器学习
Seeklike2 小时前
11.26 深度学习-初始化
人工智能·深度学习
池央2 小时前
深度学习模型:循环神经网络(RNN)
人工智能·rnn·深度学习