通义千问模型微调——swift框架

1.创建环境

服务器CUDA Version: 12.2

bash 复制代码
conda create -n lora_qwen python=3.10 -y 
conda activate lora_qwen 
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y

1.1环境搭建

本文使用swift进行微调,所以先下载swift,以及一些必要的packages

bash 复制代码
git clone https://github.com/modelscope/ms-swift.git
pip install transformers==4.49.0 
pip install pyav qwen_vl_utils 
pip install numpy==1.22.4 
pip install modelscope

1.2模型下载

使用modelscope下载指定模型,其中:

--model表示模型名称,可在modelscope官网找到

--local_dir代表模型下载地址

运行下面的命令,模型会下载到:./Qwen/Qwen2.5-VL-7B-Instruct目录下

bash 复制代码
modelscope download --model Qwen/Qwen2.5-VL-7B-Instruct --local_dir ./

下面脚本用于和模型进行对话,可以简单测试一下模型是否能够使用

bash 复制代码
CUDA_VISIBLE_DEVICES=1 swift infer --model_type qwen2_5_vl --ckpt_dir ./Qwen/Qwen2.5-VL-7B-Instruct

1.3数据集准备

下方是数据集格式,保存类型为.jsonl

bash 复制代码
[
    {
        "query": "OCR一下<image>",
        "response": "朵拉童衣",
        "images": [
            "datasets/lora_qwen/train/billboard_00001_010_朵拉童衣.jpg"
        ]
    },
    {
        "query": "OCR一下<image>",
        "response": "童衣雜貨舖",
        "images": [
            "datasets/lora_qwen/train/billboard_00002_010_童衣雜貨舖.jpg"
        ]
    },...
]

2.微调

2.1采用LoRA进行微调

对文件夹中之前下载的ms-swift-main/examples/train/multimodal/ocr.sh进行修改

bash 复制代码
# 20GB
CUDA_VISIBLE_DEVICES=0,1 \
MAX_PIXELS=1003520 \
swift sft \
    --model ./Qwen/Qwen2.5-VL-7B-Instruct \
    --model_type qwen2_5_vl \
    --dataset ./datatsets/train.jsonl \
    --val_dataset ./datatsets/val.jsonl \
    --train_type lora \
    --torch_dtype bfloat16 \
    --num_train_epochs 100 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-4 \
    --lora_rank 64 \
    --lora_alpha 64 \
    --target_modules all-linear \
    --freeze_vit true \
    --gradient_accumulation_steps 16 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 10 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4
  • 常用参数解释:
bash 复制代码
--model:原模型的权重地址

--dataset:训练集的数据地址

--val_dataset:验证集的数据地址

--train_type:全参数训练(full) 或 LoRA微调训练(lora)

--num_train_epochs:总共要训练的轮数

--per_device_train_batch_size:训练阶段batchsize大小,根据显存大小来设置

--per_device_eval_batch_size:验证阶段batchsize大小,根据显存大小来设置

--learning_rate:学习率,一般设为0.0001或0.00001

--target_modules:需要做微调的目标模块,all-linear表示所有的线形层,也就是Attention和FeedForward层

--freeze_vit:一般设为true,不微调视觉编码器,只微调LLM部分

2.2使用Transformer进行推理

python 复制代码
import os
import re
import torch
from PIL import Image

from datasets import Dataset
from modelscope import AutoTokenizer
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    Trainer, TrainingArguments,
    Seq2SeqTrainer, Seq2SeqTrainingArguments, 
    DataCollatorForSeq2Seq,
)
from qwen_vl_utils import process_vision_info

rewrite_print = print
def print(save_txt, *arg, **kwargs):
    rewrite_print(*arg, **kwargs)
    rewrite_print(*arg, **kwargs, file=open(save_txt, "a+", encoding="utf-8"))

def process_func(model, img_path, input_content):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img_path},
                {"type": "text", "text": input_content},
            ],
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    print(save_txt_path, img_path)
    print(save_txt_path, output_text[0])
    print(save_txt_path, '\n')

def get_lora_model(model_path, lora_model_path):
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    model.enable_input_require_grads()

    config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules="model\..*layers\.\d+\.(self_attn\.(q_proj|k_proj|v_proj|o_proj)|mlp\.(gate_proj|up_proj|down_proj))",
        inference_mode=True,
        r=64,
        lora_alpha=64,
        lora_dropout=0.05,
        bias="none",
    )

    peft_model = PeftModel.from_pretrained(model, model_id=lora_model_path, config=config)
    return peft_model

if __name__ == '__main__':
    save_txt_path = 'log.txt'

    model_path = "./Qwen2.5-VL-7B-Instruct"
    lora_model_path = "./output/v2-20250228-202446/checkpoint-900"
    lora_model = get_lora_model(model_path, lora_model_path)
    processor = AutoProcessor.from_pretrained(model_path)
    
    img_path = "图片路径"
    prompt = "OCR一下"
    process_func(lora_model, img_path, prompt)

3.实验参数情况

模型微调显存:30G左右(主要看数据集,图片越大,prompt,answer越多,占用显存越多);

模型微调后推理:20G左右;

相关推荐
机器学习之心1 分钟前
TCN-Transformer-BiGRU组合模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析
深度学习·回归·transformer·shap分析
Katecat996631 分钟前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python
LLWZAI2 分钟前
让朱雀AI检测无法判断的AI公众号文章,当创作者开始与算法「躲猫猫」
大数据·人工智能·深度学习
玩大数据的龙威28 分钟前
农经权二轮延包—各种地块示意图
python·arcgis
ZH154558913130 分钟前
Flutter for OpenHarmony Python学习助手实战:数据库操作与管理的实现
python·学习·flutter
belldeep39 分钟前
python:用 Flask 3 , mistune 2 和 mermaid.min.js 10.9 来实现 Markdown 中 mermaid 图表的渲染
javascript·python·flask
喵手39 分钟前
Python爬虫实战:电商价格监控系统 - 从定时任务到历史趋势分析的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·电商价格监控系统·从定时任务到历史趋势分析·采集结果sqlite存储
霖大侠1 小时前
【无标题】
人工智能·深度学习·机器学习
喵手1 小时前
Python爬虫实战:京东/淘宝搜索多页爬虫实战 - 从反爬对抗到数据入库的完整工程化方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·京东淘宝页面数据采集·反爬对抗到数据入库·采集结果csv导出
B站_计算机毕业设计之家1 小时前
猫眼电影数据可视化与智能分析平台 | Python Flask框架 Echarts 推荐算法 爬虫 大数据 毕业设计源码
python·机器学习·信息可视化·flask·毕业设计·echarts·推荐算法