【VLM】——vlm计算ppl损失

计算vlm模型的ppl损失。

代码:

python 复制代码
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
from torch.nn import CrossEntropyLoss
from PIL import Image


# 配置
DEVICE = "cuda:0"
MODEL_NAME = "/data1/chenjun/huf/Qwen2-VL-2B-Instruct"
IMAGE_SIZE = 384


def resize_image(path, max_side=384):
    """调整图片大小,保持宽高比"""
    image = Image.open(path).convert("RGB")
    width, height = image.size
    if width > height:
        new_width = max_side
        new_height = int(height * (max_side / width))
    else:
        new_height = max_side
        new_width = int(width * (max_side / height))
    return [image.resize((new_width, new_height), Image.Resampling.LANCZOS)]


def main():
    # 加载模型和处理器
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_NAME, dtype=torch.float32, device_map=DEVICE
    )
    processor = AutoProcessor.from_pretrained(MODEL_NAME)

    # 构建消息
    file = 'outputs/ppl_vlm_qwen3-vl-2b-axera-384/vit/0000.png'
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": file},
                {"type": "text", "text": "描述这张图片"},
            ],
        }
    ]

    # 应用chat模板
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # 处理图片
    image_inputs = resize_image(file, IMAGE_SIZE)
    inputs = processor(text=[text], images=image_inputs, return_tensors="pt").to(DEVICE)
    gen_idx = inputs['input_ids'].shape[1]

    # 生成文本
    generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    # 计算PPL
    text_with_response = text + output_text
    image_inputs = resize_image(file, IMAGE_SIZE)
    inputs2 = processor(text=[text_with_response], images=image_inputs, return_tensors="pt").to(DEVICE)


    with torch.no_grad():
        outputs = model(**inputs2, max_new_tokens=1)
        logits = outputs.logits

        # 计算交叉熵损失
        shift_labels = inputs2['input_ids'][..., gen_idx+1:].contiguous().to(DEVICE)
        shift_logits = logits[..., gen_idx:-1, :].contiguous().to(dtype=torch.float32)
        loss_fct = CrossEntropyLoss()
        ce_loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )
        print(f"ce_loss: {ce_loss:.3f}, ppl: {ce_loss.exp():.3f}")


if __name__ == "__main__":
    main()
相关推荐
极欧互联1 分钟前
2026素材网站推荐排行 商用/自媒体/影视后期专用
大数据·人工智能·媒体
GOWIN革文品牌咨询3 分钟前
机器人企业品牌语言体系怎么搭建:一句话定位、产品逻辑与解决方案表达
人工智能·机器人
techdashen4 分钟前
Unweight:Cloudflare 如何在不损失精度的情况下把大模型压缩 22%
网络·人工智能
前端不太难4 分钟前
AI 能力如何变成鸿蒙 App 的基础设施
人工智能·状态模式·harmonyos
龙山云仓8 分钟前
无忧智脑-让企业拥抱智能,让管理回归简单
人工智能·深度学习·机器学习
2501_9333295511 分钟前
Infoseek数字公关AI中台技术解析:基于DeepSeek+NLP的全网舆情监测与智能处置系统
人工智能·架构·数据库开发
QFIUNE12 分钟前
【文献阅读】化学空间边缘的分子深度学习
论文阅读·人工智能·笔记·深度学习
新新学长搞科研13 分钟前
【最新】2026年能源方向学术会议征稿/交流资讯
人工智能·功能测试·计算机视觉·自动化·能源·新能源·材料工程
Coovally AI模型快速验证15 分钟前
多校联合提出LLM-as-Judge:大模型评判无人机电力线分割,无真值场景下守护安全
人工智能·计算机视觉·电力巡检
AI阿阳18 分钟前
✅真・喂饭级教程:2026 年 OpenClaw(Clawdbot)新手部署 + 飞书接入步骤流程
人工智能·windows·飞书·openclaw·openclaw 教程·本地 ai 部署