【LLM学习笔记】第三篇:模型微调及LoRA介绍(附PyTorch实例)

文章目录

      • [1. 模型微调](#1. 模型微调)
        • [1.1 为什么要进行模型微调?](#1.1 为什么要进行模型微调?)
        • [1.2 模型微调的分类](#1.2 模型微调的分类)
      • [2. 高效模型微调LoRA](#2. 高效模型微调LoRA)
        • [2.1 LoRA概述](#2.1 LoRA概述)
        • [2.2 LoRA的机理](#2.2 LoRA的机理)
      • [3. LoRA的变体](#3. LoRA的变体)
        • [3.1 分布式LoRA (Distributed LoRA)](#3.1 分布式LoRA (Distributed LoRA))
        • [3.2 动态LoRA (Dynamic LoRA)](#3.2 动态LoRA (Dynamic LoRA))
        • [3.3 多任务LoRA (Multi-Task LoRA)](#3.3 多任务LoRA (Multi-Task LoRA))
        • [3.4 层级LoRA (Hierarchical LoRA)](#3.4 层级LoRA (Hierarchical LoRA))
        • [3.5 自适应LoRA (Adaptive LoRA)](#3.5 自适应LoRA (Adaptive LoRA))
        • [3.6 集成LoRA (Ensemble LoRA)](#3.6 集成LoRA (Ensemble LoRA))
      • [4. `peft`库](#4. peft库)
      • 5.结论

1. 模型微调

1.1 为什么要进行模型微调?

模型微调(Fine-Tuning)是一种在深度学习中广泛应用的技术,旨在通过在特定任务上对预训练模型进行进一步训练 ,以提高模型在该任务上的性能。这一过程不仅仅是简单的训练调整,而是通过利用预训练模型已有的知识和特征表示,使其更好地适应新的任务需求。

预训练模型通常是在大规模数据集上训练的,这些模型已经学到了丰富的、通用的特征表示。这些特征对于许多任务都是非常有价值的。例如,在自然语言处理任务中,预训练的语言模型(如BERT、GPT等)已经学会了词义、语法结构和语境信息;在计算机视觉任务中,预训练的卷积神经网络(如ResNet、VGG等)已经学会了边缘、纹理、形状等低级特征,以及物体、场景等高级特征。通过微调,可以将这些学到的特征迁移到新的任务中,从而提高模型的初始性能。

尽管预训练模型具有强大的通用特征提取能力,但它们可能不完全适合特定任务的数据分布和特点。通过微调,可以针对特定任务的数据进行进一步训练,使模型更好地适应任务需求。例如,在医学影像分类任务中,预训练模型可能已经学会了识别一般的图像特征,但通过微调,可以使模型更专注于识别特定的医学影像特征,从而提高分类准确率。

1.2 模型微调的分类

模型微调分为全局微调(全量微调)和局部微调:

  • 全局微调

    • 定义:全局微调是指对整个预训练模型的所有参数进行微调。这种方法适用于数据量较大且计算资源充足的情况。
    • 优点:可以充分利用所有参数的调整空间,使模型更好地适应特定任务。
    • 缺点:计算成本较高,容易导致过拟合,尤其是在数据量较少的情况下。
  • 局部微调

    • 定义:局部微调是指只对预训练模型的部分参数进行微调,通常是最后几层或特定的模块。这种方法适用于数据量较小或计算资源有限的情况。
    • 优点:计算成本较低,可以有效避免过拟合。
    • 缺点:可能无法充分利用预训练模型的全部潜力。

本文要介绍的LoRA即属于一种高效的局部微调方法。

2. 高效模型微调LoRA

2.1 LoRA概述

LoRA(Low-Rank Adaptation)是一种高效的模型微调技术,旨在通过在预训练模型中插入低秩矩阵来减少微调所需的参数量,从而提高训练效率并避免过拟合。LoRA 的核心思想是在保持预训练模型大部分参数不变的情况下,通过添加低秩矩阵来模拟参数的变化量,从而实现对特定任务的适应。

关于"低秩近似"的概念,我在此前的文章有介绍过深度学习中的常用线性代数知识汇总------第一篇:基础概念、秩、奇异值

研究表明:语言模型针对特定任务微调之后,权重矩阵通常具有很低的本征秩(Intrinsic Rank),参数更新量即便投影到较小的子空间中,也不会影响学习的有效性。因此,提出固定预训练模型参数不变,在原权重矩阵旁路添加低秩矩阵的乘积作为可训练参数,用以模拟参数的变化量。

2.2 LoRA的机理

具体来说,假设预训练权重为 W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0∈Rd×k,可训练参数为 Δ W = B A \Delta W = BA ΔW=BA,其中 A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r, B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k。初始化时,矩阵A通过高斯函数初始化,矩阵B为零初始化,使得训练开始之前旁路对原模型不造成影响,即参数变化量为0。对于该权重的输入 x ∈ R d x \in \mathbb{R}^d x∈Rd 来说,输出如下:

h = W 0 x + Δ W x = W 0 x + B A x h = W_0 x + \Delta W x = W_0 x + BA x h=W0x+ΔWx=W0x+BAx

3. LoRA的变体

3.1 分布式LoRA (Distributed LoRA)

分布式LoRA通过将低秩矩阵分布在多个层或模块中,进一步提高参数效率和模型性能。

  • 多层低秩矩阵 :在多个层中插入低秩矩阵,每个层的低秩矩阵可以有不同的秩 r i r_i ri。
  • 参数更新 :每个层的权重矩阵 W i W_i Wi 更新为 W i + B i A i W_i + B_i A_i Wi+BiAi。
  • 前向传播 :每个层的输出 h i h_i hi 计算为 h i = W i x + B i A i x h_i = W_i x + B_i A_i x hi=Wix+BiAix。
3.2 动态LoRA (Dynamic LoRA)

动态LoRA 通过在训练过程中动态调整低秩矩阵的秩 r r r,以适应不同的训练阶段和任务需求。

  • 动态秩调整 :在训练初期使用较小的秩 r r r,随着训练的进行逐渐增加秩 r r r。
  • 参数更新 :根据当前的秩 r r r动态计算$ \Delta W$。
  • 前向传播 : h = W 0 x + B ( r ) A ( r ) x h = W_0 x + B(r) A(r) x h=W0x+B(r)A(r)x,其中 B ( r ) B(r) B(r) 和 A ( r ) A(r) A(r)是根据当前秩 r r r 动态调整的矩阵。
3.3 多任务LoRA (Multi-Task LoRA)

多任务LoRA通过在多任务学习中共享低秩矩阵,提高模型的泛化能力和资源利用率。

  • 共享低秩矩阵 :多个任务共享同一个低秩矩阵 A A A和 B B B,但每个任务可以有自己的权重矩阵 W 0 W_0 W0。
  • 参数更新 :每个任务的权重矩阵 W 0 , i W_{0,i} W0,i 更新为 W 0 , i + B A W_{0,i} + BA W0,i+BA。
  • 前向传播 :每个任务的输出 h i h_i hi 计算为 h i = W 0 , i x + B A x h_i = W_{0,i} x + BA x hi=W0,ix+BAx。
3.4 层级LoRA (Hierarchical LoRA)

层级LoRA通过在不同层级的模块中插入低秩矩阵,实现更细粒度的参数控制和优化。

  • 多层级低秩矩阵 :在不同层级的模块中插入低秩矩阵,每个模块的低秩矩阵可以有不同的秩 r i r_i ri。
  • 参数更新 :每个模块的权重矩阵 W i W_i Wi 更新为 W i + B i A i W_i + B_i A_i Wi+BiAi。
  • 前向传播 :每个模块的输出 h i h_i hi 计算为 h i = W i x + B i A i x h_i = W_i x + B_i A_i x hi=Wix+BiAix。
3.5 自适应LoRA (Adaptive LoRA)

自适应LoRA通过引入自适应机制,使低秩矩阵根据输入数据动态调整,从而提高模型的适应性和性能。

  • 自适应低秩矩阵 :低秩矩阵 A A A 和 B B B根据输入数据 x x x动态调整。
  • 参数更新 : Δ W = B ( x ) A ( x ) \Delta W = B(x) A(x) ΔW=B(x)A(x)。
  • 前向传播 : h = W 0 x + B ( x ) A ( x ) x h = W_0 x + B(x) A(x) x h=W0x+B(x)A(x)x。
3.6 集成LoRA (Ensemble LoRA)

集成LoRA通过组合多个低秩矩阵,提高模型的表达能力和鲁棒性。

  • 多个低秩矩阵 :使用多个低秩矩阵 A i A_i Ai 和 B i B_i Bi,每个矩阵可以有不同的秩 r i r_i ri。
  • 参数更新 :权重矩阵 W W W更新为 W = W 0 + ∑ i B i A i W = W_0 + \sum_i B_i A_i W=W0+∑iBiAi。
  • 前向传播 : h = W 0 x + ∑ i B i A i x h = W_0 x + \sum_i B_i A_i x h=W0x+∑iBiAix。

LoRA的这些变体在不同的应用场景中各有优势,可以根据具体任务的需求选择合适的变体。基本LoRA适用于大多数微调任务,而分布式LoRA、动态LoRA、多任务LoRA、层级LoRA、自适应LoRA和集成LoRA则在特定场景下提供了更高的灵活性和性能。通过这些变体,LoRA技术可以更好地适应各种复杂任务,提高模型的性能和效率。

4. peft

peft(Parameter-Efficient Fine-Tuning)是一个用于参数高效微调的库,支持多种方法,包括LoRA。下面是一个基于PyTorch和peft库的LoRA微调示例代码。这个示例将展示如何使用Hugging Face的transformers库和peft库来微调一个预训练的BERT模型。

python 复制代码
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import LoraConfig, get_peft_model

# 加载预训练模型和分词器
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# 定义LoRA配置
lora_config = LoraConfig(
    r=8,  # 低秩矩阵的秩
    lora_alpha=16,  # LoRA的缩放因子
    target_modules=["query", "value"]  # 需要应用LoRA的模块
)

# 将LoRA配置应用到模型
model = get_peft_model(model, lora_config)

# 准备一些简单的数据
train_texts = ["I love this movie.", "This movie is terrible."]
train_labels = [1, 0]  # 1 表示正类,0 表示负类

# 对数据进行编码
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128, return_tensors='pt')

# 定义损失函数和优化器
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# 训练模型
model.train()
for epoch in range(3):  # 训练3个epoch
    optimizer.zero_grad()
    outputs = model(input_ids=train_encodings['input_ids'], attention_mask=train_encodings['attention_mask'])
    loss = loss_fn(outputs.logits, torch.tensor(train_labels))
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# 保存模型
model.save_pretrained('./simple_lora_model')
tokenizer.save_pretrained('./simple_lora_model')

print("Training complete!")

5.结论

本文介绍了模型微调的目的及方法,详细说明了局部微调中的LoRA方法。

相关推荐
__lost1 分钟前
Python 使用 OpenCV 将 MP4 转换为 GIF图
开发语言·python·opencv
Ciderw14 分钟前
AI 在软件开发流程中的优势、挑战及应对策略
人工智能·ai
霍夫曼vx_helloworld735216 分钟前
(二)手势识别——动作模型训练【代码+数据集+python环境(免安装)+GUI系统】
开发语言·python
神仙别闹21 分钟前
基于Python实现三种不同类型BP网络及分析
开发语言·python
Struart_R1 小时前
Edify 3D: Scalable High-Quality 3D Asset Generation 论文解读
人工智能·深度学习·3d·扩散模型·三维生成·三维资产
陈健平1 小时前
2024最新YT-DLP使用demo网页端渲染
python·fastapi·jinja2·yt-dlp·yt_dlp
声网1 小时前
Runway 新增视频扩展画面功能;Anthropic 再获亚马逊投资 40 亿美元,聚焦 AI 芯片研发丨 RTE 开发者日报
人工智能
量子位1 小时前
将活体神经元植入大脑,他和马斯克闹掰后开辟脑机接口新路线
人工智能
forestsea1 小时前
【Java 解释器模式】实现高扩展性的医学专家诊断规则引擎
java·人工智能·设计模式·解释器模式
程序员奇奥2 小时前
CentOS中使用Python将文本中的IP地址替换为外网地址
python·tcp/ip·centos