通义千问模型微调——swift框架

1.创建环境

服务器CUDA Version: 12.2

bash 复制代码
conda create -n lora_qwen python=3.10 -y 
conda activate lora_qwen 
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y

1.1环境搭建

本文使用swift进行微调,所以先下载swift,以及一些必要的packages

bash 复制代码
git clone https://github.com/modelscope/ms-swift.git
pip install transformers==4.49.0 
pip install pyav qwen_vl_utils 
pip install numpy==1.22.4 
pip install modelscope

1.2模型下载

使用modelscope下载指定模型,其中:

--model表示模型名称,可在modelscope官网找到

--local_dir代表模型下载地址

运行下面的命令,模型会下载到:./Qwen/Qwen2.5-VL-7B-Instruct目录下

bash 复制代码
modelscope download --model Qwen/Qwen2.5-VL-7B-Instruct --local_dir ./

下面脚本用于和模型进行对话,可以简单测试一下模型是否能够使用

bash 复制代码
CUDA_VISIBLE_DEVICES=1 swift infer --model_type qwen2_5_vl --ckpt_dir ./Qwen/Qwen2.5-VL-7B-Instruct

1.3数据集准备

下方是数据集格式,保存类型为.jsonl

bash 复制代码
[
    {
        "query": "OCR一下<image>",
        "response": "朵拉童衣",
        "images": [
            "datasets/lora_qwen/train/billboard_00001_010_朵拉童衣.jpg"
        ]
    },
    {
        "query": "OCR一下<image>",
        "response": "童衣雜貨舖",
        "images": [
            "datasets/lora_qwen/train/billboard_00002_010_童衣雜貨舖.jpg"
        ]
    },...
]

2.微调

2.1采用LoRA进行微调

对文件夹中之前下载的ms-swift-main/examples/train/multimodal/ocr.sh进行修改

bash 复制代码
# 20GB
CUDA_VISIBLE_DEVICES=0,1 \
MAX_PIXELS=1003520 \
swift sft \
    --model ./Qwen/Qwen2.5-VL-7B-Instruct \
    --model_type qwen2_5_vl \
    --dataset ./datatsets/train.jsonl \
    --val_dataset ./datatsets/val.jsonl \
    --train_type lora \
    --torch_dtype bfloat16 \
    --num_train_epochs 100 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-4 \
    --lora_rank 64 \
    --lora_alpha 64 \
    --target_modules all-linear \
    --freeze_vit true \
    --gradient_accumulation_steps 16 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 10 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4
  • 常用参数解释:
bash 复制代码
--model:原模型的权重地址

--dataset:训练集的数据地址

--val_dataset:验证集的数据地址

--train_type:全参数训练(full) 或 LoRA微调训练(lora)

--num_train_epochs:总共要训练的轮数

--per_device_train_batch_size:训练阶段batchsize大小,根据显存大小来设置

--per_device_eval_batch_size:验证阶段batchsize大小,根据显存大小来设置

--learning_rate:学习率,一般设为0.0001或0.00001

--target_modules:需要做微调的目标模块,all-linear表示所有的线形层,也就是Attention和FeedForward层

--freeze_vit:一般设为true,不微调视觉编码器,只微调LLM部分

2.2使用Transformer进行推理

python 复制代码
import os
import re
import torch
from PIL import Image

from datasets import Dataset
from modelscope import AutoTokenizer
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    Trainer, TrainingArguments,
    Seq2SeqTrainer, Seq2SeqTrainingArguments, 
    DataCollatorForSeq2Seq,
)
from qwen_vl_utils import process_vision_info

rewrite_print = print
def print(save_txt, *arg, **kwargs):
    rewrite_print(*arg, **kwargs)
    rewrite_print(*arg, **kwargs, file=open(save_txt, "a+", encoding="utf-8"))

def process_func(model, img_path, input_content):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img_path},
                {"type": "text", "text": input_content},
            ],
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    print(save_txt_path, img_path)
    print(save_txt_path, output_text[0])
    print(save_txt_path, '\n')

def get_lora_model(model_path, lora_model_path):
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    model.enable_input_require_grads()

    config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules="model\..*layers\.\d+\.(self_attn\.(q_proj|k_proj|v_proj|o_proj)|mlp\.(gate_proj|up_proj|down_proj))",
        inference_mode=True,
        r=64,
        lora_alpha=64,
        lora_dropout=0.05,
        bias="none",
    )

    peft_model = PeftModel.from_pretrained(model, model_id=lora_model_path, config=config)
    return peft_model

if __name__ == '__main__':
    save_txt_path = 'log.txt'

    model_path = "./Qwen2.5-VL-7B-Instruct"
    lora_model_path = "./output/v2-20250228-202446/checkpoint-900"
    lora_model = get_lora_model(model_path, lora_model_path)
    processor = AutoProcessor.from_pretrained(model_path)
    
    img_path = "图片路径"
    prompt = "OCR一下"
    process_func(lora_model, img_path, prompt)

3.实验参数情况

模型微调显存:30G左右(主要看数据集,图片越大,prompt,answer越多,占用显存越多);

模型微调后推理:20G左右;

相关推荐
聆风吟º1 天前
CANN算子开发:ops-nn神经网络算子库的技术解析与实战应用
人工智能·深度学习·神经网络·cann
觉醒大王1 天前
强女思维:着急,是贪欲外显的相。
java·论文阅读·笔记·深度学习·学习·自然语言处理·学习方法
笔画人生1 天前
# 探索 CANN 生态:深入解析 `ops-transformer` 项目
人工智能·深度学习·transformer
亓才孓1 天前
[Class类的应用]反射的理解
开发语言·python
灰灰勇闯IT1 天前
领域制胜——CANN 领域加速库(ascend-transformer-boost)的场景化优化
人工智能·深度学习·transformer
小白狮ww1 天前
要给 OCR 装个脑子吗?DeepSeek-OCR 2 让文档不再只是扫描
人工智能·深度学习·机器学习·ocr·cpu·gpu·deepseek
小镇敲码人1 天前
深入剖析华为CANN框架下的Ops-CV仓库:从入门到实战指南
c++·python·华为·cann
island13141 天前
CANN GE(图引擎)深度解析:计算图优化管线、内存静态规划与异构任务的 Stream 调度机制
开发语言·人工智能·深度学习·神经网络
艾莉丝努力练剑1 天前
深度学习视觉任务:如何基于ops-cv定制图像预处理流程
人工智能·深度学习
禁默1 天前
大模型推理的“氮气加速系统”:全景解读 Ascend Transformer Boost (ATB)
人工智能·深度学习·transformer·cann