抽取语料库索引语义向量并建milvus库

import random

import time

import os

import sys

from tqdm import tqdm

import numpy as np

import paddle

from paddle import inference

from paddlenlp.transformers import AutoModel, AutoTokenizer

from paddlenlp.data import Stack, Tuple, Pad

from paddlenlp.datasets import load_dataset

from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,

tokenizer,

max_seq_length=512,

pad_to_max_seq_len=False):

result = []

for key, text in example.items():

encoded_inputs = tokenizer(text=text,

max_seq_len=max_seq_length,

pad_to_max_seq_len=pad_to_max_seq_len)

input_ids = encoded_inputs["input_ids"]

token_type_ids = encoded_inputs["token_type_ids"]

result += [input_ids, token_type_ids]

return result

model_dir='./output/yysy/'

corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'

max_seq_length=64

batch_size=64

device='gpu'

cpu_threads=8

model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):

def init(self,

model_dir,

device="gpu",

max_seq_length=128,

batch_size=32,

use_tensorrt=False,

precision="fp32",

cpu_threads=10,

enable_mkldnn=False):

self.max_seq_length = max_seq_length

self.batch_size = batch_size

model_file = model_dir + "inference.get_pooled_embedding.pdmodel"

params_file = model_dir + "inference.get_pooled_embedding.pdiparams"

if not os.path.exists(model_file):

raise ValueError("not find model file path {}".format(model_file))

if not os.path.exists(params_file):

raise ValueError("not find params file path {}".format(params_file))

config = paddle.inference.Config(model_file, params_file)

if device == "gpu":

config.enable_use_gpu(100, 0)

precision_map = {

"fp16": inference.PrecisionType.Half,

"fp32": inference.PrecisionType.Float32,

"int8": inference.PrecisionType.Int8

}

precision_mode = precision_map[precision]

if use_tensorrt:

config.enable_tensorrt_engine(max_batch_size=batch_size,

min_subgraph_size=30,

precision_mode=precision_mode)

elif device == "cpu":

config.disable_gpu()

if enable_mkldnn:

config.set_mkldnn_cache_capacity(10)

config.enable_mkldnn()

config.set_cpu_math_library_num_threads(cpu_threads)

elif device == "xpu":

config.enable_xpu(100)

config.switch_use_feed_fetch_ops(False)

self.predictor = paddle.inference.create_predictor(config)

self.input_handles = [

self.predictor.get_input_handle(name)

for name in self.predictor.get_input_names()

]

self.output_handle = self.predictor.get_output_handle(

self.predictor.get_output_names()[0])

def predict(self, data, tokenizer):

batchify_fn = lambda samples, fn=Tuple(

Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input

Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"

), # segment

): fn(samples)

all_embeddings = []

examples = []

for idx, text in enumerate(tqdm(data)):

input_ids, segment_ids = convert_example(

text,

tokenizer,

max_seq_length=self.max_seq_length,

pad_to_max_seq_len=True)

examples.append((input_ids, segment_ids))

if (len(examples) >=self.batch_size):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

examples = []

if (len(examples) > 0):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

all_embeddings = np.concatenate(all_embeddings, axis=0)

np.save('yysy_corpus_embedding', all_embeddings)

def read_text(file_path):

file = open(file_path)

id2corpus = {}

for idx, data in enumerate(file.readlines()):

id2corpus[idx] = data.strip()

return id2corpus

predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

