多模态大模型微调:LLaVA 与 Qwen-VL 视觉语言模型训练

1. 引言

多模态大模型(如 LLaVA、Qwen-VL、InternVL)能够同时理解图像和文本,实现视觉问答、图像描述、OCR 等任务。本文将介绍如何微调这些模型以适应特定领域。

主流多模态架构对比:

模型 视觉编码器 LLM 参数量 特点
LLaVA-1.5 CLIP-ViT-L Vicuna/LLaMA 7B/13B 简单高效
Qwen-VL ViT-bigG Qwen-7B 9.6B 中文优秀
InternVL-2 InternViT-6B InternLM2 8B-76B 开源最强
Phi-3-Vision CLIP-ViT Phi-3 4.2B 轻量级

2. LLaVA 架构解析

2.1 三组件架构

复制代码
图像 → Vision Encoder (CLIP ViT-L/14) → 视觉 tokens
                                           ↓
                                      Projection Layer (MLP)
                                           ↓
文本 → Tokenizer → 文本 tokens ──────→ 拼接 → LLM → 回答

2.2 两阶段训练

复制代码
阶段一:预训练投影层
  - 冻结 Vision Encoder 和 LLM
  - 只训练 Projection Layer
  - 数据:558K 图文对(图像描述)
  - 目标:对齐视觉和语言空间

阶段二:指令微调
  - 冻结 Vision Encoder
  - 训练 Projection Layer + LLM
  - 数据:665K 多模态指令数据
  - 目标:学习遵循指令回答问题

3. 数据准备

3.1 数据格式

json 复制代码
{
  "id": "vqa_001",
  "image": "images/001.jpg",
  "conversations": [
    {"from": "human", "value": "<image>\n这张图片中有什么?"},
    {"from": "gpt", "value": "图片中显示了一条繁忙的城市街道,有多个行人和车辆。"}
  ]
}

3.2 数据处理脚本

python 复制代码
import json
from PIL import Image
from torch.utils.data import Dataset

class MultimodalDataset(Dataset):
    """多模态指令微调数据集"""

    def __init__(self, data_path, image_dir, processor, tokenizer, max_length=2048):
        with open(data_path) as f:
            self.data = json.load(f)
        self.image_dir = image_dir
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # 加载图像
        image_path = f"{self.image_dir}/{item['image']}"
        image = Image.open(image_path).convert("RGB")

        # 处理对话
        conversations = item["conversations"]
        prompt = conversations[0]["value"].replace("<image>", "")
        answer = conversations[1]["value"]

        # 构造输入
        input_text = f"USER: <image>\n{prompt}\nASSISTANT: {answer}"

        # 编码
        image_inputs = self.processor(images=image, return_tensors="pt")
        text_inputs = self.tokenizer(
            input_text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "pixel_values": image_inputs["pixel_values"].squeeze(),
            "input_ids": text_inputs["input_ids"].squeeze(),
            "attention_mask": text_inputs["attention_mask"].squeeze(),
        }

4. LLaVA 微调

4.1 环境准备

bash 复制代码
pip install transformers accelerate peft
pip install flash-attn --no-build-isolation

4.2 加载模型

python 复制代码
from transformers import LlavaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

model_id = "llava-hf/llava-1.5-7b-hf"

# QLoRA 配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# 加载模型
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(model_id)

# LoRA 配置(只适配语言模型部分)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

4.3 训练

python 复制代码
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./llava-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    save_strategy="epoch",
    remove_unused_columns=False,
    optim="paged_adamw_8bit",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=lambda batch: {
        "pixel_values": torch.stack([b["pixel_values"] for b in batch]),
        "input_ids": torch.stack([b["input_ids"] for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
        "labels": torch.stack([b["input_ids"] for b in batch]),
    },
)

trainer.train()

5. Qwen-VL 微调

5.1 加载 Qwen-VL

python 复制代码
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Qwen/Qwen-VL-Chat"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True,
    bf16=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

5.2 Qwen-VL 数据格式

json 复制代码
{
  "id": "vqa_001",
  "conversations": [
    {
      "from": "user",
      "value": "Picture 1: images/001.jpg\n这张图片中有什么?"
    },
    {
      "from": "assistant",
      "value": "图片中显示了一条繁忙的城市街道。"
    }
  ]
}

6. 推理与评估

6.1 推理代码

python 复制代码
from PIL import Image

def inference(model, processor, image_path, question):
    """多模态推理"""
    image = Image.open(image_path).convert("RGB")

    prompt = f"USER: <image>\n{question}\nASSISTANT:"
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    output = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
    )

    response = processor.decode(output[0], skip_special_tokens=True)
    # 提取 ASSISTANT 后的回答
    return response.split("ASSISTANT:")[-1].strip()

# 使用
answer = inference(model, processor, "test.jpg", "描述这张图片的内容")
print(answer)

6.2 评估指标

python 复制代码
# 常用评估指标
# 1. VQA Accuracy: 答案是否正确
# 2. CIDEr: 图像描述质量
# 3. BLEU: 文本生成质量
# 4. ROUGE: 召考率导向的评估

from pycocoevalcap.cider.cider import Cider

def evaluate_cider(predictions, references):
    """CIDEr 评估"""
    scorer = Cider()
    score, _ = scorer.compute_score(references, predictions)
    return score

7. 领域适配示例

7.1 医学影像分析

python 复制代码
# 医学影像数据集格式
medical_data = {
    "id": "med_001",
    "image": "xray/chest_001.jpg",
    "conversations": [
        {"from": "human", "value": "<image>\n分析这张X光片,指出异常区域。"},
        {"from": "gpt", "value": "右下肺野可见片状模糊影,考虑肺炎可能。建议CT进一步检查。"}
    ]
}

7.2 工业质检

python 复制代码
# 工业质检数据格式
qc_data = {
    "id": "qc_001",
    "image": "defect/scratch_001.jpg",
    "conversations": [
        {"from": "human", "value": "<image>\n检测这张产品图片中的缺陷。"},
        {"from": "gpt", "value": "检测到表面划痕缺陷,位于图片中部偏左位置,长度约2cm。严重程度:中等。"}
    ]
}

8. 总结

多模态大模型微调的关键要点:

  1. 数据质量 > 数据量:1000 条高质量标注 > 10000 条低质量数据
  2. 两阶段训练:先预训练投影层对齐模态,再指令微调
  3. LoRA 微调:只适配语言模型部分,视觉编码器通常冻结
  4. 领域数据:收集领域特定的图文对是成功的关键