如何打造自己的视觉版GPT

一、前言

自ChatGPT出现后,Yann LeCun就一直抨击GPT无法成为AGI的最终形态,其原因在于文本只是人类智能的一种载体,有许多东西是文本不能表示的,也无法从GPT的训练中理解,Yann LeCun认为视觉大模型是一种更有前景的方向。

如今GPT4V、MiniGPT5等模型已经具备了视觉能力,基于这种视觉能力,做出来非常多有趣是应用。比如:AI游戏解说、AI赛事解说、TinyBot等。

不过视觉比文本要更加昂贵,以GPT4V为例,理解一个一分钟的视频需要十几到几十美元不等。今天我们要做的就是打造一个低成本的视觉版GPT,这里的低成本不仅体现在价钱,还体现在硬件资源。

二、实现原理

让GPT理解图像的方式有两种,分别是端到端的和非端到端的。

首先是第一种非端到端的,这种方式的想法很简单,就是利用图像描述模型生成图像的描述,然后利用Prompt工程实现图像理解。

我们只需要使用Image Captioning相关的网络生成一个图像描述即可。

而端到端的网络则是输入图片和问题,直接返回回答。在这个过程中,需要用到两个Transformer。首先需要对输入的问题进行编码,可以使用基础的Transformer Encoder做到。为了方便Decoder处理,我们可以用一个ViT(Vision Transformer)对图像进行编码。ViT和传统的Transformer非常相似,只不过把原本输入的token变成了图片的patch。最后图像、文本被编码成768的向量,再把向量拼接,传递给Tranformer Decoder生成回答。

第一种方式可以理解为GPT本身是一个盲人,但是有人在给他描述场景。而第二种方式则是GPT本身就能看到东西。

在理解能力上,第二种实现要更加准确,不过资源的消耗是巨大的。因此本文使用第一种实现方式。

三、基于YOLO的图像描述

3.1 yolo生成描述

一种简单的思路是使用YOLO、ResNet等网络来生成描述。这种描述通常比较机械,但是非常直观。这里我们以YOLO为例。

首先需要使用yolo来完成目标检测的操作,这里需要使用:

复制代码
pip install ultralytics

安装必要的库。 然后是目标检测,代码如下:

python 复制代码
from ultralytics import YOLO
# 从huggingface加载模型
model = YOLO('ultralyticsplus/yolov8s')
image = 'https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg'
# 预测
results = model.predict(image)
# 绘制框
render = render_result(model=model, image=image, result=results[0])
render.show()

使用yolo我们可以得到两个信息,第一个是图像中有什么物体,一个是物体所在的区域。根据这两个信息,我们可以使用一套模板当做图像描述,比如:

图片左上角(右下角)有一个猫(狗、人)

在图像中可以检测到多个物体,一个物体生成一句描述,最后合起来就是图像的最终描述了。生成描述的代码如下:

python 复制代码
from collections import Counter
from ultralyticsplus import YOLO


def location_to_description(yolo_results):
    counter = Counter()
    for category, box in zip(yolo_results[0].boxes.cls.cpu().numpy(), yolo_results[0].boxes.xywhn.cpu().numpy()):
        x, y, w, h = box
        if x < 0.33 and y < 0.33:
            location = "upper left area"
        elif x > 0.66 and y > 0.66:
            location = "lower right area"
        elif x < 0.33 and y > 0.66:
            location = "lower left area"
        elif x > 0.66 and y < 0.33:
            location = "upper right area"
        elif 0.33 < x < 0.66 < y:
            location = "bottom center area"
        elif 0.66 > x > 0.33 > y:
            location = "top center area"
        else:
            location = "center area"
        counter.update((f"There is ?? {yolo.model.names[int(category)]} located in the {location}.",))
    description = ""
    for obj in counter.most_common():
        tmp, count = obj
        if count > 1:
            tmp = tmp.replace("is", "are")
        tmp = tmp.replace("??", str(count))
        description += tmp + "\n"
    return description

location_to_description函数接收yolo的输出,并返回一段描述,比如用下面的图片测试:

测试代码如下:

ini 复制代码
if __name__ == '__main__':
    yolo = YOLO('ultralyticsplus/yolov8s')
    results = yolo.predict("dongwu.jpg")
    print(location_to_description(results))

得到如下输出:

scss 复制代码
There are 2 elephant located in the upper left area.
There are 2 elephant located in the center area.
There are 2 elephant located in the top center area.

因为是用规则生成的描述,难免会有一些语法问题,不过这个可以交给LLM自己解决。

3.2 图文问答

接下来就是把描述接入LLM了,这里选择开源的Llama2模型,使用Llamacpp部署。我们可以用fastapi写一个简单的接口,也可以用llama-cpp-python提供的api,这里选择前者。代码如下:

python 复制代码
import uvicorn
from llama_cpp import Llama
from fastapi import FastAPI, Request

app = FastAPI()
llm = Llama(
    model_path=r'G:\models\llama2\llama-2-13b-chat-q4\ggml-model-q4_0.gguf',
    n_ctx=2048
)


@app.post("/chat")
async def chat(request: Request):
    global llm
    jdata = await request.json()
    prompt = jdata['prompt']
    return llm(prompt, stop=['Human'])


if __name__ == '__main__':
    uvicorn.run(app, host='127.0.0.1', port=8000, workers=1)

这里把model_path设置成自己的模型位置即可。运行后就可以用post访问http://127.0.0.1:8000/chat 了,调用接口的代码如下:

