抽取语料库索引语义向量并建milvus库

import random

import time

import os

import sys

from tqdm import tqdm

import numpy as np

import paddle

from paddle import inference

from paddlenlp.transformers import AutoModel, AutoTokenizer

from paddlenlp.data import Stack, Tuple, Pad

from paddlenlp.datasets import load_dataset

from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,

tokenizer,

max_seq_length=512,

pad_to_max_seq_len=False):

result = []

for key, text in example.items():

encoded_inputs = tokenizer(text=text,

max_seq_len=max_seq_length,

pad_to_max_seq_len=pad_to_max_seq_len)

input_ids = encoded_inputs["input_ids"]

token_type_ids = encoded_inputs["token_type_ids"]

result += [input_ids, token_type_ids]

return result

model_dir='./output/yysy/'

corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'

max_seq_length=64

batch_size=64

device='gpu'

cpu_threads=8

model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):

def init(self,

model_dir,

device="gpu",

max_seq_length=128,

batch_size=32,

use_tensorrt=False,

precision="fp32",

cpu_threads=10,

enable_mkldnn=False):

self.max_seq_length = max_seq_length

self.batch_size = batch_size

model_file = model_dir + "inference.get_pooled_embedding.pdmodel"

params_file = model_dir + "inference.get_pooled_embedding.pdiparams"

if not os.path.exists(model_file):

raise ValueError("not find model file path {}".format(model_file))

if not os.path.exists(params_file):

raise ValueError("not find params file path {}".format(params_file))

config = paddle.inference.Config(model_file, params_file)

if device == "gpu":

config.enable_use_gpu(100, 0)

precision_map = {

"fp16": inference.PrecisionType.Half,

"fp32": inference.PrecisionType.Float32,

"int8": inference.PrecisionType.Int8

}

precision_mode = precision_map[precision]

if use_tensorrt:

config.enable_tensorrt_engine(max_batch_size=batch_size,

min_subgraph_size=30,

precision_mode=precision_mode)

elif device == "cpu":

config.disable_gpu()

if enable_mkldnn:

config.set_mkldnn_cache_capacity(10)

config.enable_mkldnn()

config.set_cpu_math_library_num_threads(cpu_threads)

elif device == "xpu":

config.enable_xpu(100)

config.switch_use_feed_fetch_ops(False)

self.predictor = paddle.inference.create_predictor(config)

self.input_handles = [

self.predictor.get_input_handle(name)

for name in self.predictor.get_input_names()

]

self.output_handle = self.predictor.get_output_handle(

self.predictor.get_output_names()[0])

def predict(self, data, tokenizer):

batchify_fn = lambda samples, fn=Tuple(

Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input

Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"

), # segment

): fn(samples)

all_embeddings = []

examples = []

for idx, text in enumerate(tqdm(data)):

input_ids, segment_ids = convert_example(

text,

tokenizer,

max_seq_length=self.max_seq_length,

pad_to_max_seq_len=True)

examples.append((input_ids, segment_ids))

if (len(examples) >=self.batch_size):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

examples = []

if (len(examples) > 0):

input_ids, segment_ids = batchify_fn(examples)

self.input_handles[0].copy_from_cpu(input_ids)

self.input_handles[1].copy_from_cpu(segment_ids)

self.predictor.run()

logits = self.output_handle.copy_to_cpu()

all_embeddings.append(logits)

all_embeddings = np.concatenate(all_embeddings, axis=0)

np.save('yysy_corpus_embedding', all_embeddings)

def read_text(file_path):

file = open(file_path)

id2corpus = {}

for idx, data in enumerate(file.readlines()):

id2corpus[idx] = data.strip()

return id2corpus

predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

