抽取语料库索引语义向量并建milvus库

import random

import time

import os

import sys

from tqdm import tqdm

import numpy as np

import paddle

from paddle import inference

from paddlenlp.transformers import AutoModel, AutoTokenizer

from paddlenlp.data import Stack, Tuple, Pad

from paddlenlp.datasets import load_dataset

from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,

tokenizer,

max_seq_length=512,

pad_to_max_seq_len=False):

result = []

for key, text in example.items():

encoded_inputs = tokenizer(text=text,

max_seq_len=max_seq_length,

pad_to_max_seq_len=pad_to_max_seq_len)

input_ids = encoded_inputs["input_ids"]

token_type_ids = encoded_inputs["token_type_ids"]

result += [input_ids, token_type_ids]

return result

model_dir='./output/yysy/'

corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'

max_seq_length=64

batch_size=64

device='gpu'

cpu_threads=8

model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):

def init(self,

model_dir,

device="gpu",

max_seq_length=128,

batch_size=32,

use_tensorrt=False,

precision="fp32",

cpu_threads=10,

enable_mkldnn=False):

self.max_seq_length = max_seq_length

self.batch_size = batch_size

model_file = model_dir + "inference.get_pooled_embedding.pdmodel"

params_file = model_dir + "inference.get_pooled_embedding.pdiparams"

if not os.path.exists(model_file):

raise ValueError("not find model file path {}".format(model_file))

if not os.path.exists(params_file):

raise ValueError("not find params file path {}".format(params_file))

config = paddle.inference.Config(model_file, params_file)

if device == "gpu":

config.enable_use_gpu(100, 0)

precision_map = {

"fp16": inference.PrecisionType.Half,

"fp32": inference.PrecisionType.Float32,

"int8": inference.PrecisionType.Int8

}

precision_mode = precision_map[precision]

if use_tensorrt:

config.enable_tensorrt_engine(max_batch_size=batch_size,

min_subgraph_size=30,

precision_mode=precision_mode)

elif device == "cpu":

config.disable_gpu()

if enable_mkldnn:

config.set_mkldnn_cache_capacity(10)

config.enable_mkldnn()

config.set_cpu_math_library_num_threads(cpu_threads)

elif device == "xpu":

config.enable_xpu(100)

config.switch_use_feed_fetch_ops(False)

self.predictor = paddle.inference.create_predictor(config)

self.input_handles = [

self.predictor.get_input_handle(name)

for name in self.predictor.get_input_names()

]

self.output_handle = self.predictor.get_output_handle(

self.predictor.get_output_names()[0])

def predict(self, data, tokenizer):

batchify_fn = lambda samples, fn=Tuple(

Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input

Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"

), # segment

): fn(samples)

all_embeddings = []

examples = []

for idx, text in enumerate(tqdm(data)):

input_ids, segment_ids = convert_example(

text,

tokenizer,

max_seq_length=self.max_seq_length,

pad_to_max_seq_len=True)

examples.append((input_ids, segment_ids))

if (len(examples) >=self.batch_size):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

examples = []

if (len(examples) > 0):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

all_embeddings = np.concatenate(all_embeddings, axis=0)

np.save('yysy_corpus_embedding', all_embeddings)

def read_text(file_path):

file = open(file_path)

id2corpus = {}

for idx, data in enumerate(file.readlines()):

id2corpus[idx] = data.strip()

return id2corpus

predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