python 复制代码
async def chat(prompt):
    async with aiohttp.ClientSession() as session:
        async with session.post('http://127.0.0.1:8000/chat', json={'prompt': prompt}) as response:
            response = await response.json()
            return response['choices'][0]['text']

我们只需传入prompt即可。接下来就是界面的搭建,这里选择streamlit,代码如下:

python 复制代码
import aiohttp
import asyncio
from io import BytesIO
from PIL import Image
import streamlit as st
from ultralyticsplus import YOLO

prompt = ""
# 加载历史消息
messages = st.session_state.get('history_chat')
if not messages:
    messages = []
# 加载yolo
yolo = st.session_state.get('yolo')
if not yolo:
    yolo = YOLO('ultralyticsplus/yolov8s')
    st.session_state['yolo'] = yolo

# 界面
st.title("图文对话")
if file := st.file_uploader(label="请上传图片"):
    image = Image.open(BytesIO(file.getvalue()))
    st.sidebar.image(image)
    results = yolo.predict(image)
    description = location_to_description(results)
    st.sidebar.write(description)
    prompt += (
        "System: You need to answer the questions based on the description of the picture given below."
        "If the description has nothing to do with the question, "
        "you should just answer using your own language abilities."
        "Do not imagine non-existent facts.\n\n"
        f"Description: {description}."
    )
for role, text in messages:
    st.chat_message(role).write(text)
if message := st.chat_input("请输入问题:"):
    messages.append(['user', message])
    prompt += (
        f"\n\nHuman: {message}. \n\nAssistant: "
    )
    st.chat_message('user').write(message)
    response = asyncio.run(chat(prompt))
    messages.append(['assistant', response])
    st.chat_message('assistant').write(response)
    st.session_state['history_chat'] = messages

运行界面后就可以上传图片进行对话了。

四、基于BLIP的图像描述

在这种方法中,上面的大部分代码都可以复用,我们只需要重写一个生成描述的方法即可。代码如下:

python 复制代码
# 导入新模块
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration


prompt = ""
# 加载历史消息
messages = st.session_state.get('history_chat')
if not messages:
    messages = []
# 把加载yolo改成加载blip
model_path = r"G:\huggingface\hub\models--Salesforce--blip-image-captioning-large"
processor = st.session_state.get('processor')
if not processor:
    processor = BlipProcessor.from_pretrained(model_path)
blip = st.session_state.get('blip')
if not blip:
    blip = BlipForConditionalGeneration.from_pretrained(model_path,
                                                        torch_dtype=torch.float16).to("cuda")
# 界面
st.title("图文对话")
if file := st.file_uploader(label="请上传图片"):

    image = Image.open(BytesIO(file.getvalue()))
    st.sidebar.image(image)
    
    # 使用blip生成描述
    text = "a photography of"
    inputs = processor(image, text, return_tensors="pt").to("cuda", torch.float16)
    out = blip.generate(**inputs)
    description = processor.decode(out[0], skip_special_tokens=True)
    st.sidebar.write(description)
    prompt += (
        "System: You need to answer the questions based on the description of the picture given below."
        "If the description has nothing to do with the question, "
        "you should just answer using your own language abilities."
        "Do not imagine non-existent facts.\n\n"
        f"Description: {description}."
    )
for role, text in messages:
    st.chat_message(role).write(text)
if message := st.chat_input("请输入问题:"):
    messages.append(['user', message])
    prompt += (
        f"\n\nHuman: {message}. \n\nAssistant: "
    )
    st.chat_message('user').write(message)
    response = asyncio.run(chat(prompt))
    messages.append(['assistant', response])
    st.chat_message('assistant').write(response)
    st.session_state['history_chat'] = messages

上述代码有三处修改。第一处是导包,这里把yolo去掉了,导入了BLIP。第二处则是加载模型,同样把yolo换成了BLIP。第三处则是生成描述,这里去掉了location_to_description函数,换成了使用blip生成图像描述。其余部分维持原样,此时我们再来运行会发现结果比原本要更自然。

五、总结

本文使用了一种类似于文档问答的方式实现了图文QA的操作。这种方法不需要额外的训练,可以在极低的成本下实现,但是效果也取决于描述的质量,对于某些细节的把握并不是那么准确。读者可以尝试更多生成描述的方法以提升问答的准确性。

关于实现的详细教程可以参考:www.bilibili.com/video/BV19H...

项目代码上传至github:github.com/IronSpiderM...

相关推荐
智者知已应修善业2 小时前
【51单片机8位数码管同时倒计时从9999】2024-1-25
c++·经验分享·笔记·算法·51单片机
洛水水2 小时前
【力扣100题】86.柱状图中最大的矩形
算法·leetcode·职场和发展
渡之2 小时前
GRiM-Net 深度解析 | 无人机 GNSS 拒止场景下两阶段跨视角视觉定位框架
深度学习·算法·动态规划·无人机
测试仪器廖生135902563853 小时前
罗德与施瓦茨 FSP13频谱分析仪FSP30
网络·人工智能·算法
happymaker06263 小时前
LeetCodeHot100——560.和为K的子数组
算法
dtq04243 小时前
C语言刷题数组5,6(求平均值,求最大值)
c语言·数据结构·算法
郭梧悠3 小时前
Hash算法入门Hash冲突解决方案
算法·哈希算法
洛水水4 小时前
【力扣100题】81.寻找两个正序数组的中位数
数据结构·算法·leetcode
happymaker06264 小时前
LeetCodeHot100——155.最小栈
算法
洛水水4 小时前
【力扣100题】85.每日温度
算法·leetcode·职场和发展