milvus+flask山寨复刻《从零构建向量数据库》第7章

常规练手,图片搜索山寨版。拜读罗云大佬著作,结果只有操作层的东西可以上上手。

书中是自己写的向量数据库,这边直接用python拼个现成的milvus向量数据库。

  1. 创建一个向量数据库以及对应的相应数据表:
python 复制代码
# Milvus Setup Arguments
COLLECTION_NAME = 'animal_search'
DIMENSION = 2048
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128

from pymilvus import connections

# Connect to the instance
connections.connect(host=MILVUS_HOST,port=MILVUS_PORT)

from pymilvus import utility

# Remove any previous collection with the same name
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

#创建保存ID、图片文件路径及Embeddings的Collection。
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

fields = [
        FieldSchema(name='id',dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR,max_length=200),
        FieldSchema(name='image_embedding',dtype=DataType.FLOAT_VECTOR,dim=DIMENSION)
        ]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

index_params = {
        'metric_type':'L2',
        'index_type': "IVF_FLAT",
        'params':{'nlist':16384}
}
collection.create_index(field_name="image_embedding",index_params=index_params)
collection.load()
  1. 写一堆图片进去存着,向量其实就是各种像素间的维度特征,
python 复制代码
# Milvus Setup Arguments
COLLECTION_NAME = 'animal_search'
DIMENSION = 2048
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128

from pymilvus import connections

# Connect to the instance
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)

import glob

paths = glob.glob('/mcm/vectorDB_training/animals_db/*',recursive=True)

#分批预处理数据
import torch
# Load the embedding model with the last layer removed
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()

from torchvision import transforms
# Preprocessing for images
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

#插入数据
from PIL import Image
from tqdm import tqdm

# Embed function that embeds the batch and inserts it
def embed(data):
    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

    fields = [
        FieldSchema(name='id',dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR,max_length=200),
        FieldSchema(name='image_embedding',dtype=DataType.FLOAT_VECTOR,dim=DIMENSION)
        ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    with torch.no_grad():
        output = model(torch.stack(data[0])).squeeze()
        collection.insert([data[1],output.tolist()])
    collection.flush()

data_batch = [[],[]]

# Read the images into batches for embedding and insertion
for path in tqdm(paths):
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed(data_batch)
        data_batch = [[],[]]

# Embed and insert the remainder
if len(data_batch[0]) != 0:
    embed(data_batch)
  1. 向量化图片的函数要单独拎出来,做搜索功能的时候用它。
python 复制代码
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image

def extract_features(image_path):
    # 加载预训练的 ResNet-50 模型
    model = resnet50(pretrained=True)
    model = torch.nn.Sequential(*list(model.children())[:-1])  #移除fc层,不移除,向量最后就是1000层,而不是2048
    model.eval()

    # 图像预处理
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 读取图像
    img = Image.open(image_path)
    img_t = preprocess(img)
    batch_t = torch.unsqueeze(img_t, 0)

    # 提取特征
    with torch.no_grad():
        out = model(batch_t)

    # 将特征向量转换为一维数组并返回
    return out.flatten().numpy()
  1. 用flask做的界面
python 复制代码
from flask import Flask,request,jsonify
from flask import render_template
from image_eb import extract_features
#from pymilvus import connections
from pymilvus import MilvusClient

import logging
import os
import shutil

MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
COLLECTION_NAME = 'animal_search'
TOP_K = 3

app = Flask(__name__)
milvus_client = MilvusClient(uri="http://localhost:19530")

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/upload",methods=["POST"])
def upload_image():
    image_file = request.files["image"]
    image_id_str = request.form.get("image_id")
    data = []
    #检查image_id是否存在。
    if not image_id_str:
        return jsonify({"message": "Image ID is required"}),400
    #image id转化为整型
    try:
        image_id = int(image_id_str)
        data.append(image_id)
    except ValueError:
        return jsonify({"message": "Invalid image ID. It must be an integer"}),400
    filename = image_file.filename
    image_path = os.path.join("static/images",image_id_str)
    image_file.save(image_path)
    image_features = extract_features(image_path)
    data.append(image_features)
    data_dict = dict(filepath=image_path,image_embedding=data[1])
    #更新数据库中记录
    milvus_client.insert(collection_name=COLLECTION_NAME,data=[data_dict])
    return jsonify({"message": "Image uploaded successfully", "id": image_id})

@app.route("/search",methods=["POST"])
def search_image():
    image_file = request.files["image"]
    image_path = os.path.join("static/images","temp_image.jpg")
    image_file.save(image_path)
    image_features = extract_features(image_path)
    data_li = [extract_features(image_path).tolist()]
    search_result = milvus_client.search(
        collection_name=COLLECTION_NAME,
        data=data_li,
        output_fields=["filepath"],
        limit=TOP_K,
        search_params={'metric_type': 'L2', 'params': {}},
    )
    dict_search_result = search_result[0]
    arr_search_result = []
    destination_folder = '/mcm/vectorDB_training/static/images'
    for index,value in enumerate(dict_search_result):
        source_file = value["entity"]["filepath"]
        base_file_name = os.path.basename(source_file)
        destination_file = os.path.join(destination_folder, base_file_name)
        shutil.copy(source_file, destination_file)
        key_file_name = os.path.join("/static/images",base_file_name)

        arr_search_result.append(key_file_name)        

    image_urls = [
            f"{filepath}" for filepath in arr_search_result
        ]
    return jsonify({"image_urls":image_urls})

if __name__=="__main__":
    app.run(host='0.0.0.0',port=5020,debug=True)

小网站结构,以及其他杂代码,可以查看以及直接下载:https://www.ituring.com.cn/book/3305

相关推荐
游王子19 小时前
Milvus(18):IVF_PQ、HNSW
milvus
AI大模型顾潇20 小时前
[特殊字符] Milvus + LLM大模型:打造智能电影知识库系统
数据库·人工智能·机器学习·大模型·llm·llama·milvus
Timmer丿2 天前
AI开发跃迁指南(第三章:第四维度1——Milvus、weaviate、redis等向量数据库介绍及对比选型)
数据库·人工智能·milvus
游王子2 天前
Milvus(16):索引解释
milvus
桥Dopey3 天前
Milvus 向量数据库详解与实践指南
推荐系统·milvus·向量数据库·图像检索
游王子3 天前
Milvus(15):插入和删除
milvus
桥Dopey3 天前
mac 使用 Docker 安装向量数据库Milvus独立版的保姆级别教程
milvus·向量数据库
游王子6 天前
Milvus(10):JSON 字段、数组字段
json·milvus
游王子6 天前
Milvus(13):自定义分析器、过滤器
milvus