计算vlm模型的ppl损失。
代码:
python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
from torch.nn import CrossEntropyLoss
from PIL import Image
# 配置
DEVICE = "cuda:0"
MODEL_NAME = "/data1/chenjun/huf/Qwen2-VL-2B-Instruct"
IMAGE_SIZE = 384
def resize_image(path, max_side=384):
"""调整图片大小,保持宽高比"""
image = Image.open(path).convert("RGB")
width, height = image.size
if width > height:
new_width = max_side
new_height = int(height * (max_side / width))
else:
new_height = max_side
new_width = int(width * (max_side / height))
return [image.resize((new_width, new_height), Image.Resampling.LANCZOS)]
def main():
# 加载模型和处理器
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME, dtype=torch.float32, device_map=DEVICE
)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
# 构建消息
file = 'outputs/ppl_vlm_qwen3-vl-2b-axera-384/vit/0000.png'
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": file},
{"type": "text", "text": "描述这张图片"},
],
}
]
# 应用chat模板
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# 处理图片
image_inputs = resize_image(file, IMAGE_SIZE)
inputs = processor(text=[text], images=image_inputs, return_tensors="pt").to(DEVICE)
gen_idx = inputs['input_ids'].shape[1]
# 生成文本
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# 计算PPL
text_with_response = text + output_text
image_inputs = resize_image(file, IMAGE_SIZE)
inputs2 = processor(text=[text_with_response], images=image_inputs, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs2, max_new_tokens=1)
logits = outputs.logits
# 计算交叉熵损失
shift_labels = inputs2['input_ids'][..., gen_idx+1:].contiguous().to(DEVICE)
shift_logits = logits[..., gen_idx:-1, :].contiguous().to(dtype=torch.float32)
loss_fct = CrossEntropyLoss()
ce_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
print(f"ce_loss: {ce_loss:.3f}, ppl: {ce_loss.exp():.3f}")
if __name__ == "__main__":
main()