测试一下python版本的dqrant向量数据库的效果,完整代码如下:
安装库
!pip install qdrant-client>=1.1.1
!pip install -U sentence-transformers
导入
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer
encoder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
准备测试数据集
documents = [
{
"name": "The Time Machine",
"description": "A man travels through time and witnesses the evolution of humanity."
* 8,
"author": "H.G. Wells",
"year": 1895,
},
{
"name": "Ender's Game",
"description": "A young boy is trained to become a military leader in a war against an alien race."
* 4,
"author": "Orson Scott Card",
"year": 1985,
},
{
"name": "Brave New World",
"description": "A dystopian society where people are genetically engineered and conditioned to conform to a strict social hierarchy."
* 6,
"author": "Aldous Huxley",
"year": 1932,
},
] * 50000
print(len(documents))
创建存储库
qdrant = QdrantClient(":memory:") # 内存中
# qdrant = QdrantClient(path='./qdrant') # 存储到本地
在数据库中创建一个collection(类似一个存储桶)
qdrant.recreate_collection(
collection_name="my_books",
vectors_config=models.VectorParams(
size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
distance=models.Distance.COSINE,
),
)
对文档进行向量化
import hashlib
from tqdm import tqdm
def sha256(text):
hash_object = hashlib.sha256()
hash_object.update(text.encode("utf-8"))
hash_value = hash_object.hexdigest()
return hash_value
records = []
bs = 256
for i in tqdm(range(0, len(documents), bs)):
docs = documents[i : i + bs]
vectors = encoder.encode(
[doc["description"] for doc in docs], normalize_embeddings=True
).tolist()
record = [
models.Record(id=idx, vector=vec, payload=doc) # sha256(doc['description'])
for idx, vec, doc in zip(range(i, i + bs), vectors, docs)
]
records.extend(record)
上传到向量数据库中指定的collection
qdrant.upload_points(
collection_name="my_books", points=records, batch_size=128, parallel=12
)
语义搜索
query = "Aliens attack our planet"
hits = qdrant.search(
collection_name="my_books",
query_vector=encoder.encode(query).tolist(),
limit=6,
)
for hit in hits:
print(hit.payload, "score:", hit.score)
条件搜索
search only for books from 21st century
hits = qdrant.search(
collection_name="my_books",
query_vector=encoder.encode("Tyranic society").tolist(),
query_filter=models.Filter(
must=[models.FieldCondition(key="year", range=models.Range(gte=1980))]
),
limit=3,
)
for hit in hits:
print(hit.payload, "score:", hit.score)