id2corpus = read_text(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'

MILVUS_PORT = 19530

data_dim = 256

top_k = 20

collection_name = 'literature_search'

partition_tag = 'partition_1'

embedding_name = 'embeddings'

index_config = {

"index_type": "IVF_FLAT",

"metric_type": "L2",

"params": {

"nlist": 1000

},

}

search_params = {

"metric_type": "L2",

"params": {

"nprobe": top_k

},

}

from pymilvus import (

connections,

utility,

FieldSchema,

CollectionSchema,

DataType,

Collection,

)

fmt = "\n=== {:30} ===\n"

text_max_len = 1000

fields = [

FieldSchema(name="pk",

dtype=DataType.INT64,

is_primary=True,#主键

auto_id=False,#不自动增长

max_length=100),#id

FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text

FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding

]

schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def has_collection(self, collection_name):

try:

has = utility.has_collection(collection_name)

print(f"Does collection {collection_name} exist in Milvus: {has}")

return has

except Exception as e:

print("Milvus has_table error:", e)

def creat_collection(self, collection_name):

try:

print(fmt.format("Create collection {}".format(collection_name)))

self.collection = Collection(collection_name,

schema,

consistency_level="Strong")

except Exception as e:

print("Milvus create collection error:", e)

def drop_collection(self, collection_name):

try:

utility.drop_collection(collection_name)

except Exception as e:

print("Milvus delete collection error:", e)

def create_index(self, index_name):

try:

print(fmt.format("Start Creating index"))

self.collection.create_index(index_name, index_config)

print(fmt.format("Start loading"))

self.collection.load()

except Exception as e:

print("Milvus create index error:", e)

def has_partition(self, partition_tag):

try:

result = self.collection.has_partition(partition_tag)

return result

except Exception as e:

print("Milvus has partition error: ", e)

def create_partition(self, partition_tag):

try:

self.collection.create_partition(partition_tag)

print('create partition {} successfully'.format(partition_tag))

except Exception as e:

print('Milvus create partition error: ', e)

def insert(self, entities, collection_name, index_name, partition_tag=None):

try:

if not self.has_collection(collection_name):

self.creat_collection(collection_name)

self.create_index(index_name)

else:

self.collection = Collection(collection_name)

if (partition_tag

is not None) and (not self.has_partition(partition_tag)):

self.create_partition(partition_tag)

self.collection.insert(entities, partition_name=partition_tag)

print(

f"Number of entities in Milvus: {self.collection.num_entities}"

) # check the num_entites

except Exception as e:

print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def get_collection(self, collection_name):

try:

print(fmt.format("Connect collection {}".format(collection_name)))

self.collection = Collection(collection_name)

except Exception as e:

print("Milvus create collection error:", e)

def search(self,

vectors,

embedding_name,

collection_name,

partition_names=[],

output_fields=[]):

try:

self.get_collection(collection_name)

result = self.collection.search(vectors,

embedding_name,

search_params,

limit=top_k,

partition_names=partition_names,

output_fields=output_fields)

return result

except Exception as e:

print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'

index=18

batch_size=5000

def read_text(file_path):

file = open(file_path)

id2corpus = []

for idx, data in enumerate(file.readlines()):

id2corpus.append(data.strip())

return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):

print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....

cur_end = i + batch_size

if (cur_end > data_size):#确保下标不越界

cur_end = data_size

batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量

entities = [

j for j in range(i, cur_end, 1)\],#索引 \[corpus_list_embed\[j\]\[:text_max_len - 1\] for j in range(i, cur_end, 1)\],#文本 batch_emb #每个批次嵌入向量

client.insert(collection_name=collection_name,

entities=entities,

index_name=embedding_name,

partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start

result = recall_client.search(embeddings,

embedding_name,

collection_name,

partition_names=[partition_tag],

output_fields=['pk', 'text'])

time_end = time.time()# end

sum_t = time_end - time_start

print('time cost', sum_t, 's')

for hits in result:

for hit in hits:

print(f"hit: {hit}, text field: {hit.entity.get('text')}")

相关推荐
沉浸式学习ing2 小时前
B站视频怎么快速总结?AI自动生成要点+思维导图+逐字稿
人工智能·ai·自然语言处理·音视频·语音识别·notion
隔窗听雨眠3 小时前
基于Milvus混合检索与Java SpringBoot的全栈实现
milvus
tzc_fly7 小时前
LLaDA2.0:块扩散语言模型
人工智能·语言模型·自然语言处理
毋语天7 小时前
从零搭建 RAG 系统:Milvus 向量数据库 + 大模型完整实战指南
数据库·milvus
财经资讯数据_灵砚智能8 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年5月14日
人工智能·python·信息可视化·自然语言处理·ai编程
k09338 小时前
免费大语言模型API平台汇总指南(2026年最新)
人工智能·语言模型·自然语言处理
Hui_AI72011 小时前
电商桌面自动化实战:用RPA实现抖店批量铺货
运维·开发语言·人工智能·自然语言处理·自动化·开源软件·rpa
财经资讯数据_灵砚智能12 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月15日
人工智能·python·信息可视化·自然语言处理·ai编程
victory043112 小时前
DeepSeek-R1:通过强化学习激励大语言模型的推理能力 技术报告中文翻译
人工智能·语言模型·自然语言处理
kishu_iOS&AI1 天前
NLP —— 英译法实例
人工智能·ai·自然语言处理