抽取语料库索引语义向量并建milvus库

import random

import time

import os

import sys

from tqdm import tqdm

import numpy as np

import paddle

from paddle import inference

from paddlenlp.transformers import AutoModel, AutoTokenizer

from paddlenlp.data import Stack, Tuple, Pad

from paddlenlp.datasets import load_dataset

from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,

tokenizer,

max_seq_length=512,

pad_to_max_seq_len=False):

result = []

for key, text in example.items():

encoded_inputs = tokenizer(text=text,

max_seq_len=max_seq_length,

pad_to_max_seq_len=pad_to_max_seq_len)

input_ids = encoded_inputs["input_ids"]

token_type_ids = encoded_inputs["token_type_ids"]

result += [input_ids, token_type_ids]

return result

model_dir='./output/yysy/'

corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'

max_seq_length=64

batch_size=64

device='gpu'

cpu_threads=8

model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):

def init(self,

model_dir,

device="gpu",

max_seq_length=128,

batch_size=32,

use_tensorrt=False,

precision="fp32",

cpu_threads=10,

enable_mkldnn=False):

self.max_seq_length = max_seq_length

self.batch_size = batch_size

model_file = model_dir + "inference.get_pooled_embedding.pdmodel"

params_file = model_dir + "inference.get_pooled_embedding.pdiparams"

if not os.path.exists(model_file):

raise ValueError("not find model file path {}".format(model_file))

if not os.path.exists(params_file):

raise ValueError("not find params file path {}".format(params_file))

config = paddle.inference.Config(model_file, params_file)

if device == "gpu":

config.enable_use_gpu(100, 0)

precision_map = {

"fp16": inference.PrecisionType.Half,

"fp32": inference.PrecisionType.Float32,

"int8": inference.PrecisionType.Int8

}

precision_mode = precision_map[precision]

if use_tensorrt:

config.enable_tensorrt_engine(max_batch_size=batch_size,

min_subgraph_size=30,

precision_mode=precision_mode)

elif device == "cpu":

config.disable_gpu()

if enable_mkldnn:

config.set_mkldnn_cache_capacity(10)

config.enable_mkldnn()

config.set_cpu_math_library_num_threads(cpu_threads)

elif device == "xpu":

config.enable_xpu(100)

config.switch_use_feed_fetch_ops(False)

self.predictor = paddle.inference.create_predictor(config)

self.input_handles = [

self.predictor.get_input_handle(name)

for name in self.predictor.get_input_names()

]

self.output_handle = self.predictor.get_output_handle(

self.predictor.get_output_names()[0])

def predict(self, data, tokenizer):

batchify_fn = lambda samples, fn=Tuple(

Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input

Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"

), # segment

): fn(samples)

all_embeddings = []

examples = []

for idx, text in enumerate(tqdm(data)):

input_ids, segment_ids = convert_example(

text,

tokenizer,

max_seq_length=self.max_seq_length,

pad_to_max_seq_len=True)

examples.append((input_ids, segment_ids))

if (len(examples) >=self.batch_size):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

examples = []

if (len(examples) > 0):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

all_embeddings = np.concatenate(all_embeddings, axis=0)

np.save('yysy_corpus_embedding', all_embeddings)

def read_text(file_path):

file = open(file_path)

id2corpus = {}

for idx, data in enumerate(file.readlines()):

id2corpus[idx] = data.strip()

return id2corpus

predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