id2corpus = read_text(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'

MILVUS_PORT = 19530

data_dim = 256

top_k = 20

collection_name = 'literature_search'

partition_tag = 'partition_1'

embedding_name = 'embeddings'

index_config = {

"index_type": "IVF_FLAT",

"metric_type": "L2",

"params": {

"nlist": 1000

},

}

search_params = {

"metric_type": "L2",

"params": {

"nprobe": top_k

},

}

from pymilvus import (

connections,

utility,

FieldSchema,

CollectionSchema,

DataType,

Collection,

)

fmt = "\n=== {:30} ===\n"

text_max_len = 1000

fields = [

FieldSchema(name="pk",

dtype=DataType.INT64,

is_primary=True,#主键

auto_id=False,#不自动增长

max_length=100),#id

FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text

FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding

]

schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def has_collection(self, collection_name):

try:

has = utility.has_collection(collection_name)

print(f"Does collection {collection_name} exist in Milvus: {has}")

return has

except Exception as e:

print("Milvus has_table error:", e)

def creat_collection(self, collection_name):

try:

print(fmt.format("Create collection {}".format(collection_name)))

self.collection = Collection(collection_name,

schema,

consistency_level="Strong")

except Exception as e:

print("Milvus create collection error:", e)

def drop_collection(self, collection_name):

try:

utility.drop_collection(collection_name)

except Exception as e:

print("Milvus delete collection error:", e)

def create_index(self, index_name):

try:

print(fmt.format("Start Creating index"))

self.collection.create_index(index_name, index_config)

print(fmt.format("Start loading"))

self.collection.load()

except Exception as e:

print("Milvus create index error:", e)

def has_partition(self, partition_tag):

try:

result = self.collection.has_partition(partition_tag)

return result

except Exception as e:

print("Milvus has partition error: ", e)

def create_partition(self, partition_tag):

try:

self.collection.create_partition(partition_tag)

print('create partition {} successfully'.format(partition_tag))

except Exception as e:

print('Milvus create partition error: ', e)

def insert(self, entities, collection_name, index_name, partition_tag=None):

try:

if not self.has_collection(collection_name):

self.creat_collection(collection_name)

self.create_index(index_name)

else:

self.collection = Collection(collection_name)

if (partition_tag

is not None) and (not self.has_partition(partition_tag)):

self.create_partition(partition_tag)

self.collection.insert(entities, partition_name=partition_tag)

print(

f"Number of entities in Milvus: {self.collection.num_entities}"

) # check the num_entites

except Exception as e:

print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def get_collection(self, collection_name):

try:

print(fmt.format("Connect collection {}".format(collection_name)))

self.collection = Collection(collection_name)

except Exception as e:

print("Milvus create collection error:", e)

def search(self,

vectors,

embedding_name,

collection_name,

partition_names=[],

output_fields=[]):

try:

self.get_collection(collection_name)

result = self.collection.search(vectors,

embedding_name,

search_params,

limit=top_k,

partition_names=partition_names,

output_fields=output_fields)

return result

except Exception as e:

print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'

index=18

batch_size=5000

def read_text(file_path):

file = open(file_path)

id2corpus = []

for idx, data in enumerate(file.readlines()):

id2corpus.append(data.strip())

return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):

print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....

cur_end = i + batch_size

if (cur_end > data_size):#确保下标不越界

cur_end = data_size

batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量

entities = [

j for j in range(i, cur_end, 1)\],#索引 \[corpus_list_embed\[j\]\[:text_max_len - 1\] for j in range(i, cur_end, 1)\],#文本 batch_emb #每个批次嵌入向量

client.insert(collection_name=collection_name,

entities=entities,

index_name=embedding_name,

partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start

result = recall_client.search(embeddings,

embedding_name,

collection_name,

partition_names=[partition_tag],

output_fields=['pk', 'text'])

time_end = time.time()# end

sum_t = time_end - time_start

print('time cost', sum_t, 's')

for hits in result:

for hit in hits:

print(f"hit: {hit}, text field: {hit.entity.get('text')}")

相关推荐
ghostwritten1 天前
Run Milvus in Kubernetes with Milvus Operator
容器·kubernetes·milvus
weixin_435208161 天前
通过 Markdown 改进 RAG 文档处理
人工智能·python·算法·自然语言处理·面试·nlp·aigc
Chaos_Wang_1 天前
NLP高频面试题(三十三)——Vision Transformer(ViT)模型架构介绍
人工智能·自然语言处理·transformer
weixin_435208162 天前
论文浅尝 | Interactive-KBQA:基于大语言模型的多轮交互KBQA(ACL2024)
人工智能·语言模型·自然语言处理
姚瑞南2 天前
从模糊感知到量化评估:构建一个Prompt打分工具
人工智能·自然语言处理·chatgpt·prompt·aigc
人工智能培训咨询叶梓2 天前
LLAMAFACTORY:一键优化大型语言模型微调的利器
人工智能·语言模型·自然语言处理·性能优化·调优·大模型微调·llama factory
sauTCc2 天前
N元语言模型的时间和空间复杂度计算
人工智能·语言模型·自然语言处理
鸿蒙布道师3 天前
OpenAI战略转向:开源推理模型背后的行业博弈与技术趋势
人工智能·深度学习·神经网络·opencv·自然语言处理·openai·deepseek
pen-ai3 天前
【NLP】15. NLP推理方法详解 --- 动态规划:序列标注,语法解析,共同指代
人工智能·自然语言处理·动态规划
Chaos_Wang_3 天前
NLP高频面试题(二十九)——大模型解码常见参数解析
人工智能·自然语言处理