Elasticsearch:运用 JINA 来实现多模态搜索的 RAG

Jina Embeddings v4 是一个 38 亿参数的通用向量模型,用于多模态多语言检索,支持单向量和多向量输出。那么我们该如何使用它对图片及文字进行搜索,并最终对搜索的结果做 RAG。

下载源码

闲话少说,我们直接到地址 https://github.com/liu-xiao-guo/jina_multimodal_rag 下载源码。

复制代码
git clone https://github.com/liu-xiao-guo/jina_multimodal_rag

$ pwd
/Users/liuxg/python/jina_multimodal_rag
$ tree -L 3
.
├── README.md
├── app.py
├── images
│   ├── bladerunner-city.jpg
│   ├── images (1).jpeg
│   ├── images (2).jpeg
│   ├── images (3).jpeg
│   ├── matrix-code.jpg
│   ├── starwars-lightsaber.jpg
│   └── tfa_poster_wide_header-1536x864-324397389357.0.0.1537961254.webp
├── pics
│   ├── pic1.png
│   ├── pic2.png
│   └── pic3.png
├── requirements.txt
└── texts
    ├── 1.txt
    ├── 10.txt
    ├── 2.txt
    ├── 3.txt
    ├── 4.txt
    ├── 5.txt
    ├── 6.txt
    ├── 7.txt
    ├── 8.txt
    └── 9.txt

如上所示,我们代码在 app.py 里。我们可以把所有需要向量化的图片放入到 images 目录下。把所有的需要向量化的文字放入到 texts 里文件中。

除了上面的文件之外,还有一个叫做 .env 的文件:

.env

复制代码
ES_URL="<Your ES_URL>"
ES_API_KEY="Your ES_API_Key"
GEMINI_FLASH_API_KEY="<Your Gemini Flash API Key>"

我们需要根据自己的配置填入相应的设置。在今天的使用中,我们使用 https://openrouter.ai/ 来调用 Gemini 3 Flash multimodal LLM 来完成我们的 RAG。

代码设计

为了方便我们的代码设计,我们使用 streamlit 来设计界面:

app.py

复制代码
import os
import torch
import streamlit as st
from PIL import Image
from transformers import AutoModel
from elasticsearch import Elasticsearch
from dotenv import load_dotenv
import openai
import base64
from io import BytesIO

# Load environment variables from .env if exists
load_dotenv()

# -------------------------
# Config
# -------------------------
INDEX_NAME = "multimodal-index"
IMAGE_FOLDER = "./images"  # local image folder
TEXT_FOLDER = "./texts"    # local text folder
ES_URL = os.getenv("ES_URL", "https://localhost:9200")
ES_API_KEY = os.getenv("ES_API_KEY", "")
OPENROUTER_API_KEY = os.getenv("GEMINI_FLASH_API_KEY")

# -------------------------
# Elasticsearch
# -------------------------
@st.cache_resource
def get_es():
    return Elasticsearch(ES_URL, verify_certs=False, api_key=ES_API_KEY)

es = get_es() # Initialize Elasticsearch client

# -------------------------
# Model loading
# -------------------------
@st.cache_resource
def load_model():
    device = (
        "mps" if torch.backends.mps.is_available()
        else "cuda" if torch.cuda.is_available()
        else "cpu"
    )

    model = AutoModel.from_pretrained(
        "jinaai/jina-embeddings-v4",
        trust_remote_code=True,
        torch_dtype=torch.float32,
    ).to(device)
    model.eval()
    return model, device

model, device = load_model()

# -------------------------
# LLM Client loading (OpenRouter)
# -------------------------
@st.cache_resource
def load_llm_client():
    if not OPENROUTER_API_KEY:
        st.error("GEMINI_FLASH_API_KEY not found in environment variables.")
        return None
    
    client = openai.OpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key=OPENROUTER_API_KEY,
    )
    return client

llm_client = load_llm_client()
LLM_MODEL_NAME = "google/gemini-3-flash-preview"

