一、实现
使用目录结构:
templates
---upload.html
faiss_app.py
前端代码:upload.html
html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Search and Show Multiple Images</title>
<style>
#image-container {
display: flex;
flex-wrap: wrap;
}
#image-container img {
max-width: 150px;
margin: 10px;
}
</style>
</head>
<body>
<h1>Search Images</h1>
<!-- 搜索框 -->
<form id="search-form">
<input type="text" id="search-input" name="query" placeholder="Enter search term" required>
<input type="submit" value="Search">
</form>
<h2>Search Results</h2>
<!-- 显示搜索返回的多张图片 -->
<div id="image-container"></div>
<!-- 使用JS处理表单提交 -->
<script>
document.getElementById('search-form').addEventListener('submit', async function(event) {
event.preventDefault(); // 阻止表单默认提交行为
const query = document.getElementById('search-input').value; // 获取搜索框中的输入内容
try {
// 发送GET请求,将搜索关键词发送到后端
const response = await fetch(`/search?query=${encodeURIComponent(query)}`, {
method: 'GET',
});
// 确保服务器返回JSON数据
const data = await response.json();
// 清空图片容器
const imageContainer = document.getElementById('image-container');
imageContainer.innerHTML = '';
// 遍历后端返回的图片URL数组,动态创建<img>标签并渲染
data.image_urls.forEach(url => {
const imgElement = document.createElement('img');
imgElement.src = url; // 设置图片的src属性为返回的URL
imageContainer.appendChild(imgElement); // 将图片添加到容器中
});
} catch (error) {
console.error('Error searching for images:', error);
}
});
</script>
</body>
</html>
后端代码 faiss_app.py:
python
from sentence_transformers import SentenceTransformer, util
from PIL import Image
from flask import Flask, request, jsonify, current_app, render_template, send_from_directory, url_for
from werkzeug.utils import secure_filename
import faiss
import os, glob
import numpy as np
from markupsafe import escape
import shutil
#Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
UPLOAD_FOLDER = 'uploads/'
IMAGES_PATH = "C:\\Users\\xxxx\\Pictures\\"
def generate_clip_embeddings(images_path, model):
image_paths = []
# 使用 os.walk 遍历所有子目录和文件
for root, dirs, files in os.walk(images_path):
for file in files:
# 获取文件的扩展名并转换为小写
ext = os.path.splitext(file)[1].lower()
# 判断是否是图片文件
if ext in IMAGE_EXTENSIONS:
image_paths.append(os.path.join(root, file))
embeddings = []
for img_path in image_paths:
image = Image.open(img_path)
embedding = model.encode(image)
embeddings.append(embedding)
return embeddings, image_paths
def create_faiss_index(embeddings, image_paths, output_path):
dimension = len(embeddings[0])
# 分情况创建Faiss索引对象
if len(image_paths) < 39 * 256:
# 如果条目很少,直接用最普通的L2索引
faiss_index = faiss.IndexFlatL2(dimension)
elif len(image_paths) < 39 * 4096:
# 如果条目少于39 × 4096,就只用PQ量化,不使用IVF
faiss_index = faiss.index_factory(dimension, 'OPQ64_256,PQ64x8')
else:
# 否则就加上IVF
faiss_index = faiss.index_factory(dimension, 'OPQ64_256,IVF4096,PQ64x8')
res = faiss.StandardGpuResources()
co = faiss.GpuClonerOptions()
co.useFloat16 = True
faiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)
#index = faiss.IndexFlatIP(dimension)
faiss_index = faiss.IndexIDMap(faiss_index)
vectors = np.array(embeddings).astype(np.float32)
# Add vectors to the index with IDs
faiss_index.add_with_ids(vectors, np.array(range(len(embeddings))))
# Save the index
faiss_index = faiss.index_gpu_to_cpu(faiss_index)
faiss.write_index(faiss_index, output_path)
print(f"Index created and saved to {output_path}")
# Save image paths
with open(output_path + '.paths', 'w') as f:
for img_path in image_paths:
f.write(img_path + '\n')
return faiss_index
def load_faiss_index(index_path):
faiss_index = faiss.read_index(index_path)
with open(index_path + '.paths', 'r') as f:
image_paths = [line.strip() for line in f]
print(f"Index loaded from {index_path}")
if not faiss_index.is_trained:
raise RuntimeError(f'从[{index_path}]加载的Faiss索引未训练')
res = faiss.StandardGpuResources()
co = faiss.GpuClonerOptions()
co.useFloat16 = True
faiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)
return faiss_index, image_paths
def retrieve_similar_images(query, model, index, image_paths, top_k=3):
# query preprocess:
if query.endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
query = Image.open(query)
query_features = model.encode(query)
query_features = query_features.astype(np.float32).reshape(1, -1)
distances, indices = index.search(query_features, top_k)
retrieved_images = [image_paths[int(idx)] for idx in indices[0]]
return query, retrieved_images
# 检查文件扩展名是否允许
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def search():
query = request.args.get('query') # 获取搜索关键词
safe_query = escape(query)
if not query:
return jsonify({"error": "No search query provided"}), 400
index, image_paths = None, []
OUTPUT_INDEX_PATH = f"{app.config['UPLOAD_FOLDER']}/vector.index"
if os.path.exists(OUTPUT_INDEX_PATH):
index, image_paths = load_faiss_index(OUTPUT_INDEX_PATH)
else:
embeddings, image_paths = generate_clip_embeddings(IMAGES_PATH, model)
index = create_faiss_index(embeddings, image_paths, OUTPUT_INDEX_PATH)
query, retrieved_images = retrieve_similar_images(query, model, index, image_paths, top_k=5)
image_urls = []
for path in retrieved_images:
base_name = os.path.basename(path)
shutil.copy(path, os.path.join(app.config['UPLOAD_FOLDER'], base_name))
image_urls.append(url_for('uploaded_file_path', filename=base_name))
return jsonify({"image_urls": image_urls})
def index():
return render_template('upload.html')
# 提供静态文件的访问路径
def uploaded_file_path(filename):
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
if __name__ == "__main__":
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# 主页显示上传表单
app.route('/')(index)
app.route('/search', methods=['GET'])(search)
app.route('/uploads/images/<filename>')(uploaded_file_path)
app.run(host='0.0.0.0', port=8080, debug=True)