文章目录
背景介绍
本文将一步一步地,介绍如何使用llamafactory框架利用开源大语言模型完成文本分类的实验,以 LoRA微调 qwen/Qwen2.5-7B-Instruct
为例。
文本分类数据集
按照 alpaca 样式构建数据集,并在将其添加到 LLaMA-Factory/data/dataset_info.json
文件中。如此方便直接根据自定义数据集的名字,获取到数据集的数据。
python
[
{
"instruction": "",
"input": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:\n\n要求}}\nreason: \nlabel:",
"output": "reason: 该文本主要讨论的是xxx。因此,该文本最符合"社会管理"这一类别。\n\nlabel: 社会管理"
},
...
]
Lora 微调
llamafactory 框架支持网页端训练,但本文选择在终端使用命令行微调模型。
模型微调训练的参数较多,将模型训练的参数都存储在 yaml 文件中。
qwen_train_cls.yaml
的文件内容如下:
yaml
### model
model_name_or_path: qwen/Qwen2.5-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
# dataset_dir: data
dataset_dir: LLaMA-Factory/data/ 填写相应路径
dataset: 数据集名
template: qwen
cutoff_len: 2048
# max_samples: 1000 若数据集较大,可随机筛选一部分数据微调模型
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: output/qwen2.5-7B/cls_epoch2 训练的LoRA权重输出路径
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
使用下述命令启动模型训练:
bash
nohup llamafactory-cli train qwen_train_cls.yaml > qwen_train_cls.log 2>&1 &
命令分解介绍:
nohup, 全称为 "no hangup"(不要挂起)。它的作用是让命令在退出终端后仍然运行,防止因关闭终端或会话中断导致进程被终止。
默认情况下,nohup 会将输出重定向到 nohup.out 文件,但这里已经显式指定了输出位置。
llamafactory-cli train qwen_train_cls.yaml 运行 llamafactory-cli 工具,用于执行训练任务。
train 是子命令,表示进行训练。
qwen_train_cls.yaml 是一个配置文件,包含训练所需的超参数、数据路径、模型结构等。
qwen_train_cls.log
将标准输出 (stdout) 重定向到 qwen_train_cls.log 文件中。
即运行过程中的正常日志信息会被记录到这个文件。
2>&1: 将标准错误输出 (stderr) 重定向到标准输出 (stdout)。
这样,所有错误信息也会被写入到 qwen_train_cls.log 文件中。
&: 表示将整个命令放到后台运行。终端会立即返回,您可以继续进行其他操作,而不用等待命令完成。
模型部署与推理
模型训练完成后得到 Lora 权重。相关微调模型部署与推理,请浏览下述两篇文章,相比llamafactory原本的模型推理速度更快。
- 基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645
- 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现
目前llamafactory已经支持 vllm_infer 推理,这个PR是笔者提交的:
期待模型的输出结果
下述是使用 llamafactory 推理出的数据格式,建议大家在做推理评估时,也做成这个样式,方便统一评估。
json
{
"prompt": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:...",
"predict": "\nreason: 该文本主要讨论了改革创新发展、行政区划调整、行政管理体制等方面的内容,涉及到体制机制的改革与完善,旨在推动高质量发展和提升生活品质。这些内容与社会管理和经济管理密切相关,但更侧重于行政管理和社会治理的改革,因此更符合"社会管理"这一类别。\n\nlabel: 社会管理",
"label": "reason: 该文本主要讨论的是改革创新、行政区划调整、体制机制障碍的破除以及行政管理体制等与政府治理和社会管理相关的内容,强调了与高质量发展和生活品质的关系。这些内容显示出对社会管理和行政管理的关注,尤其是在推动城乡一体化和适应高质量发展要求方面。因此,该文本最符合"社会管理"这一类别。\n\nlabel: 社会管理"
}
文本分类评估代码
python
import os
import re
import json
from sklearn.metrics import classification_report, confusion_matrix
# 文本类别
CLASS_NAME = [
"产业相关",
...
"法律法规与行政事务",
"其他",
]
def load_jsonl(file_path):
"""
加载指定路径的 JSON 文件并返回解析后的数据。
:param file_path: JSON 文件的路径
:return: 解析后的数据(通常是字典或列表)
:raises FileNotFoundError: 如果文件未找到
:raises json.JSONDecodeError: 如果 JSON 格式不正确
"""
data = []
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
tmp = json.loads(line)
data.append(tmp)
except FileNotFoundError as e:
print(f"文件未找到:{file_path}")
raise e
except json.JSONDecodeError as e:
print(f"JSON 格式错误:{e}")
raise e
return data
def parser_label(text: str):
pattern = r"label[::\s\.\d\*]*([^\s^\*]+)"
matches = re.findall(pattern, text, re.DOTALL)
if len(matches) == 1:
return matches[0]
return None
def trans2num(item):
predict = parser_label(item["predict"])
label = parser_label(item["label"])
predict_idx = -1
label_idx = -1
for idx, cls_name in enumerate(CLASS_NAME):
if predict == cls_name:
predict_idx = idx
if label == cls_name:
label_idx = idx
return predict_idx, label_idx
def cls_eval(input_file):
data = load_jsonl(file_path=input_file)
predicts = []
labels = []
for item in data:
predict, label = trans2num(item)
if label == -1:
continue
predicts.append(predict)
labels.append(label)
return classification_report(predicts, labels, output_dict=False)
本文使用了大模型生成式预测文本类别,我没有使用结构化输出的方式,大家可以使用结构化的json格式输出,这样在提取大模型预测结果的时候会方便很多。
大家按照自己模型的输出结果,修改parser_label
函数,这个函数用于从大模型的输出结果提取label。
bash
cls_eval("xxx/generated_predictions.jsonl")
就会得到下述的输出结果:
-1
代表模型预测的类别不在给定的类别中。