import random
import time
import os
import sys
from tqdm import tqdm
import numpy as np
import paddle
from paddle import inference
from paddlenlp.transformers import AutoModel, AutoTokenizer
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.utils.log import logger
sys.path.append('.')
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=False):
result = []
for key, text in example.items():
encoded_inputs = tokenizer(text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
result += [input_ids, token_type_ids]
return result
model_dir='./output/yysy/'
corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'
max_seq_length=64
batch_size=64
device='gpu'
cpu_threads=8
model_name_or_path='rocketqa-zh-base-query-encoder'
class Predictor(object):
def init(self,
model_dir,
device="gpu",
max_seq_length=128,
batch_size=32,
use_tensorrt=False,
precision="fp32",
cpu_threads=10,
enable_mkldnn=False):
self.max_seq_length = max_seq_length
self.batch_size = batch_size
model_file = model_dir + "inference.get_pooled_embedding.pdmodel"
params_file = model_dir + "inference.get_pooled_embedding.pdiparams"
if not os.path.exists(model_file):
raise ValueError("not find model file path {}".format(model_file))
if not os.path.exists(params_file):
raise ValueError("not find params file path {}".format(params_file))
config = paddle.inference.Config(model_file, params_file)
if device == "gpu":
config.enable_use_gpu(100, 0)
precision_map = {
"fp16": inference.PrecisionType.Half,
"fp32": inference.PrecisionType.Float32,
"int8": inference.PrecisionType.Int8
}
precision_mode = precision_map[precision]
if use_tensorrt:
config.enable_tensorrt_engine(max_batch_size=batch_size,
min_subgraph_size=30,
precision_mode=precision_mode)
elif device == "cpu":
config.disable_gpu()
if enable_mkldnn:
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(cpu_threads)
elif device == "xpu":
config.enable_xpu(100)
config.switch_use_feed_fetch_ops(False)
self.predictor = paddle.inference.create_predictor(config)
self.input_handles = [
self.predictor.get_input_handle(name)
for name in self.predictor.get_input_names()
]
self.output_handle = self.predictor.get_output_handle(
self.predictor.get_output_names()[0])
def predict(self, data, tokenizer):
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
), # segment
): fn(samples)
all_embeddings = []
examples = []
for idx, text in enumerate(tqdm(data)):
input_ids, segment_ids = convert_example(
text,
tokenizer,
max_seq_length=self.max_seq_length,
pad_to_max_seq_len=True)
examples.append((input_ids, segment_ids))
if (len(examples) >=self.batch_size):
input_ids, segment_ids = batchify_fn(examples)
self.input_handles[0].copy_from_cpu(input_ids)
self.input_handles[1].copy_from_cpu(segment_ids)
self.predictor.run()
logits = self.output_handle.copy_to_cpu()
all_embeddings.append(logits)
examples = []
if (len(examples) > 0):
input_ids, segment_ids = batchify_fn(examples)
self.input_handles[0].copy_from_cpu(input_ids)
self.input_handles[1].copy_from_cpu(segment_ids)
self.predictor.run()
logits = self.output_handle.copy_to_cpu()
all_embeddings.append(logits)
all_embeddings = np.concatenate(all_embeddings, axis=0)
np.save('yysy_corpus_embedding', all_embeddings)
def read_text(file_path):
file = open(file_path)
id2corpus = {}
for idx, data in enumerate(file.readlines()):
id2corpus[idx] = data.strip()
return id2corpus
predictor = Predictor(model_dir, device, max_seq_length, batch_size)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
id2corpus = read_text(corpus_file)
corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本
predictor.predict(corpus_list, tokenizer)
MILVUS_HOST = '127.0.0.1'
MILVUS_PORT = 19530
data_dim = 256
top_k = 20
collection_name = 'literature_search'
partition_tag = 'partition_1'
embedding_name = 'embeddings'
index_config = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {
"nlist": 1000
},
}
search_params = {
"metric_type": "L2",
"params": {
"nprobe": top_k
},
}
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
fmt = "\n=== {:30} ===\n"
text_max_len = 1000
fields = [
FieldSchema(name="pk",
dtype=DataType.INT64,
is_primary=True,#主键
auto_id=False,#不自动增长
max_length=100),#id
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding
]
schema = CollectionSchema(fields, "Neural Search Index")
class VecToMilvus():#语义向量-->milvus
def init(self):
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
self.collection = None
def has_collection(self, collection_name):
try:
has = utility.has_collection(collection_name)
print(f"Does collection {collection_name} exist in Milvus: {has}")
return has
except Exception as e:
print("Milvus has_table error:", e)
def creat_collection(self, collection_name):
try:
print(fmt.format("Create collection {}".format(collection_name)))
self.collection = Collection(collection_name,
schema,
consistency_level="Strong")
except Exception as e:
print("Milvus create collection error:", e)
def drop_collection(self, collection_name):
try:
utility.drop_collection(collection_name)
except Exception as e:
print("Milvus delete collection error:", e)
def create_index(self, index_name):
try:
print(fmt.format("Start Creating index"))
self.collection.create_index(index_name, index_config)
print(fmt.format("Start loading"))
self.collection.load()
except Exception as e:
print("Milvus create index error:", e)
def has_partition(self, partition_tag):
try:
result = self.collection.has_partition(partition_tag)
return result
except Exception as e:
print("Milvus has partition error: ", e)
def create_partition(self, partition_tag):
try:
self.collection.create_partition(partition_tag)
print('create partition {} successfully'.format(partition_tag))
except Exception as e:
print('Milvus create partition error: ', e)
def insert(self, entities, collection_name, index_name, partition_tag=None):
try:
if not self.has_collection(collection_name):
self.creat_collection(collection_name)
self.create_index(index_name)
else:
self.collection = Collection(collection_name)
if (partition_tag
is not None) and (not self.has_partition(partition_tag)):
self.create_partition(partition_tag)
self.collection.insert(entities, partition_name=partition_tag)
print(
f"Number of entities in Milvus: {self.collection.num_entities}"
) # check the num_entites
except Exception as e:
print("Milvus insert error:", e)
class RecallByMilvus():#从milvus召回向量
def init(self):
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
self.collection = None
def get_collection(self, collection_name):
try:
print(fmt.format("Connect collection {}".format(collection_name)))
self.collection = Collection(collection_name)
except Exception as e:
print("Milvus create collection error:", e)
def search(self,
vectors,
embedding_name,
collection_name,
partition_names=[],
output_fields=[]):
try:
self.get_collection(collection_name)
result = self.collection.search(vectors,
embedding_name,
search_params,
limit=top_k,
partition_names=partition_names,
output_fields=output_fields)
return result
except Exception as e:
print('Milvus recall error: ', e)
data_path='./datasets/yysy/milvus/milvus_data_s.csv'
embedding_path='./yysy_corpus_embedding.npy'
index=18
batch_size=5000
def read_text(file_path):
file = open(file_path)
id2corpus = []
for idx, data in enumerate(file.readlines()):
id2corpus.append(data.strip())
return id2corpus
corpus_list_embed=read_text(data_path)
corpus_list_embed[:5]
embeddings = np.load(embedding_path)
embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids
client = VecToMilvus()
client.has_collection(collection_name)
client.drop_collection(collection_name)
data_size = len(embedding_ids)
x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作
max([len(i) for i in x])
for i in range(0, data_size, batch_size):
print(i)
for i in range(0, data_size, batch_size):#i:0-5000-10000-....
cur_end = i + batch_size
if (cur_end > data_size):#确保下标不越界
cur_end = data_size
batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量
entities = [
[j for j in range(i, cur_end, 1)],#索引
[corpus_list_embed[j][:text_max_len - 1] for j in range(i, cur_end, 1)],#文本
batch_emb #每个批次嵌入向量
]
client.insert(collection_name=collection_name,
entities=entities,
index_name=embedding_name,
partition_tag=partition_tag)
recall_client = RecallByMilvus()
embeddings = embeddings[np.arange(index, index + 1)]
time_start = time.time()# start
result = recall_client.search(embeddings,
embedding_name,
collection_name,
partition_names=[partition_tag],
output_fields=['pk', 'text'])
time_end = time.time()# end
sum_t = time_end - time_start
print('time cost', sum_t, 's')
for hits in result:
for hit in hits:
print(f"hit: {hit}, text field: {hit.entity.get('text')}")