# -------------------------
# Index setup
# -------------------------
def create_index():
    if es.indices.exists(index=INDEX_NAME):
        return

    mapping = {
        "mappings": {
            "properties": {
                "filename": {"type": "keyword"},
                "path": {"type": "keyword"},
                "caption": {"type": "text"},
                "vector_field": {
                    "type": "dense_vector",
                    "dims": 2048,
                    "index": True,
                    "similarity": "cosine"
                }
            }
        }
    }
    es.indices.create(index=INDEX_NAME, body=mapping)

create_index()

# -------------------------
# Embedding helpers
# -------------------------
def embed_image(pil_image):
    with torch.inference_mode():
        vec = model.encode_image(
            images=[pil_image],
            task="retrieval",
            return_numpy=True
        )
    return vec[0]

def embed_text(text):
    with torch.inference_mode():
        vec = model.encode_text(
            texts=[text],
            task="retrieval",
            prompt_name="query",
            return_numpy=True
        )
    return vec[0]

# -------------------------
# Batch ingestion for images
# -------------------------
def ingest_image_folder(folder):
    docs = []
    for fname in os.listdir(folder):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
            continue

        path = os.path.join(folder, fname)
        image = Image.open(path).convert("RGB")
        vec = embed_image(image)

        docs.append({
            "_index": INDEX_NAME,
            "_source": {
                "filename": fname,
                "path": path,
                "caption": fname.replace("_", " "),
                "vector_field": vec.tolist(),
            }
        })

    if docs:
        from elasticsearch.helpers import bulk
        bulk(es, docs)

# -------------------------
# Batch ingestion for text files
# -------------------------
def ingest_text_folder(folder):
    docs = []
    for fname in os.listdir(folder):
        if not fname.lower().endswith(".txt"):
            continue

        path = os.path.join(folder, fname)
        with open(path, "r", encoding="utf-8") as f:
            text = f.read().strip()
        vec = embed_text(text)

        docs.append({
            "_index": INDEX_NAME,
            "_source": {
                "filename": fname,
                "path": path,
                "caption": text[:500],
                "vector_field": vec.tolist(),
            }
        })

    if docs:
        from elasticsearch.helpers import bulk
        bulk(es, docs)

# -------------------------
# KNN search only
# -------------------------
def knn_search(query, k=10):
    vec = embed_text(query)
    body = {
        "size": k,
        "query": {
            "knn": {
                "field": "vector_field",
                "query_vector": vec.tolist(),
                "k": k,
                "num_candidates": 50
            }
        }
    }
    res = es.search(index=INDEX_NAME, body=body)
    return res["hits"]["hits"]

# -------------------------
# Image to Base64 helper
# -------------------------
def pil_to_base64(image, format="jpeg"):
    buffered = BytesIO()
    image.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format};base64,{img_str}"

# -------------------------
# RAG Augmentation
# -------------------------
def generate_rag_response(user_query: str, k: int = 3):
    """
    Retrieves top K documents, creates a multimodal prompt, and generates a response from Gemini via OpenRouter.
    """
    st.write(f"Searching for top {k} relevant documents for RAG...")
    results = knn_search(user_query, k=k)
    
    if not results:
        st.warning("No relevant documents found for RAG.")
        return

    # Build the multimodal prompt for OpenAI-compatible API
    content_parts = []
    text_context = "Based on the following information:\n"

    for hit in results:
        src = hit["_source"]
        path = src.get("path", "")
        if path and os.path.exists(path) and path.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
            text_context += f"- Image: {src.get('filename', 'N/A')}\n"
            try:
                img = Image.open(path).convert("RGB")
                base64_image = pil_to_base64(img)
                content_parts.append({
                    "type": "image_url",
                    "image_url": {"url": base64_image}
                })
            except Exception as e:
                st.error(f"Could not load image {path}: {e}")
        else:
            text_context += f"- Text Content: {src.get('caption', 'N/A')}\n"

    text_context += f"\nAnswer the question: {user_query}"
    content_parts.insert(0, {"type": "text", "text": text_context})

    messages = [{"role": "user", "content": content_parts}]

    st.subheader("Gemini Flash Multimodal Prompt:")
    st.json(messages)

    if llm_client:
        with st.spinner("Gemini Flash is generating a response via OpenRouter..."):
            try:
                response = llm_client.chat.completions.create(
                    model=LLM_MODEL_NAME,
                    messages=messages,
                    max_tokens=1024,
                )
                st.markdown("**LLM Generated Response:**")
                st.markdown(response.choices[0].message.content)
            except Exception as e:
                st.error(f"Error generating response from OpenRouter: {e}")