id2corpus = read_text(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'

MILVUS_PORT = 19530

data_dim = 256

top_k = 20

collection_name = 'literature_search'

partition_tag = 'partition_1'

embedding_name = 'embeddings'

index_config = {

"index_type": "IVF_FLAT",

"metric_type": "L2",

"params": {

"nlist": 1000

},

}

search_params = {

"metric_type": "L2",

"params": {

"nprobe": top_k

},

}

from pymilvus import (

connections,

utility,

FieldSchema,

CollectionSchema,

DataType,

Collection,

)

fmt = "\n=== {:30} ===\n"

text_max_len = 1000

fields = [

FieldSchema(name="pk",

dtype=DataType.INT64,

is_primary=True,#主键

auto_id=False,#不自动增长

max_length=100),#id

FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text

FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding

]

schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def has_collection(self, collection_name):

try:

has = utility.has_collection(collection_name)

print(f"Does collection {collection_name} exist in Milvus: {has}")

return has

except Exception as e:

print("Milvus has_table error:", e)

def creat_collection(self, collection_name):

try:

print(fmt.format("Create collection {}".format(collection_name)))

self.collection = Collection(collection_name,

schema,

consistency_level="Strong")

except Exception as e:

print("Milvus create collection error:", e)

def drop_collection(self, collection_name):

try:

utility.drop_collection(collection_name)

except Exception as e:

print("Milvus delete collection error:", e)

def create_index(self, index_name):

try:

print(fmt.format("Start Creating index"))

self.collection.create_index(index_name, index_config)

print(fmt.format("Start loading"))

self.collection.load()

except Exception as e:

print("Milvus create index error:", e)

def has_partition(self, partition_tag):

try:

result = self.collection.has_partition(partition_tag)

return result

except Exception as e:

print("Milvus has partition error: ", e)

def create_partition(self, partition_tag):

try:

self.collection.create_partition(partition_tag)

print('create partition {} successfully'.format(partition_tag))

except Exception as e:

print('Milvus create partition error: ', e)

def insert(self, entities, collection_name, index_name, partition_tag=None):

try:

if not self.has_collection(collection_name):

self.creat_collection(collection_name)

self.create_index(index_name)

else:

self.collection = Collection(collection_name)

if (partition_tag

is not None) and (not self.has_partition(partition_tag)):

self.create_partition(partition_tag)

self.collection.insert(entities, partition_name=partition_tag)

print(

f"Number of entities in Milvus: {self.collection.num_entities}"

) # check the num_entites

except Exception as e:

print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量

def init(self):

print(fmt.format("start connecting to Milvus"))

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

self.collection = None

def get_collection(self, collection_name):

try:

print(fmt.format("Connect collection {}".format(collection_name)))

self.collection = Collection(collection_name)

except Exception as e:

print("Milvus create collection error:", e)

def search(self,

vectors,

embedding_name,

collection_name,

partition_names=[],

output_fields=[]):

try:

self.get_collection(collection_name)

result = self.collection.search(vectors,

embedding_name,

search_params,

limit=top_k,

partition_names=partition_names,

output_fields=output_fields)

return result

except Exception as e:

print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'

index=18

batch_size=5000

def read_text(file_path):

file = open(file_path)

id2corpus = []

for idx, data in enumerate(file.readlines()):

id2corpus.append(data.strip())

return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):

print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....

cur_end = i + batch_size

if (cur_end > data_size):#确保下标不越界

cur_end = data_size

batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量

entities = [

j for j in range(i, cur_end, 1)\],#索引 \[corpus_list_embed\[j\]\[:text_max_len - 1\] for j in range(i, cur_end, 1)\],#文本 batch_emb #每个批次嵌入向量

client.insert(collection_name=collection_name,

entities=entities,

index_name=embedding_name,

partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start

result = recall_client.search(embeddings,

embedding_name,

collection_name,

partition_names=[partition_tag],

output_fields=['pk', 'text'])

time_end = time.time()# end

sum_t = time_end - time_start

print('time cost', sum_t, 's')

for hits in result:

for hit in hits:

print(f"hit: {hit}, text field: {hit.entity.get('text')}")

相关推荐
TracyCoder12320 小时前
词嵌入来龙去脉:One-hot、Word2Vec、GloVe、ELMo
人工智能·自然语言处理·word2vec
dog2501 天前
LLM(大语言模型)和高尔顿板
人工智能·语言模型·自然语言处理·高尔顿板
2401_841495641 天前
【自然语言处理】自然语言理解:从技术基础到多元应用的全景探索
人工智能·python·自然语言处理·语音助手·翻译工具·自然语言理解·企业服务
阿杰学AI1 天前
AI核心知识52——大语言模型之Model Quantization(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·模型量化·ai-native
FF-Studio1 天前
解决 NVIDIA RTX 50 系列 (sm_120) 架构下的 PyTorch 与 Unsloth 依赖冲突
pytorch·自然语言处理·cuda·unsloth·rtx 50 series
努力毕业的小土博^_^1 天前
【AI课程领学】基于SmolVLM2与Qwen3的多模态模型拼接实践:从零构建视觉语言模型(一)
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
fishfuck1 天前
MMEvol: Empowering Multimodal Large Language Models with Evol-Instruct
人工智能·语言模型·自然语言处理
阿正的梦工坊2 天前
ProRL:延长强化学习训练,扩展大语言模型推理边界——NeurIPS 2025论文解读
人工智能·语言模型·自然语言处理
MARS_AI_2 天前
大模型呼叫技术:客服行业的智能化演进与云蝠实践
人工智能·自然语言处理·交互·信息与通信·agi
渡我白衣2 天前
AI应用层革命(六)——智能体的伦理边界与法律框架:当机器开始“做决定”
人工智能·深度学习·神经网络·机器学习·计算机视觉·自然语言处理·语音识别