目标检测多模态大模型实践:貌似是全网唯一Shikra的部署和测试教程,内含各种踩坑以及demo代码

原文:

Shikra: Unleashing Multimodal LLM's Referential Dialogue Magic

代码:

https://github.com/shikras/shikra

模型:

https://huggingface.co/shikras/shikra-7b-delta-v1

https://huggingface.co/shikras/shikra7b-delta-v1-0708

第一个是论文用的,第二个会有迭代。

本人的shikra论文解读,逐行解读,非常详细!

多模态大模型目标检测,精读,Shikra

部署:

  1. 下载GitHub工程,和shikras的模型参数,注意,还要下载LLaMA-7b的模型;

  2. 创建环境:

    conda create -n shikra python=3.10
    conda activate shikra
    pip install -r requirements.txt

后面我运行的时候报缺包了,又pip install了以下包,不过每个人情况不同:

复制代码
pip install uvicorn
pip install mmengine

然后还会报错:

复制代码
File "/usr/local/lib/python3.10/dist-packages/cv2/typing/__init__.py", line 171, in <module>
    LayerId = cv2.dnn.DictValue
AttributeError: module 'cv2.dnn' has no attribute 'DictValue'

解决方案:

修改/usr/local/lib/python3.10/dist-packages/cv2/typing/init .py

注释掉LayerId = cv2.dnn.DictValue这行即可。

  1. 权重下载和合并
    shikra官方提供的模型权重需要和llama1-7b合并之后才能用,然而llama1需要申请,比较麻烦,在hf上找到了平替(这一步我走了好久QwQ):
    https://huggingface.co/huggyllama/llama-7b
    大家自己下载,然后运行官方提供的合并代码:

    python mllm/models/shikra/apply_delta.py
    --base /path/to/llama-7b
    --target /output/path/to/shikra-7b-merge
    --delta shikras/shikra-7b-delta-v1

得到了可用的模型参数shikra-7b-merge。

注意要把参数文件夹里config里的模型路径改成merge版的。

此外还需要下载clip 模型参数:

https://huggingface.co/openai/clip-vit-large-patch14

代码和配置文件中有多处调用/openai/clip-vit-large-patch14,要改成本地版本。如果不预先下载,应该会在运行时自动下载,大家看网络情况自行选择。

  1. 我写的demo文件,用于在命令行测试模型效果,主要是为了不用gradiofastapi这些东西。
python 复制代码
import argparse
import os
import sys
import base64
import logging
import time
from pathlib import Path
from io import BytesIO

import torch
import uvicorn
import transformers
from PIL import Image
from mmengine import Config
from transformers import BitsAndBytesConfig

sys.path.append(str(Path(__file__).parent.parent.parent))

from mllm.dataset.process_function import PlainBoxFormatter
from mllm.dataset.builder import prepare_interactive
from mllm.models.builder.build_shikra import load_pretrained_shikra
from mllm.dataset.utils.transform import expand2square, box_xyxy_expand2square

# Set up logging
log_level = logging.DEBUG
transformers.logging.set_verbosity(log_level)
transformers.logging.enable_default_handler()
transformers.logging.enable_explicit_format()

# prompt for coco

# Argument parsing
parser = argparse.ArgumentParser("Shikra Local Demo")
parser.add_argument('--model_path', default = "xxx/shikra-merge", help="Path to the model")
parser.add_argument('--load_in_8bit', action='store_true', help="Load model in 8-bit precision")
parser.add_argument('--image_path', default = "xxx/shikra-main/mllm/demo/assets/ball.jpg", help="Path to the image file")
parser.add_argument('--text', default="What do you see in this image? Please mention the objects and their locations using the format [x1,y1,x2,y2].", help="Text prompt")
parser.add_argument('--boxes_value', nargs='+', type=int, default=[], help="Bounding box values (x1, y1, x2, y2)")
parser.add_argument('--boxes_seq', nargs='+', type=int, default=[], help="Sequence of bounding boxes")
parser.add_argument('--do_sample', action='store_true', help="Use sampling during generation")
parser.add_argument('--max_length', type=int, default=512, help="Maximum length of the output")
parser.add_argument('--top_p', type=float, default=1.0, help="Top-p value for sampling")
parser.add_argument('--temperature', type=float, default=1.0, help="Temperature for sampling")