# -------------------------
# Streamlit UI
# -------------------------
st.title("🖼️📄 Multimodal Image & Text KNN Search")

# Batch ingestion buttons
st.subheader("Ingest Data")
if st.button("📥 Ingest image folder"):
    ingest_image_folder(IMAGE_FOLDER)
    st.success("Images ingested successfully")

if st.button("📥 Ingest text folder"):
    ingest_text_folder(TEXT_FOLDER)
    st.success("Text files ingested successfully")

col1, col2 = st.columns([3, 1])
with col2:
    if st.button("⚠️ Delete & Re-ingest All"):
        with st.spinner("Deleting index and re-ingesting all data..."):
            if es.indices.exists(index=INDEX_NAME):
                es.indices.delete(index=INDEX_NAME)
                st.toast(f"Index '{INDEX_NAME}' deleted.")
            
            create_index()
            st.toast("Index created.")
            ingest_image_folder(IMAGE_FOLDER)
            st.toast("Images ingested.")
            ingest_text_folder(TEXT_FOLDER)
            st.toast("Texts ingested.")
        st.success("All data has been re-ingested successfully!")

# Search box for typing text queries
st.subheader("Search")
user_query = st.text_input("Type your search query here", key="search_query")
k_value = st.slider("Number of results to retrieve (K)", min_value=1, max_value=10, value=4)

if user_query:
    st.subheader(f"Retrieval Results (Top {k_value})")
    retrieval_results = knn_search(user_query, k=k_value)

    if retrieval_results:
        cols_per_row = 3
        for i in range(0, len(retrieval_results), cols_per_row):
            row = retrieval_results[i:i + cols_per_row]
            cols = st.columns(len(row), gap="medium")

            for col, hit in zip(cols, row):
                src = hit["_source"]
                path = src.get("path", "")

                if path and os.path.exists(path) and path.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
                    col.image(path, caption=f"{src.get('filename', '')}", width=200)
                else:
                    col.write(f"**[TEXT]**\n\n{src.get('caption', '')}")

                col.write(f"Score: {hit['_score']:.3f}")
    else:
        st.info("No relevant documents found in the index.")

    st.subheader("RAG System Output")
    st.write("---")
    generate_rag_response(user_query, k=k_value)

代码不是很长。

运行代码:

我们需要在虚拟环境中使用如下的命令来安装所需要的库:

复制代码
pip install -r requirements.txt

我们使用如下的命令来执行:

复制代码
streamlit run app.py

首次运行,我们可以直接点击 Delete & Re-ingest All 按钮来写入所有的 images 及 texts。当然我们也可以分别使用 ingest image folderingest text folder 来完成文件的写入。值得注意的是:它们并不会删除之前的索引数据,而且重新写入 images 或 texts 目录里的文件。如果多次点击这个按钮,它会对该文件夹中的文件多次写入。

如下是搜索的结果:

当我们搜索 Star wars:

祝大家学习愉快!

相关推荐
永霖光电_UVLED2 小时前
氧化镓高体积热容的特性,集成高介电常数界面的结侧冷却架构
人工智能·生成对抗网络·架构·汽车·制造
lishutong10062 小时前
基于 Perfetto 与 AI 的 Android 性能自动化诊断方案
android·人工智能·自动化
lifewange2 小时前
Git版本管理
大数据·git·elasticsearch
面向Google编程2 小时前
从零学习Kafka:位移与高水位
大数据·后端·kafka
code_pgf2 小时前
Transformer 原理讲解及可视化算子操作
人工智能·深度学习·transformer
碑 一2 小时前
视频分割VisTR算法
人工智能·深度学习·计算机视觉
木斯佳2 小时前
前端八股文面经大全:腾讯前端一面(2026-04-04)·深度解析
前端·ai·鉴权·monorepo
AI2512242 小时前
免费AI视频生成工具技术解析与功能对比
人工智能·音视频