抽取语料库索引语义向量并建milvus库

import random

import time

import os

import sys

from tqdm import tqdm

import numpy as np

import paddle

from paddle import inference

from paddlenlp.transformers import AutoModel, AutoTokenizer

from paddlenlp.data import Stack, Tuple, Pad

from paddlenlp.datasets import load_dataset

from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,

tokenizer,

max_seq_length=512,

pad_to_max_seq_len=False):

result = []

for key, text in example.items():

encoded_inputs = tokenizer(text=text,

max_seq_len=max_seq_length,

pad_to_max_seq_len=pad_to_max_seq_len)

input_ids = encoded_inputs["input_ids"]

token_type_ids = encoded_inputs["token_type_ids"]

result += [input_ids, token_type_ids]

return result

model_dir='./output/yysy/'

corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'

max_seq_length=64

batch_size=64

device='gpu'

cpu_threads=8

model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):

def init(self,

model_dir,

device="gpu",

max_seq_length=128,

batch_size=32,

use_tensorrt=False,

precision="fp32",

cpu_threads=10,

enable_mkldnn=False):

self.max_seq_length = max_seq_length

self.batch_size = batch_size

model_file = model_dir + "inference.get_pooled_embedding.pdmodel"

params_file = model_dir + "inference.get_pooled_embedding.pdiparams"

if not os.path.exists(model_file):

raise ValueError("not find model file path {}".format(model_file))

if not os.path.exists(params_file):

raise ValueError("not find params file path {}".format(params_file))

config = paddle.inference.Config(model_file, params_file)

if device == "gpu":

config.enable_use_gpu(100, 0)

precision_map = {

"fp16": inference.PrecisionType.Half,

"fp32": inference.PrecisionType.Float32,

"int8": inference.PrecisionType.Int8

}

precision_mode = precision_map[precision]

if use_tensorrt:

config.enable_tensorrt_engine(max_batch_size=batch_size,

min_subgraph_size=30,

precision_mode=precision_mode)

elif device == "cpu":

config.disable_gpu()

if enable_mkldnn:

config.set_mkldnn_cache_capacity(10)

config.enable_mkldnn()

config.set_cpu_math_library_num_threads(cpu_threads)

elif device == "xpu":

config.enable_xpu(100)

config.switch_use_feed_fetch_ops(False)

self.predictor = paddle.inference.create_predictor(config)

self.input_handles = [

self.predictor.get_input_handle(name)

for name in self.predictor.get_input_names()

]

self.output_handle = self.predictor.get_output_handle(

self.predictor.get_output_names()[0])

def predict(self, data, tokenizer):

batchify_fn = lambda samples, fn=Tuple(

Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input

Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"

), # segment

): fn(samples)

all_embeddings = []

examples = []

for idx, text in enumerate(tqdm(data)):

input_ids, segment_ids = convert_example(

text,

tokenizer,

max_seq_length=self.max_seq_length,

pad_to_max_seq_len=True)

examples.append((input_ids, segment_ids))

if (len(examples) >=self.batch_size):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

examples = []

if (len(examples) > 0):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

all_embeddings = np.concatenate(all_embeddings, axis=0)

np.save('yysy_corpus_embedding', all_embeddings)

def read_text(file_path):

file = open(file_path)

id2corpus = {}

for idx, data in enumerate(file.readlines()):

id2corpus[idx] = data.strip()

return id2corpus

predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

