1. 创建RAG集合
复制代码
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
connections.connect("default", host="localhost", port="19530")
# RAG专用集合设计
def create_rag_collection(collection_name="rag_documents"):
fields = [
# 主键 - 使用文档哈希值确保唯一性
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, is_primary=True, max_length=64),
# 向量字段 - 用于语义搜索
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768), # BERT维度
# 文档元数据
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535), # 原始内容
FieldSchema(name="chunk_content", dtype=DataType.VARCHAR, max_length=2048), # 分块内容
FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=256), # 数据源
FieldSchema(name="doc_type", dtype=DataType.VARCHAR, max_length=64), # 文档类型
# 元数据字段
FieldSchema(name="chunk_index", dtype=DataType.INT64), # 分块索引
FieldSchema(name="total_chunks", dtype=DataType.INT64), # 总分块数
FieldSchema(name="created_time", dtype=DataType.INT64),
FieldSchema(name="updated_time", dtype=DataType.INT64),
# 标签和分类
FieldSchema(name="tags", dtype=DataType.ARRAY, element_type=DataType.VARCHAR,
max_capacity=50, max_length=64),
FieldSchema(name="categories", dtype=DataType.ARRAY, element_type=DataType.VARCHAR,
max_capacity=20, max_length=64),
# 权重和评分
FieldSchema(name="relevance_score", dtype=DataType.FLOAT), # 相关性评分
FieldSchema(name="access_count", dtype=DataType.INT64), # 访问次数
# 版本控制
FieldSchema(name="version", dtype=DataType.INT64, default_value=1)
]
schema = CollectionSchema(fields, description="RAG文档检索集合")
collection = Collection(name=collection_name, schema=schema)
# 创建向量索引
vector_index = {
"index_type": "HNSW", # 适合RAG的高性能索引
"metric_type": "IP", # 内积,适合余弦相似度
"params": {"M": 16, "efConstruction": 100}
}
collection.create_index(field_name="embedding", index_params=vector_index)
# 创建标量索引
scalar_indexes = ["source", "doc_type", "tags", "categories", "relevance_score"]
for field in scalar_indexes:
collection.create_index(field_name=field, index_params={"index_type": "INVERTED"})
return collection
if __name__ == '__main__':
# 创建RAG集合
rag_collection = create_rag_collection()
print("RAG集合创建完成")
2. 文档预处理与分块
复制代码
import hashlib
import re
from datetime import datetime
from typing import List, Dict
class DocumentProcessor:
"""
文档预处理和分块
"""
def __init__(self, chunk_size=512, chunk_overlap=50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def generate_doc_id(self, content: str, source: str = "") -> str:
"""生成文档唯一ID"""
content_hash = hashlib.md5(content.encode()).hexdigest()
return f"{source}_{content_hash}"[:64]
def split_text(self, text: str) -> List[str]:
"""智能文本分块"""
# 按段落分割
paragraphs = re.split(r'\n\s*\n', text.strip())
chunks = []
current_chunk = ""
for paragraph in paragraphs:
# 如果段落太长,进一步分割
if len(paragraph) > self.chunk_size:
sentences = re.split(r'[.!?]+', paragraph)
for sentence in sentences:
if len(current_chunk) + len(sentence) <= self.chunk_size:
current_chunk += sentence + ". "
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
# 处理超长句子
while len(current_chunk) > self.chunk_size:
chunks.append(current_chunk[:self.chunk_size])
current_chunk = current_chunk[self.chunk_size:]
else:
# 检查是否需要创建新块
if len(current_chunk) + len(paragraph) > self.chunk_size:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = paragraph
else:
current_chunk += "\n\n" + paragraph
# 添加最后一个块
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def process_document(self, content: str, title: str = "", source: str = "",
doc_type: str = "text", tags: List[str] = None,
categories: List[str] = None) -> List[Dict]:
"""处理文档并生成分块"""
doc_id = self.generate_doc_id(content, source)
chunks = self.split_text(content)
processed_chunks = []
timestamp = int(datetime.now().timestamp())
for i, chunk in enumerate(chunks):
chunk_doc_id = f"{doc_id}_chunk_{i}"
chunk_data = {
"doc_id": chunk_doc_id,
"content": content,
"chunk_content": chunk,
"title": title,
"source": source,
"doc_type": doc_type,
"chunk_index": i,
"total_chunks": len(chunks),
"created_time": timestamp,
"updated_time": timestamp,
"tags": tags or [],
"categories": categories or [],
"relevance_score": 1.0,
"access_count": 0,
"version": 1
}
processed_chunks.append(chunk_data)
return processed_chunks
def insert_documents(self, collection, documents: List[Dict], embedding_generator):
"""将文档插入Milvus"""
# 1. 生成向量嵌入
contents = [doc["chunk_content"] for doc in documents]
embeddings = embedding_generator.batch_generate_embeddings(contents)
# 2. 按照Milvus Schema顺序准备数据
insert_data = [
[doc["doc_id"] for doc in documents], # 0: doc_id
embeddings, # 1: embedding
[doc["content"] for doc in documents], # 2: content
[doc["chunk_content"] for doc in documents], # 3: chunk_content
[doc["title"] for doc in documents], # 4: title
[doc["source"] for doc in documents], # 5: source
[doc["doc_type"] for doc in documents], # 6: doc_type
[doc["chunk_index"] for doc in documents], # 7: chunk_index
[doc["total_chunks"] for doc in documents], # 8: total_chunks
[doc["created_time"] for doc in documents], # 9: created_time
[doc["updated_time"] for doc in documents], # 10: updated_time
[doc["tags"] for doc in documents], # 11: tags
[doc["categories"] for doc in documents], # 12: categories
[doc["relevance_score"] for doc in documents], # 13: relevance_score
[doc["access_count"] for doc in documents], # 14: access_count
[doc["version"] for doc in documents] # 15: version
]
# 3. 执行插入
result = collection.insert(insert_data)
collection.flush() # 确保数据持久化
return result
# 初始化文档处理器
doc_processor = DocumentProcessor(chunk_size=1024, chunk_overlap=100)
3. 向量嵌入生成器
复制代码
from typing import List
import numpy as np
class EmbeddingGenerator:
"""
向量嵌入生成器
"""
def __init__(self, dim=768):
self.dim = dim
def generate_embedding(self, text: str) -> List[float]:
"""生成文本向量(实际使用时替换为真实的embedding模型)"""
# 模拟BERT-like向量生成
# 实际应用中使用: sentence-transformers, OpenAI embedding等
np.random.seed(hash(text) % (2 ** 32)) # 确保相同文本生成相同向量
vector = np.random.randn(self.dim)
# 归一化向量
norm = np.linalg.norm(vector)
return (vector / norm).tolist() if norm > 0 else vector.tolist()
def batch_generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""批量生成向量"""
return [self.generate_embedding(text) for text in texts]
# 初始化向量生成器
embedding_generator = EmbeddingGenerator(dim=768)
4. 知识库管理和维护
复制代码
from datetime import datetime
from typing import Dict
from pymilvus import Collection
class KnowledgeBaseManager:
"""
知识库管理和维护
"""
def __init__(self, collection):
self.collection = collection
def get_collection_stats(self) -> Dict:
"""获取知识库统计信息"""
self.collection.flush() # 确保统计数据准确
# 查询各类统计
total_docs = self.collection.num_entities
# 按来源统计
sources = self.collection.query(
expr="",
output_fields=["source"],
limit=10000
)
source_stats = {}
for doc in sources:
source = doc.get("source", "unknown")
source_stats[source] = source_stats.get(source, 0) + 1
# 按类型统计
doc_types = self.collection.query(
expr="",
output_fields=["doc_type"],
limit=10000
)
type_stats = {}
for doc in doc_types:
doc_type = doc.get("doc_type", "unknown")
type_stats[doc_type] = type_stats.get(doc_type, 0) + 1
return {
"total_documents": total_docs,
"sources": source_stats,
"document_types": type_stats
}
def delete_documents_by_source(self, source: str) -> int:
"""按来源删除文档"""
expr = f'source == "{source}"'
result = self.collection.delete(expr=expr)
return len(result.primary_keys)
def update_document_relevance(self, doc_id: str, new_score: float):
"""更新文档相关性评分"""
# 查询原文档
original_docs = self.collection.query(
expr=f'doc_id == "{doc_id}"',
output_fields=["doc_id", "embedding", "content", "chunk_content", "title",
"source", "doc_type", "chunk_index", "total_chunks",
"created_time", "updated_time", "tags", "categories",
"access_count", "version"]
)
if not original_docs:
return False
original = original_docs[0]
# 删除原文档
self.collection.delete(expr=f'doc_id == "{doc_id}"')
# 更新评分并重新插入
original["relevance_score"] = new_score
original["updated_time"] = int(datetime.now().timestamp())
original["version"] = original.get("version", 1) + 1
# 重新插入
insert_data = [[original[field] for field in [
"doc_id", "embedding", "content", "chunk_content", "title",
"source", "doc_type", "chunk_index", "total_chunks",
"created_time", "updated_time", "tags", "categories",
"relevance_score", "access_count", "version"
]]]
self.collection.insert(insert_data)
return True
def cleanup_old_documents(self, days_old: int = 30) -> int:
"""清理旧文档"""
cutoff_time = int(datetime.now().timestamp()) - (days_old * 24 * 3600)
expr = f'created_time < {cutoff_time}'
result = self.collection.delete(expr=expr)
return len(result.primary_keys)
if __name__ == '__main__':
collection = Collection(name="rag_documents", using="default")
# 使用知识库管理器
kb_manager = KnowledgeBaseManager(collection)
# 获取统计信息
stats = kb_manager.get_collection_stats()
print("=== 知识库统计 ===")
print(f"总文档数: {stats['total_documents']}")
print("按来源统计:", stats['sources'])
print("按类型统计:", stats['document_types'])
# 更新文档评分示例
# kb_manager.update_document_relevance("some_doc_id", 4.5)
5. RAG流程
复制代码
from typing import List, Dict
from pymilvus import Collection
from tmp.milvus.rag.document_processor import doc_processor
from tmp.milvus.rag.embedding_generator import embedding_generator
from tmp.milvus.rag.rag_retriever import RAGRetriever
class RAGPipeline:
"""
RAG完整流程
"""
def __init__(self, collection, embedding_generator, doc_processor):
self.retriever = RAGRetriever(collection, embedding_generator)
self.collection = collection
self.embedding_generator = embedding_generator
self.doc_processor = doc_processor
def add_document(self, content: str, title: str = "", source: str = "",
doc_type: str = "text", tags: List[str] = None,
categories: List[str] = None):
"""添加文档到知识库"""
# 处理文档
chunks = self.doc_processor.process_document(
content=content,
title=title,
source=source,
doc_type=doc_type,
tags=tags,
categories=categories
)
# 插入到Milvus
result = self.doc_processor.insert_documents(self.collection, chunks, self.embedding_generator)
return result
def query(self, question: str, context_tokens: int = 2000) -> Dict:
"""RAG查询完整流程"""
# 1. 检索相关文档
context = self.retriever.get_context_for_llm(question, context_tokens)
# 2. 构建完整的prompt(这里只是示例,实际使用时需要根据LLM调整)
prompt = f"""
基于以下上下文回答问题。如果上下文中没有相关信息,请说明无法回答。
上下文:
{context}
问题:{question}
回答:
"""
# 3. 返回结果(实际应用中这里会调用LLM)
return {
"question": question,
"context": context,
"prompt": prompt,
"retrieved_docs": self.retriever.hybrid_search(question)
}
def update_document_access_count(self, doc_ids: List[str]):
"""更新文档访问计数"""
for doc_id in doc_ids:
# 在实际应用中,这里需要实现更新逻辑
# 可以通过先删除再插入的方式,或者使用Milvus的更新功能(如果支持)
pass
if __name__ == '__main__':
collection = Collection(name="rag_documents", using="default")
# 初始化RAG管道
rag_pipeline = RAGPipeline(collection, embedding_generator, doc_processor)
# 添加更多文档
new_docs = [
{
"title": "数据结构与算法",
"content": """数据结构是计算机存储、组织数据的方式。常见的数据结构包括数组、链表、栈、队列、树、图等。
算法是对特定问题求解步骤的描述。好的算法应该具有正确性、可读性、健壮性和高效性。
时间复杂度和空间复杂度是衡量算法效率的重要指标。
排序算法如快速排序、归并排序、堆排序等是算法学习的基础。""",
"source": "cs_fundamentals",
"doc_type": "educational",
"tags": ["data-structure", "algorithm", "computer-science"],
"categories": ["computer-science", "education"]
}
]
for doc in new_docs:
rag_pipeline.add_document(
content=doc["content"],
title=doc["title"],
source=doc["source"],
doc_type=doc["doc_type"],
tags=doc["tags"],
categories=doc["categories"]
)
# 执行RAG查询
result = rag_pipeline.query("什么是数据结构?请详细解释。")
print("=== RAG查询结果 ===")
print(f"问题: {result['question']}")
print(f"检索到的文档数量: {len(result['retrieved_docs'])}")
print(f"\n生成的Prompt预览:\n{result['prompt'][:500]}...")
# 显示检索到的文档
print("\n=== 检索到的相关文档 ===")
for i, doc in enumerate(result['retrieved_docs']):
print(f"{i + 1}. {doc['title']} (得分: {doc['final_score']:.3f})")
print(f" 内容: {doc['content'][:150]}...")
print()
6. RAG检索
复制代码
from typing import List, Dict
from pymilvus import Collection
from tmp.milvus.rag.embedding_generator import embedding_generator
class RAGRetriever:
"""
RAG检索操作
"""
def __init__(self, collection, embedding_generator, top_k=5, rerank_top_k=10):
self.collection = collection
self.embedding_generator = embedding_generator
self.top_k = top_k
self.rerank_top_k = rerank_top_k
self.collection.load() # 加载到内存
def semantic_search(self, query: str, filter_expr: str = None,
top_k: int = None) -> List[Dict]:
"""语义搜索"""
if top_k is None:
top_k = self.top_k
# 生成查询向量
query_vector = self.embedding_generator.generate_embedding(query)
# 搜索参数
search_params = {
"metric_type": "IP",
"params": {"ef": 64}
}
# 执行搜索
results = self.collection.search(
data=[query_vector],
anns_field="embedding",
param=search_params,
limit=top_k,
expr=filter_expr,
output_fields=[
"doc_id", "chunk_content", "title", "source", "doc_type",
"chunk_index", "relevance_score", "access_count", "tags", "categories"
]
)
# 处理结果
retrieved_docs = []
for hits in results:
for hit in hits:
doc_info = {
"doc_id": hit.entity.get("doc_id"),
"content": hit.entity.get("chunk_content"),
"title": hit.entity.get("title"),
"source": hit.entity.get("source"),
"doc_type": hit.entity.get("doc_type"),
"chunk_index": hit.entity.get("chunk_index"),
"relevance_score": hit.entity.get("relevance_score"),
"access_count": hit.entity.get("access_count"),
"tags": hit.entity.get("tags"),
"categories": hit.entity.get("categories"),
"similarity": hit.distance
}
retrieved_docs.append(doc_info)
return retrieved_docs
def hybrid_search(self, query: str, filters: Dict = None,
boost_factors: Dict = None) -> List[Dict]:
"""混合搜索:语义 + 元数据过滤"""
# 构建过滤表达式
filter_expr = self._build_filter_expression(filters)
# 执行语义搜索
semantic_results = self.semantic_search(query, filter_expr, self.rerank_top_k)
# 应用重排序和增强
reranked_results = self._rerank_results(semantic_results, query, boost_factors)
return reranked_results[:self.top_k]
def _build_filter_expression(self, filters: Dict) -> str:
"""构建过滤表达式"""
if not filters:
return None
conditions = []
# 源过滤
if "sources" in filters:
sources = filters["sources"]
if isinstance(sources, list):
sources_str = '", "'.join(sources)
conditions.append(f'source in ["{sources_str}"]')
else:
conditions.append(f'source == "{sources}"')
# 类型过滤
if "doc_types" in filters:
doc_types = filters["doc_types"]
if isinstance(doc_types, list):
types_str = '", "'.join(doc_types)
conditions.append(f'doc_type in ["{types_str}"]')
else:
conditions.append(f'doc_type == "{doc_types}"')
# 标签过滤
if "tags" in filters:
tags = filters["tags"]
if isinstance(tags, list):
for tag in tags:
conditions.append(f'array_contains(tags, "{tag}")')
else:
conditions.append(f'array_contains(tags, "{tags}")')
# 分类过滤
if "categories" in filters:
categories = filters["categories"]
if isinstance(categories, list):
for category in categories:
conditions.append(f'array_contains(categories, "{category}")')
else:
conditions.append(f'array_contains(categories, "{categories}")')
return " and ".join(conditions) if conditions else None
def _rerank_results(self, results: List[Dict], query: str,
boost_factors: Dict = None) -> List[Dict]:
"""结果重排序"""
if not boost_factors:
boost_factors = {
"similarity_weight": 0.7,
"relevance_weight": 0.2,
"popularity_weight": 0.1
}
# 计算综合得分
for result in results:
# 归一化相似度得分
normalized_similarity = (result["similarity"] + 1) / 2 # 转换到[0,1]
# 归一化相关性得分
normalized_relevance = result["relevance_score"] / 5.0 # 假设评分是1-5
# 归一化访问次数(需要先获取最大访问次数)
max_access = max([r.get("access_count", 0) for r in results] + [1])
normalized_popularity = result.get("access_count", 0) / max_access
# 计算综合得分
final_score = (
boost_factors["similarity_weight"] * normalized_similarity +
boost_factors["relevance_weight"] * normalized_relevance +
boost_factors["popularity_weight"] * normalized_popularity
)
result["final_score"] = final_score
# 按综合得分排序
return sorted(results, key=lambda x: x["final_score"], reverse=True)
def get_context_for_llm(self, query: str, max_tokens: int = 2000) -> str:
"""为LLM生成上下文"""
# 检索相关文档
retrieved_docs = self.hybrid_search(query)
# 构建上下文
context_parts = []
total_tokens = 0
for doc in retrieved_docs:
# 构建文档片段
doc_text = f"文档标题: {doc['title']}\n"
doc_text += f"内容: {doc['content']}\n"
doc_text += f"来源: {doc['source']}\n"
doc_text += f"相关性得分: {doc['final_score']:.3f}\n"
doc_text += "---\n"
# 简单的token估算(实际应用中使用tokenizer)
estimated_tokens = len(doc_text) // 4
if total_tokens + estimated_tokens <= max_tokens:
context_parts.append(doc_text)
total_tokens += estimated_tokens
else:
break
return "\n".join(context_parts)
if __name__ == '__main__':
# 初始化RAG检索器
collection = Collection(name="rag_documents", using="default")
rag_retriever = RAGRetriever(collection, embedding_generator, top_k=3)
# 测试检索
query = "Python编程语言的特点是什么?"
retrieved_docs = rag_retriever.semantic_search(query)
print("=== 语义搜索结果 ===")
for i, doc in enumerate(retrieved_docs):
print(f"{i + 1}. 标题: {doc['title']}")
print(f" 内容: {doc['content'][:100]}...")
print(f" 相似度: {doc['similarity']:.3f}")
print(f" 来源: {doc['source']}")
print("---")
# 混合搜索测试
filters = {
"sources": ["python_docs"],
"categories": ["programming"]
}
boost_factors = {
"similarity_weight": 0.6,
"relevance_weight": 0.3,
"popularity_weight": 0.1
}
filtered_results = rag_retriever.hybrid_search(query, filters, boost_factors)
print("\n=== 混合搜索结果 ===")
for i, doc in enumerate(filtered_results):
print(f"{i + 1}. 标题: {doc['title']}")
print(f" 内容: {doc['content'][:100]}...")
print(f" 综合得分: {doc['final_score']:.3f}")
print(f" 来源: {doc['source']}")
print("---")
7. RAG执行器
复制代码
from pymilvus import connections, Collection
from tmp.milvus.rag.create_collection import create_rag_collection
from tmp.milvus.rag.document_processor import DocumentProcessor
from tmp.milvus.rag.embedding_generator import EmbeddingGenerator
from tmp.milvus.rag.knowledge_base_manager import KnowledgeBaseManager
from tmp.milvus.rag.rag_pipeline import RAGPipeline
def main_rag_application():
"""完整的RAG应用示例"""
print("=== Milvus RAG应用启动 ===")
# 1. 初始化组件
connections.connect("default", host="localhost", port="19530")
# 2. 创建或获取集合
try:
collection = Collection("rag_documents")
print("使用现有集合")
except:
collection = create_rag_collection()
print("创建新集合")
# 3. 初始化各个组件
doc_processor = DocumentProcessor(chunk_size=1024)
embedding_generator = EmbeddingGenerator(dim=768)
rag_pipeline = RAGPipeline(collection, embedding_generator,doc_processor)
kb_manager = KnowledgeBaseManager(collection)
# 4. 添加示例文档
sample_docs = [
{
"title": "人工智能发展史",
"content": """人工智能的发展可以分为几个重要阶段。1950年代是AI的萌芽期,图灵测试在这一时期提出。
1960-1970年代是AI的黄金期,专家系统开始出现。1980年代遭遇了AI寒冬,资金和兴趣都有所下降。
1990年代机器学习开始兴起,统计方法逐渐取代了符号推理。21世纪深度学习的突破带来了AI的复兴。""",
"source": "ai_history",
"doc_type": "historical",
"tags": ["ai", "history", "development"],
"categories": ["ai", "history"]
}
]
for doc in sample_docs:
rag_pipeline.add_document(
content=doc["content"],
title=doc["title"],
source=doc["source"],
doc_type=doc["doc_type"],
tags=doc["tags"],
categories=doc["categories"]
)
# 5. 交互式查询循环
print("\n=== RAG查询系统已就绪 ===")
print("输入问题进行查询,输入'quit'退出")
while True:
try:
question = input("\n请输入您的问题: ").strip()
if question.lower() == 'quit':
break
if not question:
continue
# 执行查询
result = rag_pipeline.query(question)
print(f"\n问题: {result['question']}")
print(f"检索到 {len(result['retrieved_docs'])} 个相关文档")
print("\n检索到的相关文档:")
for i, doc in enumerate(result['retrieved_docs'][:3]): # 显示前3个
print(f" {i + 1}. {doc['title']} (相关性: {doc['final_score']:.3f})")
print(f" 内容: {doc['content'][:100]}...")
print("\n建议的Prompt结构:")
print(f"{result['prompt'][:300]}...")
except KeyboardInterrupt:
print("\n程序被用户中断")
break
except Exception as e:
print(f"查询出错: {e}")
# 6. 显示统计信息
stats = kb_manager.get_collection_stats()
print(f"\n=== 知识库最终统计 ===")
print(f"总文档数: {stats['total_documents']}")
print("来源分布:", stats['sources'])
# 运行完整应用
if __name__ == "__main__":
main_rag_application()