id2corpus = read_text(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'

MILVUS_PORT = 19530

data_dim = 256

top_k = 20

collection_name = 'literature_search'

partition_tag = 'partition_1'

embedding_name = 'embeddings'

index_config = {

"index_type": "IVF_FLAT",

"metric_type": "L2",

"params": {

"nlist": 1000

},

}

search_params = {

"metric_type": "L2",

"params": {

"nprobe": top_k

},

}

from pymilvus import (

connections,

utility,

FieldSchema,

CollectionSchema,

DataType,

Collection,

)

fmt = "\n=== {:30} ===\n"

text_max_len = 1000

fields = [

FieldSchema(name="pk",

dtype=DataType.INT64,

is_primary=True,#主键

auto_id=False,#不自动增长

max_length=100),#id

FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text

FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding

]

schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def has_collection(self, collection_name):

try:

has = utility.has_collection(collection_name)

print(f"Does collection {collection_name} exist in Milvus: {has}")

return has

except Exception as e:

print("Milvus has_table error:", e)

def creat_collection(self, collection_name):

try:

print(fmt.format("Create collection {}".format(collection_name)))

self.collection = Collection(collection_name,

schema,

consistency_level="Strong")

except Exception as e:

print("Milvus create collection error:", e)

def drop_collection(self, collection_name):

try:

utility.drop_collection(collection_name)

except Exception as e:

print("Milvus delete collection error:", e)

def create_index(self, index_name):

try:

print(fmt.format("Start Creating index"))

self.collection.create_index(index_name, index_config)

print(fmt.format("Start loading"))

self.collection.load()

except Exception as e:

print("Milvus create index error:", e)

def has_partition(self, partition_tag):

try:

result = self.collection.has_partition(partition_tag)

return result

except Exception as e:

print("Milvus has partition error: ", e)

def create_partition(self, partition_tag):

try:

self.collection.create_partition(partition_tag)

print('create partition {} successfully'.format(partition_tag))

except Exception as e:

print('Milvus create partition error: ', e)

def insert(self, entities, collection_name, index_name, partition_tag=None):

try:

if not self.has_collection(collection_name):

self.creat_collection(collection_name)

self.create_index(index_name)

else:

self.collection = Collection(collection_name)

if (partition_tag

is not None) and (not self.has_partition(partition_tag)):

self.create_partition(partition_tag)

self.collection.insert(entities, partition_name=partition_tag)

print(

f"Number of entities in Milvus: {self.collection.num_entities}"

) # check the num_entites

except Exception as e:

print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def get_collection(self, collection_name):

try:

print(fmt.format("Connect collection {}".format(collection_name)))

self.collection = Collection(collection_name)

except Exception as e:

print("Milvus create collection error:", e)

def search(self,

vectors,

embedding_name,

collection_name,

partition_names=[],

output_fields=[]):

try:

self.get_collection(collection_name)

result = self.collection.search(vectors,

embedding_name,

search_params,

limit=top_k,

partition_names=partition_names,

output_fields=output_fields)

return result

except Exception as e:

print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'

index=18

batch_size=5000

def read_text(file_path):

file = open(file_path)

id2corpus = []

for idx, data in enumerate(file.readlines()):

id2corpus.append(data.strip())

return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):

print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....

cur_end = i + batch_size

if (cur_end > data_size):#确保下标不越界

cur_end = data_size

batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量

entities = [

j for j in range(i, cur_end, 1)\],#索引 \[corpus_list_embed\[j\]\[:text_max_len - 1\] for j in range(i, cur_end, 1)\],#文本 batch_emb #每个批次嵌入向量

client.insert(collection_name=collection_name,

entities=entities,

index_name=embedding_name,

partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start

result = recall_client.search(embeddings,

embedding_name,

collection_name,

partition_names=[partition_tag],

output_fields=['pk', 'text'])

time_end = time.time()# end

sum_t = time_end - time_start

print('time cost', sum_t, 's')

for hits in result:

for hit in hits:

print(f"hit: {hit}, text field: {hit.entity.get('text')}")

相关推荐
AI木马人2 小时前
1.【AI系统架构设计】如何设计一个高效、安全的人性化AI工具系统?(从0到1完整方案)
人工智能·深度学习·神经网络·计算机视觉·自然语言处理
YiRan_Zhao3 小时前
milvus面试题
milvus
周末也要写八哥4 小时前
大语言模型的“自我迭代”
人工智能·语言模型·自然语言处理
财经资讯数据_灵砚智能4 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年4月24日
人工智能·python·信息可视化·自然语言处理·ai编程
许彰午4 小时前
# 约94万条热线问题怎么去重?动态相似度阈值+Milvus,不用LLM一毛钱
人工智能·milvus
AI木马人4 小时前
2.【多模型接入架构】如何同时接入GPT、Gemini、Claude并统一管理?(完整实现方案)
人工智能·gpt·深度学习·神经网络·自然语言处理
Zzj_tju6 小时前
大语言模型部署实战:生产环境怎么做高并发、监控、限流与故障恢复?
人工智能·语言模型·自然语言处理
程序员老邢7 小时前
【技术底稿 23】Ollama + Docker + Ubuntu 部署踩坑实录:网络通了,参数还在调
java·经验分享·后端·ubuntu·docker·容器·milvus
格鸰爱童话7 小时前
python使用milvus向量库
python·milvus
阿杰学AI7 小时前
AI核心知识140—大语言模型之 推理期算力(简洁且通俗易懂版)
人工智能·语言模型·自然语言处理·思维链·思维树·慢思考·推理期算力