id2corpus = read_text(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'

MILVUS_PORT = 19530

data_dim = 256

top_k = 20

collection_name = 'literature_search'

partition_tag = 'partition_1'

embedding_name = 'embeddings'

index_config = {

"index_type": "IVF_FLAT",

"metric_type": "L2",

"params": {

"nlist": 1000

},

}

search_params = {

"metric_type": "L2",

"params": {

"nprobe": top_k

},

}

from pymilvus import (

connections,

utility,

FieldSchema,

CollectionSchema,

DataType,

Collection,

)

fmt = "\n=== {:30} ===\n"

text_max_len = 1000

fields = [

FieldSchema(name="pk",

dtype=DataType.INT64,

is_primary=True,#主键

auto_id=False,#不自动增长

max_length=100),#id

FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text

FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding

]

schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def has_collection(self, collection_name):

try:

has = utility.has_collection(collection_name)

print(f"Does collection {collection_name} exist in Milvus: {has}")

return has

except Exception as e:

print("Milvus has_table error:", e)

def creat_collection(self, collection_name):

try:

print(fmt.format("Create collection {}".format(collection_name)))

self.collection = Collection(collection_name,

schema,

consistency_level="Strong")

except Exception as e:

print("Milvus create collection error:", e)

def drop_collection(self, collection_name):

try:

utility.drop_collection(collection_name)

except Exception as e:

print("Milvus delete collection error:", e)

def create_index(self, index_name):

try:

print(fmt.format("Start Creating index"))

self.collection.create_index(index_name, index_config)

print(fmt.format("Start loading"))

self.collection.load()

except Exception as e:

print("Milvus create index error:", e)

def has_partition(self, partition_tag):

try:

result = self.collection.has_partition(partition_tag)

return result

except Exception as e:

print("Milvus has partition error: ", e)

def create_partition(self, partition_tag):

try:

self.collection.create_partition(partition_tag)

print('create partition {} successfully'.format(partition_tag))

except Exception as e:

print('Milvus create partition error: ', e)

def insert(self, entities, collection_name, index_name, partition_tag=None):

try:

if not self.has_collection(collection_name):

self.creat_collection(collection_name)

self.create_index(index_name)

else:

self.collection = Collection(collection_name)

if (partition_tag

is not None) and (not self.has_partition(partition_tag)):

self.create_partition(partition_tag)

self.collection.insert(entities, partition_name=partition_tag)

print(

f"Number of entities in Milvus: {self.collection.num_entities}"

) # check the num_entites

except Exception as e:

print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def get_collection(self, collection_name):

try:

print(fmt.format("Connect collection {}".format(collection_name)))

self.collection = Collection(collection_name)

except Exception as e:

print("Milvus create collection error:", e)

def search(self,

vectors,

embedding_name,

collection_name,

partition_names=[],

output_fields=[]):

try:

self.get_collection(collection_name)

result = self.collection.search(vectors,

embedding_name,

search_params,

limit=top_k,

partition_names=partition_names,

output_fields=output_fields)

return result

except Exception as e:

print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'

index=18

batch_size=5000

def read_text(file_path):

file = open(file_path)

id2corpus = []

for idx, data in enumerate(file.readlines()):

id2corpus.append(data.strip())

return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):

print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....

cur_end = i + batch_size

if (cur_end > data_size):#确保下标不越界

cur_end = data_size

batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量

entities = [

j for j in range(i, cur_end, 1)\],#索引 \[corpus_list_embed\[j\]\[:text_max_len - 1\] for j in range(i, cur_end, 1)\],#文本 batch_emb #每个批次嵌入向量

client.insert(collection_name=collection_name,

entities=entities,

index_name=embedding_name,

partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start

result = recall_client.search(embeddings,

embedding_name,

collection_name,

partition_names=[partition_tag],

output_fields=['pk', 'text'])

time_end = time.time()# end

sum_t = time_end - time_start

print('time cost', sum_t, 's')

for hits in result:

for hit in hits:

print(f"hit: {hit}, text field: {hit.entity.get('text')}")

相关推荐
聚客AI4 小时前
系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型
人工智能·pytorch·python·rnn·神经网络·机器学习·自然语言处理
盛寒10 小时前
词法分析和词性标注 自然语言处理
人工智能·自然语言处理
橙子小哥的代码世界12 小时前
【大模型RAG】Docker 一键部署 Milvus 完整攻略
linux·docker·大模型·milvus·向量数据库·rag
pen-ai12 小时前
【NLP】 38. Agent
人工智能·自然语言处理
春末的南方城市13 小时前
腾讯开源视频生成工具 HunyuanVideo-Avatar,上传一张图+一段音频,就能让图中的人物、动物甚至虚拟角色“活”过来,开口说话、唱歌、演相声!
人工智能·计算机视觉·自然语言处理·aigc·音视频·视频生成
UQI-LIUWJ13 小时前
论文笔记:Urban Computing in the Era of Large Language Models
人工智能·语言模型·自然语言处理
yvestine1 天前
自然语言处理——文本分类
自然语言处理·分类·文本分类·评价指标·pr·roc
码界奇点1 天前
Python Flask文件处理与异常处理实战指南
开发语言·python·自然语言处理·flask·python3.11
伪_装1 天前
大语言模型(LLM)面试问题集
人工智能·语言模型·自然语言处理
yvestine1 天前
基于规则的自然语言处理
自然语言处理·中文分词·规则方法