使用 GPT-4 Vision 的 CLIP 嵌入来改进多模态 RAG

多模态 RAG 将附加模态集成到传统的基于文本的 RAG 中,通过提供额外的背景信息和基础文本数据来增强 LLM 的问答能力,从而提高理解力。

我们直接嵌入图像进行相似性搜索,绕过文本字幕的有损过程,以提高检索准确性。

使用基于 CLIP 的嵌入可以进一步使用特定数据进行微调或使用看不见的图像进行更新。

该技术通过使用用户提供的技术图像搜索企业知识库来提供相关信息来展示。

安装

首先我们来安装相关的软件包。

python 复制代码
#installations
%pip install clip
%pip install torch
%pip install pillow
%pip install faiss-cpu
%pip install numpy
%pip install git+https://github.com/openai/CLIP.git
%pip install openai

然后让我们导入所有需要的包。

# model imports
import faiss
import json
import torch
from openai import OpenAI
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
client = OpenAI()

# helper imports
from tqdm import tqdm
import json
import os
import numpy as np
import pickle
from typing import List, Union, Tuple

# visualisation imports
from PIL import Image
import matplotlib.pyplot as plt
import base64

现在让我们加载 CLIP 模型。

python 复制代码
#load model on device. The device you are running inference/training on is either a CPU or GPU if you have.
device = "cpu"
model, preprocess = clip.load("ViT-B/32",device=device)

我们现在将:

  1. 创建图像嵌入数据库
  2. 设置对视觉模型的查询
  3. 执行语义搜索
  4. 将用户查询传递给图像

创建图像嵌入数据库

接下来,我们将从图像目录中创建图像嵌入知识库。这将是我们搜索的技术知识库,用于向用户提供他们上传的图像的信息。

我们传入存储图像的目录(JPEG 格式)并循环遍历每个目录以创建嵌入。

我们还有一个 description.json。它为我们知识库中的每个图像都有一个条目。它有两个键:"image_path"和"description"。它将每张图片映射到该图片的有用描述,以帮助回答用户的问题。

首先,让我们编写一个函数来获取给定目录中的所有图像路径。然后,我们将从名为"image_database"的目录中获取所有 jpeg

python 复制代码
def get_image_paths(directory: str, number: int = None) -> List[str]:
    image_paths = []
    count = 0
    for filename in os.listdir(directory):
        if filename.endswith('.jpeg'):
            image_paths.append(os.path.join(directory, filename))
            if number is not None and count == number:
                return [image_paths[-1]]
            count += 1
    return image_paths
direc = 'image_database/'
image_paths = get_image_paths(direc)

接下来,我们将编写一个函数,根据一系列路径从 CLIP 模型中获取图像嵌入。

我们首先使用之前得到的预处理函数对图像进行预处理。这会执行一些操作以确保 CLIP 模型的输入具有正确的格式和维度,包括调整大小、规范化、颜色通道调整等。

然后,我们将这些预处理过的图像堆叠在一起,这样我们就可以一次性将它们传递到模型中,而不是循环传递。最后返回模型输出,即嵌入数组。

python 复制代码
def get_features_from_image_path(image_paths):
  images = [preprocess(Image.open(image_path).convert("RGB")) for image_path in image_paths]
  image_input = torch.tensor(np.stack(images))
  with torch.no_grad():
    image_features = model.encode_image(image_input).float()
  return image_features
image_features = get_features_from_image_path(image_paths)

我们现在可以创建我们的矢量数据库。

python 复制代码
index = faiss.IndexFlatIP(image_features.shape[1])
index.add(image_features)

并提取我们的 json 以进行图像描述映射并创建一个 json 列表。我们还创建了一个辅助函数来搜索此列表以查找我们想要的图像,这样我们就可以获得该图像的描述

python 复制代码
data = []
image_path = 'train1.jpeg'
with open('description.json', 'r') as file:
    for line in file:
        data.append(json.loads(line))
def find_entry(data, key, value):
    for entry in data:
        if entry.get(key) == value:
            return entry
    return None

让我们显示一个示例图像,这将是用户上传的图像。这是在 2024 年 CES 上亮相的一项技术。它是 DELTA Pro Ultra 全屋电池发电机。

python 复制代码
im = Image.open(image_path)
plt.imshow(im)
plt.show()

查询视觉模型

现在我们来看看 GPT-4 Vision(之前它肯定没有见过这项技术)会把它标记为什么。

首先,我们需要编写一个函数以 base64 格式对图像进行编码,因为这是我们将传递到视觉模型的格式。然后,我们将创建一个通用的 image_query 函数,以便我们可以使用图像输入查询 LLM。

python 复制代码
def encode_image(image_path):
    with open(image_path, 'rb') as image_file:
        encoded_image = base64.b64encode(image_file.read())
        return encoded_image.decode('utf-8')