args = parser.parse_args()
model_name_or_path = args.model_path
# Model initialization
model_args = Config(dict(
    type='shikra',
    version='v1',

    # checkpoint config
    cache_dir=None,
    model_name_or_path=model_name_or_path,
    vision_tower=r'xxx/clip-vit-large-patch14',
    pretrain_mm_mlp_adapter=None,

    # model config
    mm_vision_select_layer=-2,
    model_max_length=2048,

    # finetune config
    freeze_backbone=False,
    tune_mm_mlp_adapter=False,
    freeze_mm_mlp_adapter=False,

    # data process config
    is_multimodal=True,
    sep_image_conv_front=False,
    image_token_len=256,
    mm_use_im_start_end=True,

    target_processor=dict(
        boxes=dict(type='PlainBoxFormatter'),
    ),

    process_func_args=dict(
        conv=dict(type='ShikraConvProcess'),
        target=dict(type='BoxFormatProcess'),
        text=dict(type='ShikraTextProcess'),
        image=dict(type='ShikraImageProcessor'),
    ),

    conv_args=dict(
        conv_template='vicuna_v1.1',
        transforms=dict(type='Expand2square'),
        tokenize_kwargs=dict(truncation_size=None),
    ),

    gen_kwargs_set_pad_token_id=True,
    gen_kwargs_set_bos_token_id=True,
    gen_kwargs_set_eos_token_id=True,
))
training_args = Config(dict(
    bf16=False,
    fp16=True,
    device='cuda',
    fsdp=None,
))

quantization_kwargs = dict(
    quantization_config=BitsAndBytesConfig(
        load_in_8bit=args.load_in_8bit,
    )
) if args.load_in_8bit else dict()

model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs)

# Convert the model and vision tower to float16
if not getattr(model, 'is_quantized', False):
    model.to(dtype=torch.float16, device=torch.device('cuda'))
if not getattr(model.model.vision_tower[0], 'is_quantized', False):
    model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda'))

preprocessor['target'] = {'boxes': PlainBoxFormatter()}
tokenizer = preprocessor['text']

# Load and preprocess the image
pil_image = Image.open(args.image_path).convert("RGB")
ds = prepare_interactive(model_args, preprocessor)

image = expand2square(pil_image)
boxes_value = [box_xyxy_expand2square(box, w=pil_image.width, h=pil_image.height) for box in zip(args.boxes_value[::2], args.boxes_value[1::2], args.boxes_value[2::2], args.boxes_value[3::2])]

ds.set_image(image)
ds.append_message(role=ds.roles[0], message=args.text, boxes=boxes_value, boxes_seq=args.boxes_seq)
model_inputs = ds.to_model_input()
model_inputs['images'] = model_inputs['images'].to(torch.float16)

# Generate
gen_kwargs = dict(
    use_cache=True,
    do_sample=args.do_sample,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=args.max_length,
    top_p=args.top_p,
    temperature=args.temperature,
)

input_ids = model_inputs['input_ids']
st_time = time.time()
with torch.inference_mode():
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        output_ids = model.generate(**model_inputs, **gen_kwargs)
print(f"Generated in {time.time() - st_time} seconds")

input_token_len = input_ids.shape[-1]
response = tokenizer.batch_decode(output_ids[:, input_token_len:])[0]
print(f"Response: {response}")

这么良心,点个关注吧,会持续更新多模态大模型相关内容。

相关推荐
DuHz1 天前
论文精读:大语言模型 (Large Language Models, LLM) —— 一项调查
论文阅读·人工智能·深度学习·算法·机器学习·计算机视觉·语言模型
offer收割机小鹅1 天前
大学生求职必备:AI面试、AI写作与设计工具助力职场发展
人工智能·ai·面试·aigc·ai写作
XinZong1 天前
【AI社交Skill】禁止人类发言”的AI闭环社交社区_ 到底什么是 clawreach 虾聊?
aigc·openai·ai编程
爱吃的小肥羊1 天前
裂开!ChatGPT 居然开始要手机号验证,附详细解决方法
aigc
阿祖zu1 天前
本地到生产,解决 AI 全栈最后一公里——构建&部署&运维
运维·架构·aigc
奇牙1 天前
OpenClaw Channel 插件开发实战:从零写一个自定义模型接入插件(2026)
aigc·mcp
AI医影跨模态组学1 天前
如何将影像组学特征与肿瘤免疫微环境中的关键信号通路及免疫细胞浸润建立关联,并进一步解释其与胃癌术后复发、预后的机制联系
人工智能·深度学习·计算机视觉·论文·医学影像
阿杰学AI1 天前
AI核心知识141—大语言模型之 对齐难题(简洁且通俗易懂版)
人工智能·安全·ai·语言模型·自然语言处理·aigc·ai对齐
Elastic 中国社区官方博客1 天前
在 Discover 中探索来自新的时间序列数据流的指标
大数据·数据库·目标检测·elasticsearch·搜索引擎·数据分析·全文检索