背景
LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持 vllm 推理,故导致推理速度不够快。
LlamaFactory 目前支持通过 VLLM API 进行部署,调用 API 时的响应速度,仍然没有vllm批量推理的速度快。
如果模型是通过 LlamaFactory 微调的,为了确保数据集的一致性,建议在推理时也使用 LlamaFactory 提供的封装数据集。
简介
在上述的背景下,我们使用 LlamaFactory 原生数据集,支持 lora的 vllm 批量推理。
完整代码如下:
c
import json
import os
from typing import List
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
def vllm_infer():
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
)
tokenizer = load_tokenizer(model_args)["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
eval_dataset = get_dataset(
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
)["eval_dataset"]
prompts = [item["input_ids"] for item in eval_dataset]
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
labels = [
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
for item in eval_dataset
]
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
sampling_params = SamplingParams(
temperature=generating_args.temperature,
top_k=generating_args.top_k,
top_p=generating_args.top_p,
max_tokens=2048,
)
if model_args.adapter_name_or_path:
if isinstance(model_args.adapter_name_or_path, list):
lora_requests = []
for i, _lora_path in enumerate(model_args.adapter_name_or_path):
lora_requests.append(
LoRARequest(f"lora_adapter_{i}", i, lora_path=_lora_path)
)
else:
lora_requests = LoRARequest(
"lora_adapter_0", 0, lora_path=model_args.adapter_name_or_path
)
enable_lora = True
else:
lora_requests = None
enable_lora = False
llm = LLM(
model=model_args.model_name_or_path,
trust_remote_code=True,
tokenizer=model_args.model_name_or_path,
enable_lora=enable_lora,
)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)
output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, outputs, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))
vllm.yaml
示例:
c
## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct
# adapter_name_or_path: lora模型
### method
stage: sft
do_predict: true
finetuning_type: lora
### dataset
dataset_dir: 数据集路径
eval_dataset: 数据集
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: output/
overwrite_output_dir: true
### eval
predict_with_generate: true
程序调用:
shell
python vllm_infer.py vllm.yaml
程序运行速度:
shell
Processed prompts: 100%|█| 1000/1000 [01:56<00:00, 8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57
总结
本方案在原生 LlamaFactory 数据集的基础上,支持 LoRA 的 vllm 批量推理,能提升了推理效率。