def image_query(query, image_path):
    response = client.chat.completions.create(
        model='gpt-4-vision-preview',
        messages=[
            {
            "role": "user",
            "content": [
                {
                "type": "text",
                "text": query,
                },
                {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{encode_image(image_path)}",
                },
                }
            ],
            }
        ],
        max_tokens=300,
    )
    # Extract relevant features from the response
    return response.choices[0].message.content
image_query('Write a short label of what is show in this image?', image_path)

"自动送货机器人"

我们可以看到,它尽了最大努力利用训练过的信息,但由于在训练数据中没有看到任何类似的东西,它犯了一个错误。这是因为这是一张模糊的图像,很难推断和推论。

执行语义搜索

现在让我们执行相似性搜索,以在我们的知识库中找到两张最相似的图像。我们通过获取用户输入的 image_path 的嵌入,检索数据库中相似图像的索引和距离来实现此目的。距离将成为我们相似性的代理指标,距离越小意味着越相似。然后我们根据距离按降序排序。

python 复制代码
image_search_embedding = get_features_from_image_path([image_path])
distances, indices = index.search(image_search_embedding.reshape(1, -1), 2) #2 signifies the number of topmost similar images to bring back
distances = distances[0]
indices = indices[0]
indices_distances = list(zip(indices, distances))
indices_distances.sort(key=lambda x: x[1], reverse=True)

我们需要索引,因为我们将使用它来搜索我们的 image_directory 并选择索引位置处的图像以输入到 RAG 的视觉模型中。

让我们看看它带回了什么(我们按相似度顺序显示它们):

python 复制代码
#display similar images
for idx, distance in indices_distances:
    print(idx)
    path = get_image_paths(direc, idx)[0]
    im = Image.open(path)
    plt.imshow(im)
    plt.show()


我们可以看到它返回了两张包含 DELTA Pro Ultra 全屋电池发电机的图像。其中一张图像中还有一些背景,可能会分散注意力,但它设法找到了正确的图像。

用户查询最相似的图像

现在,对于我们最相似的图像,我们希望将其及其描述与用户查询一起传递给 GPT-V,以便他们可以查询他们可能购买的技术。这就是视觉模型的强大之处,您可以向模型提出尚未明确训练的一般查询,并且它会以高精度做出响应。

在下面的例子中,我们将询问所讨论物品的容量。

python 复制代码
similar_path = get_image_paths(direc, indices_distances[0][0])[0]
element = find_entry(data, 'image_path', similar_path)

user_query = 'What is the capacity of this item?'
prompt = f"""
Below is a user query, I want you to answer the query using the description and image provided.

user query:
{user_query}

description:
{element['description']}
"""
image_query(prompt, similar_path)

"便携式家用电池 DELTA Pro 的基本容量为 3.6kWh。使用附加电池,容量可扩大至 25kWh。图片展示了 DELTA Pro,其交流输出功率容量也高达 3600W。"

我们发现它能够回答这个问题。这只有通过直接匹配图像并从中收集相关描述作为上下文才有可能。

结论

在本文,我们介绍了如何使用 CLIP 模型,并使用 CLIP 模型创建图像嵌入数据库、执行语义搜索并最终提供用户查询来回答问题的示例。

相关推荐
985小水博一枚呀3 小时前
【深度学习|可视化】如何以图形化的方式展示神经网络的结构、训练过程、模型的中间状态或模型决策的结果??
人工智能·python·深度学习·神经网络·机器学习·计算机视觉·cnn
明明真系叻6 小时前
第十二周:机器学习笔记
人工智能·机器学习
跟着大数据和AI去旅行7 小时前
使用肘部法则确定K-Means中的k值
python·机器学习·kmeans
QuantumYou7 小时前
【对比学习串烧】 SWav和 BYOL
学习·机器学习
学不会lostfound7 小时前
一、机器学习算法与实践_03概率论与贝叶斯算法笔记
算法·机器学习·概率论·高斯贝叶斯
正义的彬彬侠9 小时前
举例说明计算一个矩阵的秩的完整步骤
人工智能·机器学习·矩阵·回归
鸽芷咕9 小时前
【Python报错已解决】xlrd.biffh.XLRDError: Excel xlsx file; not supported
开发语言·python·机器学习·bug·excel
zhangbin_23710 小时前
【Python机器学习】NLP信息提取——提取人物/事物关系
开发语言·人工智能·python·机器学习·自然语言处理
王豫翔10 小时前
OpenAl o1论文:Let’s Verify Step by Step 快速解读
人工智能·深度学习·机器学习·chatgpt
叫我:松哥12 小时前
基于机器学习的癌症数据分析与预测系统实现,有三种算法,bootstrap前端+flask
前端·python·随机森林·机器学习·数据分析·flask·bootstrap