1. 引言
多模态大模型(如 LLaVA、Qwen-VL、InternVL)能够同时理解图像和文本,实现视觉问答、图像描述、OCR 等任务。本文将介绍如何微调这些模型以适应特定领域。
主流多模态架构对比:
| 模型 | 视觉编码器 | LLM | 参数量 | 特点 |
|---|---|---|---|---|
| LLaVA-1.5 | CLIP-ViT-L | Vicuna/LLaMA | 7B/13B | 简单高效 |
| Qwen-VL | ViT-bigG | Qwen-7B | 9.6B | 中文优秀 |
| InternVL-2 | InternViT-6B | InternLM2 | 8B-76B | 开源最强 |
| Phi-3-Vision | CLIP-ViT | Phi-3 | 4.2B | 轻量级 |
2. LLaVA 架构解析
2.1 三组件架构
图像 → Vision Encoder (CLIP ViT-L/14) → 视觉 tokens
↓
Projection Layer (MLP)
↓
文本 → Tokenizer → 文本 tokens ──────→ 拼接 → LLM → 回答
2.2 两阶段训练
阶段一:预训练投影层
- 冻结 Vision Encoder 和 LLM
- 只训练 Projection Layer
- 数据:558K 图文对(图像描述)
- 目标:对齐视觉和语言空间
阶段二:指令微调
- 冻结 Vision Encoder
- 训练 Projection Layer + LLM
- 数据:665K 多模态指令数据
- 目标:学习遵循指令回答问题
3. 数据准备
3.1 数据格式
json
{
"id": "vqa_001",
"image": "images/001.jpg",
"conversations": [
{"from": "human", "value": "<image>\n这张图片中有什么?"},
{"from": "gpt", "value": "图片中显示了一条繁忙的城市街道,有多个行人和车辆。"}
]
}
3.2 数据处理脚本
python
import json
from PIL import Image
from torch.utils.data import Dataset
class MultimodalDataset(Dataset):
"""多模态指令微调数据集"""
def __init__(self, data_path, image_dir, processor, tokenizer, max_length=2048):
with open(data_path) as f:
self.data = json.load(f)
self.image_dir = image_dir
self.processor = processor
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 加载图像
image_path = f"{self.image_dir}/{item['image']}"
image = Image.open(image_path).convert("RGB")
# 处理对话
conversations = item["conversations"]
prompt = conversations[0]["value"].replace("<image>", "")
answer = conversations[1]["value"]
# 构造输入
input_text = f"USER: <image>\n{prompt}\nASSISTANT: {answer}"
# 编码
image_inputs = self.processor(images=image, return_tensors="pt")
text_inputs = self.tokenizer(
input_text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return {
"pixel_values": image_inputs["pixel_values"].squeeze(),
"input_ids": text_inputs["input_ids"].squeeze(),
"attention_mask": text_inputs["attention_mask"].squeeze(),
}
4. LLaVA 微调
4.1 环境准备
bash
pip install transformers accelerate peft
pip install flash-attn --no-build-isolation
4.2 加载模型
python
from transformers import LlavaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
model_id = "llava-hf/llava-1.5-7b-hf"
# QLoRA 配置
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# 加载模型
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
# LoRA 配置(只适配语言模型部分)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
4.3 训练
python
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./llava-finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
save_strategy="epoch",
remove_unused_columns=False,
optim="paged_adamw_8bit",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=lambda batch: {
"pixel_values": torch.stack([b["pixel_values"] for b in batch]),
"input_ids": torch.stack([b["input_ids"] for b in batch]),
"attention_mask": torch.stack([b["attention_mask"] for b in batch]),
"labels": torch.stack([b["input_ids"] for b in batch]),
},
)
trainer.train()
5. Qwen-VL 微调
5.1 加载 Qwen-VL
python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "Qwen/Qwen-VL-Chat"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
trust_remote_code=True,
bf16=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
5.2 Qwen-VL 数据格式
json
{
"id": "vqa_001",
"conversations": [
{
"from": "user",
"value": "Picture 1: images/001.jpg\n这张图片中有什么?"
},
{
"from": "assistant",
"value": "图片中显示了一条繁忙的城市街道。"
}
]
}
6. 推理与评估
6.1 推理代码
python
from PIL import Image
def inference(model, processor, image_path, question):
"""多模态推理"""
image = Image.open(image_path).convert("RGB")
prompt = f"USER: <image>\n{question}\nASSISTANT:"
inputs = processor(text=prompt, images=image, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
output = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
)
response = processor.decode(output[0], skip_special_tokens=True)
# 提取 ASSISTANT 后的回答
return response.split("ASSISTANT:")[-1].strip()
# 使用
answer = inference(model, processor, "test.jpg", "描述这张图片的内容")
print(answer)
6.2 评估指标
python
# 常用评估指标
# 1. VQA Accuracy: 答案是否正确
# 2. CIDEr: 图像描述质量
# 3. BLEU: 文本生成质量
# 4. ROUGE: 召考率导向的评估
from pycocoevalcap.cider.cider import Cider
def evaluate_cider(predictions, references):
"""CIDEr 评估"""
scorer = Cider()
score, _ = scorer.compute_score(references, predictions)
return score
7. 领域适配示例
7.1 医学影像分析
python
# 医学影像数据集格式
medical_data = {
"id": "med_001",
"image": "xray/chest_001.jpg",
"conversations": [
{"from": "human", "value": "<image>\n分析这张X光片,指出异常区域。"},
{"from": "gpt", "value": "右下肺野可见片状模糊影,考虑肺炎可能。建议CT进一步检查。"}
]
}
7.2 工业质检
python
# 工业质检数据格式
qc_data = {
"id": "qc_001",
"image": "defect/scratch_001.jpg",
"conversations": [
{"from": "human", "value": "<image>\n检测这张产品图片中的缺陷。"},
{"from": "gpt", "value": "检测到表面划痕缺陷,位于图片中部偏左位置,长度约2cm。严重程度:中等。"}
]
}
8. 总结
多模态大模型微调的关键要点:
- 数据质量 > 数据量:1000 条高质量标注 > 10000 条低质量数据
- 两阶段训练:先预训练投影层对齐模态,再指令微调
- LoRA 微调:只适配语言模型部分,视觉编码器通常冻结
- 领域数据:收集领域特定的图文对是成功的关键