如何打造自己的视觉版GPT

一、前言

自ChatGPT出现后,Yann LeCun就一直抨击GPT无法成为AGI的最终形态,其原因在于文本只是人类智能的一种载体,有许多东西是文本不能表示的,也无法从GPT的训练中理解,Yann LeCun认为视觉大模型是一种更有前景的方向。

如今GPT4V、MiniGPT5等模型已经具备了视觉能力,基于这种视觉能力,做出来非常多有趣是应用。比如:AI游戏解说、AI赛事解说、TinyBot等。

不过视觉比文本要更加昂贵,以GPT4V为例,理解一个一分钟的视频需要十几到几十美元不等。今天我们要做的就是打造一个低成本的视觉版GPT,这里的低成本不仅体现在价钱,还体现在硬件资源。

二、实现原理

让GPT理解图像的方式有两种,分别是端到端的和非端到端的。

首先是第一种非端到端的,这种方式的想法很简单,就是利用图像描述模型生成图像的描述,然后利用Prompt工程实现图像理解。

我们只需要使用Image Captioning相关的网络生成一个图像描述即可。

而端到端的网络则是输入图片和问题,直接返回回答。在这个过程中,需要用到两个Transformer。首先需要对输入的问题进行编码,可以使用基础的Transformer Encoder做到。为了方便Decoder处理,我们可以用一个ViT(Vision Transformer)对图像进行编码。ViT和传统的Transformer非常相似,只不过把原本输入的token变成了图片的patch。最后图像、文本被编码成768的向量,再把向量拼接,传递给Tranformer Decoder生成回答。

第一种方式可以理解为GPT本身是一个盲人,但是有人在给他描述场景。而第二种方式则是GPT本身就能看到东西。

在理解能力上,第二种实现要更加准确,不过资源的消耗是巨大的。因此本文使用第一种实现方式。

三、基于YOLO的图像描述

3.1 yolo生成描述

一种简单的思路是使用YOLO、ResNet等网络来生成描述。这种描述通常比较机械,但是非常直观。这里我们以YOLO为例。

首先需要使用yolo来完成目标检测的操作,这里需要使用:

pip install ultralytics

安装必要的库。 然后是目标检测,代码如下:

python 复制代码
from ultralytics import YOLO
# 从huggingface加载模型
model = YOLO('ultralyticsplus/yolov8s')
image = 'https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg'
# 预测
results = model.predict(image)
# 绘制框
render = render_result(model=model, image=image, result=results[0])
render.show()

使用yolo我们可以得到两个信息,第一个是图像中有什么物体,一个是物体所在的区域。根据这两个信息,我们可以使用一套模板当做图像描述,比如:

图片左上角(右下角)有一个猫(狗、人)

在图像中可以检测到多个物体,一个物体生成一句描述,最后合起来就是图像的最终描述了。生成描述的代码如下:

python 复制代码
from collections import Counter
from ultralyticsplus import YOLO


def location_to_description(yolo_results):
    counter = Counter()
    for category, box in zip(yolo_results[0].boxes.cls.cpu().numpy(), yolo_results[0].boxes.xywhn.cpu().numpy()):
        x, y, w, h = box
        if x < 0.33 and y < 0.33:
            location = "upper left area"
        elif x > 0.66 and y > 0.66:
            location = "lower right area"
        elif x < 0.33 and y > 0.66:
            location = "lower left area"
        elif x > 0.66 and y < 0.33:
            location = "upper right area"
        elif 0.33 < x < 0.66 < y:
            location = "bottom center area"
        elif 0.66 > x > 0.33 > y:
            location = "top center area"
        else:
            location = "center area"
        counter.update((f"There is ?? {yolo.model.names[int(category)]} located in the {location}.",))
    description = ""
    for obj in counter.most_common():
        tmp, count = obj
        if count > 1:
            tmp = tmp.replace("is", "are")
        tmp = tmp.replace("??", str(count))
        description += tmp + "\n"
    return description

location_to_description函数接收yolo的输出,并返回一段描述,比如用下面的图片测试:

测试代码如下:

ini 复制代码
if __name__ == '__main__':
    yolo = YOLO('ultralyticsplus/yolov8s')
    results = yolo.predict("dongwu.jpg")
    print(location_to_description(results))

得到如下输出:

scss 复制代码
There are 2 elephant located in the upper left area.
There are 2 elephant located in the center area.
There are 2 elephant located in the top center area.

因为是用规则生成的描述,难免会有一些语法问题,不过这个可以交给LLM自己解决。

3.2 图文问答

接下来就是把描述接入LLM了,这里选择开源的Llama2模型,使用Llamacpp部署。我们可以用fastapi写一个简单的接口,也可以用llama-cpp-python提供的api,这里选择前者。代码如下:

python 复制代码
import uvicorn
from llama_cpp import Llama
from fastapi import FastAPI, Request

app = FastAPI()
llm = Llama(
    model_path=r'G:\models\llama2\llama-2-13b-chat-q4\ggml-model-q4_0.gguf',
    n_ctx=2048
)


@app.post("/chat")
async def chat(request: Request):
    global llm
    jdata = await request.json()
    prompt = jdata['prompt']
    return llm(prompt, stop=['Human'])


if __name__ == '__main__':
    uvicorn.run(app, host='127.0.0.1', port=8000, workers=1)

这里把model_path设置成自己的模型位置即可。运行后就可以用post访问http://127.0.0.1:8000/chat 了,调用接口的代码如下:

