使用 LoRA 进行模型微调的步骤

使用 LoRA 进行模型微调的步骤

以下是关于如何在训练过程中应用 LoRA 的详细步骤

1. 安装所需的库

首先,确保已安装 PyTorch 和 Hugging Face 的 transformers 库:

bash 复制代码
pip install torch transformers

2. 定义 LoRA 层

定义一个 LoRA 模块,用于替换 transformer 中的标准线性层,通常在自注意力机制的 query、key 和 value 投影中使用

python 复制代码
import torch
import torch.nn as nn

# 定义 LoRA 模块
class LoRA(nn.Module):
    def __init__(self, input_dim, output_dim, rank=8):
        super(LoRA, self).__init__()
        # LoRA 引入了两个额外的矩阵 W_down 和 W_up
        self.W_down = nn.Linear(input_dim, rank, bias=False)  # 低秩降维
        self.W_up = nn.Linear(rank, output_dim, bias=False)   # 低秩升维

    def forward(self, x):
        # 将低秩适配结果加到原始输出上
        return self.W_up(self.W_down(x))

3. 修改 Transformer 模型以应用 LoRA

将 transformer 中的 query、key 和 value 投影替换为 LoRA 模块

python 复制代码
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载预训练模型
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 定义 LoRA 版的 Transformer 模型
class LoRATransformer(nn.Module):
    def __init__(self, base_model, lora_rank=8):
        super(LoRATransformer, self).__init__()
        self.base_model = base_model

        # 遍历 transformer 层并应用 LoRA
        for name, module in self.base_model.named_modules():
            # 对 self-attention 中的线性层应用 LoRA
            if isinstance(module, nn.Linear) and 'attn' in name:
                input_dim = module.in_features
                output_dim = module.out_features
                lora_module = LoRA(input_dim, output_dim, rank=lora_rank)
                # 替换原始的线性层为 LoRA 模块
                setattr(self.base_model, name, lora_module)

    def forward(self, input_ids, attention_mask=None):
        return self.base_model(input_ids, attention_mask=attention_mask)

# 初始化应用了 LoRA 的模型
lora_model = LoRATransformer(model)

4. 准备训练

在训练之前,冻结预训练模型的参数,仅更新 LoRA 模块中的参数。

python 复制代码
# 冻结原始模型参数
for param in model.parameters():
    param.requires_grad = False

# 仅训练 LoRA 层的参数
for name, param in lora_model.named_parameters():
    if 'W_down' in name or 'W_up' in name:
        param.requires_grad = True

5. 开始训练模型

现在可以开始训练模型了,使用优化器来更新 LoRA 模块的参数

python 复制代码
# 定义优化器,仅更新 LoRA 层的参数
optimizer = torch.optim.Adam([param for param in lora_model.parameters() if param.requires_grad], lr=1e-4)

# 损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()

# 示例训练循环
for epoch in range(3):  # 调整 epoch 数量
    for batch in dataloader:  # 假设 'dataloader' 是已经准备好的数据加载器
        inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
        labels = inputs.input_ids.clone()  # 使用输入 tokens 作为标签

        # 前向传播
        outputs = lora_model(inputs.input_ids)
        logits = outputs.logits
        
        # 计算损失
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"第 {epoch+1} 轮训练,损失: {loss.item()}")

6. 保存与加载 LoRA 增强的模型

训练完成后,可以保存 LoRA 层的参数,并在未来加载这些参数进行推理或进一步训练。

python 复制代码
# 保存 LoRA 模型的参数
torch.save(lora_model.state_dict(), "lora_model.pth")

# 加载 LoRA 模型的参数
lora_model.load_state_dict(torch.load("lora_model.pth"))

LoRA 专注于注意力层:LoRA 通常应用于 transformer 模型中的自注意力机制(例如 query、key 和 value 投影层)。

参数高效性:与全量微调相比,LoRA 仅引入少量可训练参数,从而显著减少了计算成本。

冻结预训练权重:预训练模型的原始权重保持不变,仅更新 LoRA 模块的低秩矩阵,从而实现高效微调。

相关推荐
Cha0DD18 小时前
【由浅入深探究langchain】第二十集-SQL Agent+Human-in-the-loop
人工智能·python·ai·langchain
Cha0DD18 小时前
【由浅入深探究langchain】第十九集-官方的SQL Agent示例
人工智能·python·ai·langchain
智算菩萨19 小时前
【Tkinter】4 Tkinter Entry 输入框控件深度解析:数据验证、密码输入与现代表单设计实战
python·ui·tkinter·数据验证·entry·输入框
七夜zippoe20 小时前
可解释AI:构建可信的机器学习系统——反事实解释与概念激活实战
人工智能·python·机器学习·可解释性·概念激活
YuanDaima20481 天前
[CrewAI] 第15课|构建一个多代理系统来实现自动化简历定制和面试准备
人工智能·python·面试·agent·crewai
WHS-_-20221 天前
Python 算法题学习笔记一
python·学习·算法
码界筑梦坊1 天前
353-基于Python的大湾区气候数据可视化分析系统
开发语言·python·信息可视化·数据分析·django·vue·毕业设计
如何原谅奋力过但无声1 天前
【chap11-动态规划(上 - 基础题目&背包问题)】用Python3刷《代码随想录》
数据结构·python·算法·动态规划
云姜.1 天前
JSON Schema使用
python·json
Sunshine for you1 天前
使用Flask快速搭建轻量级Web应用
jvm·数据库·python