LLM代码实现-Qwen(模型微调)

简介

LLM 通用模型在各种任务上表现良好,我们可以将它们用作对目标任务进行微调的基础。 微调允许我们使模型适应目标域和目标任务,使其可以更好地完成我们所需要的特定的任务。

目前模型微调方法主要有 Full(全参微调)、Freeze、P-tuning、LoRA、QLora,这些这些方法都各有优势,关于它们的介绍也有很多,本篇主要讲解代码实现,原理方面就不赘述了,同样考虑到不是所有读者都有足够的算力,因此使用占用资源最少的 QLora 对模型进行微调,各个微调方式占用资源如下图所示(batch_size 设置为 1)。

这里推荐使用 firefly(流萤)项目来实现模型微调,这个项目主要是为了微调多轮对话数据集,不过单轮对话也同样适用。

Firefly 项目训练多轮对话模型时,采取了一种更加充分高效的方法。如下图所示,将一条多轮对话数据拼接之后,输入模型,并行计算每个位置的 loss,只有 Assistant 部分的 loss 参与权重更新。 为什么这种做法是可行的?答案在于因果语言模型的 attention mask。以 GPT 为代表的 Causal Language Model(因果语言模型),这种模型的 attention mask 是一个对角掩码矩阵,每个 token 在编码的时候,只能看到它之前的 token,看不到它之后的 token。 所以 User1 部分的编码输出,只能感知到 User1 的内容,无法感知到它之后的文本,可以用来预测 Assistant1 的内容。而 User2 部分的编码输出,只能看到 User1、Assistant1、User2 的内容,可以用来预测 Assistant2 的内容,依此类推。对于整个序列,只需要输入模型一次,便可并行获得每个位置的 logits,从而用来计算 loss。

训练环境配置

首先 pull 项目并配置环境:

python 复制代码
git clone https://github.com/yangjianxin1/Firefly.git

cd Firefly
pip install -r requirements.txt

然后找到 1. Firefly/train_args/sft/qlora 路径下的 qwen-7b-sft-qlora.json(虽然文件名是 7b 但是 qwen 都可以通用),对里面的内容进行修改,主要修改模型路径和训练文件,其余参数可以看着修改。

json 复制代码
{
    "output_dir": "output/firefly-qwen-1_8b-sft-qlora",
    "model_name_or_path": "Qwen/Qwen-1_8B-Chat",
    "train_file": "./data/dummy_data.jsonl",
    "template_name": "qwen",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 16,
    "learning_rate": 2e-4,
    "max_seq_length": 1024,
    "logging_steps": 100,
    "save_steps": 100,
    "save_total_limit": 1,
    "lr_scheduler_type": "constant_with_warmup",
    "warmup_steps": 100,
    "lora_rank": 64,
    "lora_alpha": 128,
    "lora_dropout": 0.05,

    "gradient_checkpointing": true,
    "disable_tqdm": false,
    "optim": "paged_adamw_32bit",
    "seed": 42,
    "fp16": true,
    "report_to": "tensorboard",
    "dataloader_num_workers": 0,
    "save_strategy": "steps",
    "weight_decay": 0,
    "max_grad_norm": 0.3,
    "remove_unused_columns": false
}

数据准备

这里的数据集是官方提供的一个测试数据集,用于跑通流程,以下是其中一条数据: 这个 json 中只用 "conversation_id" 和 "conversation" 两个是训练需要的,其他的不必在意。"conversation_id" 表示对话的序号,"conversation" 对应的是一个列表,元素是字典,每个字典中有 "human" 和 "assistant" 两个键,分别表示用户和模型的说话内容。如果要用自己的数据集也要按照这种格式进行修改。

开始训练

利用以下代码开始训练:

python 复制代码
python train.py --train_args_file train_args/sft/qlora/qwen-7b-sft-qlora.json

训练完成后会在配置文件中设置的 output_dir 生成对应的 QLora 文件,可利用 Firefly/script/chat 路径下的两个脚本进行调用来测试模型微调的效果。

相关推荐
fengfuyao98517 分钟前
MATLAB的加权K-means(Warp-KMeans)聚类算法
算法·matlab·kmeans
Elastic 中国社区官方博客32 分钟前
Elasticsearch:如何为 Elastic Stack 部署 E5 模型 - 下载及隔离环境
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
xier_ran32 分钟前
深度学习:神经网络中的参数和超参数
人工智能·深度学习
8Qi842 分钟前
伪装图像生成之——GAN与Diffusion
人工智能·深度学习·神经网络·生成对抗网络·图像生成·伪装图像生成
循环过三天1 小时前
3.1、Python-列表
python·算法
青青草原羊村懒大王1 小时前
python基础知识三
开发语言·python
阿里云大数据AI技术1 小时前
PAI Physical AI Notebook详解2:基于Cosmos世界模型的操作动作数据扩增与模仿学习
人工智能
傻啦嘿哟1 小时前
Python高效实现Word转HTML:从基础到进阶的全流程方案
人工智能·python·tensorflow
该用户已不存在1 小时前
Gemini CLI 核心命令指南,让工作从从容容游刃有余
人工智能·程序员·aigc
思通数科多模态大模型2 小时前
扑灭斗殴的火苗:AI智能守护如何为校园安全保驾护航
大数据·人工智能·深度学习·安全·目标检测·计算机视觉·数据挖掘