python 复制代码
async def chat(prompt):
    async with aiohttp.ClientSession() as session:
        async with session.post('http://127.0.0.1:8000/chat', json={'prompt': prompt}) as response:
            response = await response.json()
            return response['choices'][0]['text']

我们只需传入prompt即可。接下来就是界面的搭建,这里选择streamlit,代码如下:

python 复制代码
import aiohttp
import asyncio
from io import BytesIO
from PIL import Image
import streamlit as st
from ultralyticsplus import YOLO

prompt = ""
# 加载历史消息
messages = st.session_state.get('history_chat')
if not messages:
    messages = []
# 加载yolo
yolo = st.session_state.get('yolo')
if not yolo:
    yolo = YOLO('ultralyticsplus/yolov8s')
    st.session_state['yolo'] = yolo

# 界面
st.title("图文对话")
if file := st.file_uploader(label="请上传图片"):
    image = Image.open(BytesIO(file.getvalue()))
    st.sidebar.image(image)
    results = yolo.predict(image)
    description = location_to_description(results)
    st.sidebar.write(description)
    prompt += (
        "System: You need to answer the questions based on the description of the picture given below."
        "If the description has nothing to do with the question, "
        "you should just answer using your own language abilities."
        "Do not imagine non-existent facts.\n\n"
        f"Description: {description}."
    )
for role, text in messages:
    st.chat_message(role).write(text)
if message := st.chat_input("请输入问题:"):
    messages.append(['user', message])
    prompt += (
        f"\n\nHuman: {message}. \n\nAssistant: "
    )
    st.chat_message('user').write(message)
    response = asyncio.run(chat(prompt))
    messages.append(['assistant', response])
    st.chat_message('assistant').write(response)
    st.session_state['history_chat'] = messages

运行界面后就可以上传图片进行对话了。

四、基于BLIP的图像描述

在这种方法中,上面的大部分代码都可以复用,我们只需要重写一个生成描述的方法即可。代码如下:

python 复制代码
# 导入新模块
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration


prompt = ""
# 加载历史消息
messages = st.session_state.get('history_chat')
if not messages:
    messages = []
# 把加载yolo改成加载blip
model_path = r"G:\huggingface\hub\models--Salesforce--blip-image-captioning-large"
processor = st.session_state.get('processor')
if not processor:
    processor = BlipProcessor.from_pretrained(model_path)
blip = st.session_state.get('blip')
if not blip:
    blip = BlipForConditionalGeneration.from_pretrained(model_path,
                                                        torch_dtype=torch.float16).to("cuda")
# 界面
st.title("图文对话")
if file := st.file_uploader(label="请上传图片"):

    image = Image.open(BytesIO(file.getvalue()))
    st.sidebar.image(image)
    
    # 使用blip生成描述
    text = "a photography of"
    inputs = processor(image, text, return_tensors="pt").to("cuda", torch.float16)
    out = blip.generate(**inputs)
    description = processor.decode(out[0], skip_special_tokens=True)
    st.sidebar.write(description)
    prompt += (
        "System: You need to answer the questions based on the description of the picture given below."
        "If the description has nothing to do with the question, "
        "you should just answer using your own language abilities."
        "Do not imagine non-existent facts.\n\n"
        f"Description: {description}."
    )
for role, text in messages:
    st.chat_message(role).write(text)
if message := st.chat_input("请输入问题:"):
    messages.append(['user', message])
    prompt += (
        f"\n\nHuman: {message}. \n\nAssistant: "
    )
    st.chat_message('user').write(message)
    response = asyncio.run(chat(prompt))
    messages.append(['assistant', response])
    st.chat_message('assistant').write(response)
    st.session_state['history_chat'] = messages

上述代码有三处修改。第一处是导包,这里把yolo去掉了,导入了BLIP。第二处则是加载模型,同样把yolo换成了BLIP。第三处则是生成描述,这里去掉了location_to_description函数,换成了使用blip生成图像描述。其余部分维持原样,此时我们再来运行会发现结果比原本要更自然。

五、总结

本文使用了一种类似于文档问答的方式实现了图文QA的操作。这种方法不需要额外的训练,可以在极低的成本下实现,但是效果也取决于描述的质量,对于某些细节的把握并不是那么准确。读者可以尝试更多生成描述的方法以提升问答的准确性。

关于实现的详细教程可以参考:www.bilibili.com/video/BV19H...

项目代码上传至github:github.com/IronSpiderM...

相关推荐
好奇龙猫1 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-11-01
人工智能·深度学习·神经网络·算法·机器学习·语言模型·数据挖掘
香菜大丸2 小时前
链表的归并排序
数据结构·算法·链表
jrrz08282 小时前
LeetCode 热题100(七)【链表】(1)
数据结构·c++·算法·leetcode·链表
oliveira-time2 小时前
golang学习2
算法
南宫生3 小时前
贪心算法习题其四【力扣】【算法学习day.21】
学习·算法·leetcode·链表·贪心算法
懒惰才能让科技进步4 小时前
从零学习大模型(十二)-----基于梯度的重要性剪枝(Gradient-based Pruning)
人工智能·深度学习·学习·算法·chatgpt·transformer·剪枝
Ni-Guvara4 小时前
函数对象笔记
c++·算法
泉崎4 小时前
11.7比赛总结
数据结构·算法
你好helloworld4 小时前
滑动窗口最大值
数据结构·算法·leetcode