之前探索了如何用DSPy优化RAG prompt。
https://blog.csdn.net/liliang199/article/details/155893634
这里首先设定chromadb向量检索系统,然后通过DSPy优化基于retriever的RAG检索功能。
所用测试例和代码修改自网络资料。
1 retriever设置
1.1 向量计算设置
这里基于sentence-transformer框架,并采用bge-m3向量模型。
MiniLM是微软开发的轻量级语言模型,以较小参数量和计算成本实现。
all-MiniLM-L6-v2属于 MiniLM 系列,通过知识蒸馏从更大模型压缩而来,旨在保持较高性能同时减少计算资源需求。
代码示例如下
import os
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
这里设置HF_ENDPOINT是因为hf访问受限。
1.2 向量库设置
chromadb的目标是帮助用户更加便捷地构建大模型应用,更加轻松的将知识、事实和技能等现实世界中的文档整合进大模型中。
这里选用chromadb向量库,安装过程参考如下链接
https://blog.csdn.net/liliang199/article/details/149542926
设置示例如下所示,向量检索使用余弦相似度。
persist_directory = "./curr_chromadb"
collection_name = "knowledge_test"
client = chromadb.PersistentClient(persist_directory)
collection =client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"} # 使用余弦相似度
)
1.3 向量检索示例
以下是应用sentence-transformer和chromadb的向量检索器类,可修改后应用于生产环境。
包括向量库初始化,文档划分、向量生成、向量检索等检索器要求的基本功能。
1)retriever
import os
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
import dspy
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
import hashlib
import logging
from sentence_transformers import SentenceTransformer
import numpy as np
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VectorRetriever:
"""
完整的向量检索器,支持文档初始化、嵌入生成和语义检索。
使用ChromaDB作为向量存储,SentenceTransformer生成嵌入。
"""
def __init__(
self,
collection_name: str = "knowledge_base",
embedding_model: str = "BAAI/bge-m3", # 轻量且效果不错的模型
persist_directory: str = "./chroma_db",
chunk_size: int = 500,
chunk_overlap: int = 50
):
"""
初始化检索器
Args:
collection_name: 集合名称
embedding_model: 嵌入模型名称
persist_directory: 向量数据库存储路径
chunk_size: 文本分块大小(字符数)
chunk_overlap: 分块重叠大小(字符数)
"""
self.embedding_model = SentenceTransformer(embedding_model)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# 初始化ChromaDB客户端
self.client = chromadb.PersistentClient(persist_directory)
# chromadb.Client(Settings(
# chroma_db_impl="duckdb+parquet",
# persist_directory=persist_directory,
# anonymized_telemetry=False # 禁用遥测
# ))
# 获取或创建集合
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"} # 使用余弦相似度
)
logger.info(f"检索器初始化完成,使用模型: {embedding_model}")
def _generate_chunk_id(self, text: str, source: str, chunk_index: int) -> str:
"""生成唯一的块ID"""
content = f"{source}_{chunk_index}_{text[:50]}"
return hashlib.md5(content.encode()).hexdigest()[:16]
def split_text(self, text: str) -> List[str]:
"""
将长文本分割为重叠的块
Args:
text: 输入文本
Returns:
文本块列表
"""
if len(text) <= self.chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
# 计算块结束位置
end = start + self.chunk_size
# 如果不在文本末尾,尝试在句号、逗号或空格处分割
if end < len(text):
# 查找合适的分割点
for split_point in range(end, start + self.chunk_size - 100, -1):
if split_point < len(text) and text[split_point] in '.。!!??,,;;\n\t ':
end = split_point + 1 # 包含分割字符
break
chunk = text[start:end].strip()
if chunk: # 忽略空块
chunks.append(chunk)
# 移动起始位置,考虑重叠
start = end - self.chunk_overlap
# 防止无限循环
if start >= len(text) or (start == end and end >= len(text)):
break
logger.info(f"将文本分割为 {len(chunks)} 个块")
return chunks
def prepare_documents(
self,
documents: List[Dict[str, Any]],
batch_size: int = 32
) -> List[Dict[str, Any]]:
"""
准备文档以供索引
Args:
documents: 文档列表,每个文档是字典,必须包含"content"字段
batch_size: 批处理大小
Returns:
处理后的文档块列表
"""
all_chunks = []
total_docs = len(documents)
logger.info(f"开始处理 {total_docs} 个文档...")
for doc_idx, doc in enumerate(documents, 1):
content = doc.get("content", "")
if not content:
logger.warning(f"文档 {doc_idx} 无内容,跳过")
continue
# 获取元数据(排除content字段)
metadata = {k: v for k, v in doc.items() if k != "content"}
metadata["source_doc_id"] = doc.get("id", f"doc_{doc_idx}")
# 分割文本
chunks = self.split_text(content)
# 为每个块创建记录
for chunk_idx, chunk in enumerate(chunks):
chunk_id = self._generate_chunk_id(chunk, metadata["source_doc_id"], chunk_idx)
chunk_metadata = metadata.copy()
chunk_metadata.update({
"chunk_id": chunk_idx,
"total_chunks": len(chunks),
"char_length": len(chunk),
"doc_index": doc_idx - 1
})
all_chunks.append({
"id": chunk_id,
"text": chunk,
"metadata": chunk_metadata,
"embedding": None # 稍后生成
})
if doc_idx % 10 == 0 or doc_idx == total_docs:
logger.info(f"已处理 {doc_idx}/{total_docs} 个文档,生成 {len(all_chunks)} 个块")
logger.info(f"文档准备完成,共生成 {len(all_chunks)} 个文本块")
return all_chunks
def generate_embeddings(
self,
chunks: List[Dict[str, Any]],
batch_size: int = 32
) -> List[Dict[str, Any]]:
"""
为文本块生成嵌入向量
Args:
chunks: 文本块列表
batch_size: 批处理大小
Returns:
包含嵌入向量的文本块列表
"""
if not chunks:
return chunks
logger.info(f"开始为 {len(chunks)} 个文本块生成嵌入...")
# 提取文本
texts = [chunk["text"] for chunk in chunks]
# 分批生成嵌入
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_embeddings = self.embedding_model.encode(
batch_texts,
show_progress_bar=False,
normalize_embeddings=True # 归一化以使用余弦相似度
)
embeddings.extend(batch_embeddings)
if (i // batch_size) % 10 == 0:
logger.info(f"已生成 {min(i + batch_size, len(texts))}/{len(texts)} 个嵌入")
# 将嵌入添加到块中
for chunk, embedding in zip(chunks, embeddings):
chunk["embedding"] = embedding.tolist() if hasattr(embedding, 'tolist') else embedding
logger.info("嵌入生成完成")
return chunks
def index_documents(
self,
documents: List[Dict[str, Any]],
clear_existing: bool = False
) -> int:
"""
索引文档到向量数据库
Args:
documents: 原始文档列表
clear_existing: 是否清除现有索引
Returns:
索引的文档块数量
"""
if clear_existing:
self.client.delete_collection(self.collection.name)
self.collection = self.client.get_or_create_collection(
name=self.collection.name,
metadata={"hnsw:space": "cosine"}
)
logger.info("已清除现有索引")
# 准备文档
chunks = self.prepare_documents(documents)
if not chunks:
logger.warning("没有可索引的文档块")
return 0
# 生成嵌入
chunks_with_embeddings = self.generate_embeddings(chunks)
# 准备数据用于添加到集合
ids = [chunk["id"] for chunk in chunks_with_embeddings]
texts = [chunk["text"] for chunk in chunks_with_embeddings]
metadatas = [chunk["metadata"] for chunk in chunks_with_embeddings]
embeddings = [chunk["embedding"] for chunk in chunks_with_embeddings]
# 分批添加到集合
batch_size = 100
total_added = 0
for i in range(0, len(ids), batch_size):
batch_ids = ids[i:i + batch_size]
batch_texts = texts[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
self.collection.add(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
embeddings=batch_embeddings
)
total_added += len(batch_ids)
logger.info(f"已索引 {total_added}/{len(ids)} 个文档块")
logger.info(f"文档索引完成,共索引 {total_added} 个文档块")
return total_added
def retrieve(
self,
query: str,
k: int = 3,
score_threshold: float = 0.5,
include_metadata: bool = True
) -> List[Any]:
"""
检索与查询最相关的文档
Args:
query: 查询文本
k: 返回结果数量
score_threshold: 相似度阈值
include_metadata: 是否包含元数据
Returns:
检索结果列表
"""
if self.collection.count() == 0:
logger.warning("集合为空,请先索引文档")
return []
# 生成查询嵌入
query_embedding = self.embedding_model.encode(
query,
normalize_embeddings=True
).tolist()
# 执行查询
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(k * 2, 20), # 获取更多结果用于过滤
include=["documents", "metadatas", "distances"]
)
# 处理结果
retrieved_docs = []
if results["documents"] and results["documents"][0]:
for i, (doc, distance, metadata) in enumerate(zip(
results["documents"][0],
results["distances"][0],
results["metadatas"][0] if results["metadatas"] else [{}] * len(results["documents"][0])
)):
# 将距离转换为相似度分数(余弦相似度)
similarity = 1 - distance
if similarity >= score_threshold:
result_item = {
"text": doc,
"score": similarity,
"metadata": metadata if include_metadata else None
}
retrieved_docs.append(result_item)
# 达到所需数量时停止
if len(retrieved_docs) >= k:
break
# 按分数排序
retrieved_docs.sort(key=lambda x: x["score"], reverse=True)
logger.info(f"检索到 {len(retrieved_docs)} 个相关文档 (阈值: {score_threshold})")
return retrieved_docs
def get_collection_stats(self) -> Dict[str, Any]:
"""获取集合统计信息"""
count = self.collection.count()
# 获取一些样本的元数据以了解结构
sample = self.collection.peek(limit=1)
metadata_keys = set()
if sample["metadatas"] and sample["metadatas"][0]:
metadata_keys = set(sample["metadatas"][0].keys())
return {
"total_chunks": count,
"metadata_fields": list(metadata_keys),
"collection_name": self.collection.name
}
这里尝试初始化和测试检索器,示例代码如下所示。
2)retriever test
retriever的测试代码如下所示,包括准备示例文档、初始化检索器、索引文档、测试检索等过程。
# ===== 使用示例 =====
def retriever_test():
"""使用VectorRetriever的完整示例"""
# 1. 准备示例文档
sample_documents = [
{
"id": "doc_1",
"title": "人工智能简介",
"content": """
人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。
这些任务包括视觉感知、语音识别、决策和语言翻译。AI可以分为两类:弱人工智能和强人工智能。
弱人工智能专用于特定任务,而强人工智能具有通用的人类认知能力。
机器学习是AI的一个子集,它使计算机能够在没有明确编程的情况下从数据中学习。
深度学习是机器学习的一个子领域,使用神经网络模拟人脑的工作方式。
""",
"category": "科技",
"author": "张三",
"year": 2023
},
{
"id": "doc_2",
"title": "机器学习算法",
"content": """
机器学习算法可以分为三大类:监督学习、无监督学习和强化学习。
监督学习使用标记数据训练模型,例如分类和回归任务。常见算法包括线性回归、逻辑回归、支持向量机和神经网络。
无监督学习处理未标记数据,旨在发现数据中的模式。聚类和降维是无监督学习的典型任务。K-means和PCA是常用算法。
强化学习涉及智能体通过与环境交互学习最优策略。它通过奖励和惩罚机制进行学习,常用于游戏和机器人控制。
""",
"category": "科技",
"author": "李四",
"year": 2022
},
{
"id": "doc_3",
"title": "自然语言处理",
"content": """
自然语言处理(NLP)是AI的一个分支,专注于计算机与人类语言之间的交互。
NLP任务包括文本分类、情感分析、命名实体识别、机器翻译和问答系统。
近年来,基于Transformer的模型(如BERT、GPT)彻底改变了NLP领域。
这些模型使用自注意力机制,能够更好地理解上下文和语言语义。
NLP的应用包括聊天机器人、搜索引擎、语音助手和自动摘要等。
""",
"category": "科技",
"author": "王五",
"year": 2024
}
]
# 2. 初始化检索器
print("初始化向量检索器...")
retriever = VectorRetriever(
collection_name="ai_knowledge_base",
embedding_model="all-MiniLM-L6-v2",
persist_directory="./chroma_test_ai",
chunk_size=400,
chunk_overlap=30
)
# 3. 索引文档
print("\n索引文档...")
indexed_count = retriever.index_documents(
sample_documents,
clear_existing=True # 首次运行设为True,后续可以设为False以增量添加
)
print(f"已索引 {indexed_count} 个文档块")
# 4. 查看统计信息
stats = retriever.get_collection_stats()
print(f"\n集合统计:")
print(f" 文档块总数: {stats['total_chunks']}")
print(f" 元数据字段: {stats['metadata_fields']}")
# 5. 测试检索
test_queries = [
"什么是机器学习?",
"监督学习和无监督学习有什么区别?",
"介绍一下自然语言处理"
]
print("\n测试检索功能:")
for query in test_queries:
print(f"\n查询: '{query}'")
results = retriever.retrieve(query, k=2, score_threshold=0.3)
for i, result in enumerate(results, 1):
print(f" 结果 {i} (分数: {result['score']:.3f}):")
print(f" 文本: {result['text'][:100]}...")
if result['metadata']:
print(f" 来源: {result['metadata'].get('title', '未知')}")
return retriever
retriever = retriever_test()
输出如下
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
初始化向量检索器...
INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.
INFO:main:检索器初始化完成,使用模型: all-MiniLM-L6-v2
INFO:main:已清除现有索引
INFO:main:开始处理 3 个文档...
INFO:main:已处理 3/3 个文档,生成 3 个块
INFO:main:文档准备完成,共生成 3 个文本块
INFO:main:开始为 3 个文本块生成嵌入...
INFO:main:已生成 3/3 个嵌入
INFO:main:嵌入生成完成
INFO:main:已索引 3/3 个文档块
INFO:main:文档索引完成,共索引 3 个文档块
索引文档...
已索引 3 个文档块
集合统计:
文档块总数: 3
元数据字段: ['chunk_id', 'category', 'id', 'total_chunks', 'char_length', 'year', 'source_doc_id', 'author', 'title', 'doc_index']
测试检索功能:
查询: '什么是机器学习?'
Batches: 100%|██████████| 1/1 [00:00<00:00, 87.86it/s]
INFO:main:检索到 2 个相关文档 (阈值: 0.3)
结果 1 (分数: 0.475):
文本:
机器学习算法可以分为三大类:监督学习、无监督学习和强化学习。
监督学习使用标记数据训练模型,例如分类和回归任务。常见算法包括线性回归、逻辑回归、支持...
来源: 机器学习算法
结果 2 (分数: 0.375):
文本:
人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。
这些任务包括视觉感知、语音识别、决策和语言翻译。AI可以分为...
来源: 人工智能简介
查询: '监督学习和无监督学习有什么区别?'
Batches: 100%|██████████| 1/1 [00:00<00:00, 126.13it/s]
INFO:main:检索到 2 个相关文档 (阈值: 0.3)
结果 1 (分数: 0.478):
文本:
机器学习算法可以分为三大类:监督学习、无监督学习和强化学习。
监督学习使用标记数据训练模型,例如分类和回归任务。常见算法包括线性回归、逻辑回归、支持...
来源: 机器学习算法
结果 2 (分数: 0.416):
文本:
自然语言处理(NLP)是AI的一个分支,专注于计算机与人类语言之间的交互。
NLP任务包括文本分类、情感分析、命名实体识别、机器翻译和问答系统。
...
来源: 自然语言处理
查询: '介绍一下自然语言处理'
Batches: 100%|██████████| 1/1 [00:00<00:00, 141.48it/s]
INFO:main:检索到 2 个相关文档 (阈值: 0.3)
结果 1 (分数: 0.330):
文本:
自然语言处理(NLP)是AI的一个分支,专注于计算机与人类语言之间的交互。
NLP任务包括文本分类、情感分析、命名实体识别、机器翻译和问答系统。
...
来源: 自然语言处理
结果 2 (分数: 0.318):
文本:
人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。
这些任务包括视觉感知、语音识别、决策和语言翻译。AI可以分为...
来源: 人工智能简介
2 RAG设置
2.1 配置DSPy签名
沿用之前应用习惯,这里继续使用ollama/gemma3n:e2b模型,并假设ollama环境可用。
在此基础上,设置:
context: 问题相关的上下文背景,通过retriever向量检索获得
question: 用户查询问题
answer: llm依据context和question生成简洁、准确的答案。
示例代码如下。
# 配置DSPy
# lm = dspy.LM('openai/gpt-4o-mini')
lm = dspy.LM(model="ollama/gemma3n:e2b", api_base="http://localhost:11434")
dspy.configure(lm=lm)
# 定义RAG签名
class RAGSignature(dspy.Signature):
"""基于上下文回答问题"""
context = dspy.InputField(desc="相关背景信息")
question = dspy.InputField()
answer = dspy.OutputField(desc="简洁、准确的答案")
2.2 配置RAG系统
这里进一步设置RAG系统,在生成环节应用DSPy的ChainOfThought,通过思维链进行优化。
检索具体过程为:
首先,retriever检索question相关的上下文并合并
然后,基于检索上下文和问题生成答案,这一步会应用DSPy优化后的prompt。
最后,将上下文、问题、答案返回给用户。
示例代码如下。
# 创建RAG模块
class EnhancedRAG(dspy.Module):
def __init__(self, retriever):
super().__init__()
self.retriever = retriever
self.generate_answer = dspy.ChainOfThought(RAGSignature)
def forward(self, question):
# 检索相关文档
retrieved = self.retriever.retrieve(question, k=3)
if not retrieved:
return dspy.Prediction(
context="",
answer="未找到相关信息",
sources=[]
)
# 合并上下文
contexts = [r["text"] for r in retrieved]
context_str = "\n\n".join(contexts)
# 生成答案
prediction = self.generate_answer(
context=context_str,
question=question
)
# 返回结果
return dspy.Prediction(
context=context_str,
answer=prediction.answer,
sources=[r["metadata"] for r in retrieved if r["metadata"]],
reasoning=getattr(prediction, 'reasoning', '')
)
# 创建RAG系统并测试
rag_system = EnhancedRAG(retriever)
2.3 测试RAG系统
这里测试刚才定义好的RAG系统,直接输入用户问题。
test_question = "机器学习有哪些主要类型?"
测试代码示例如下
test_question = "机器学习有哪些主要类型?"
print(f"\nRAG系统测试问题: {test_question}")
result = rag_system(test_question)
print(f"答案: {result.answer}")
print(f"来源文档数: {len(result.sources)}")
输出如下,RAG检索到test_question相关的问题,并应用LLM进行回答。
并且回答内容简洁、准确,符合DSPy签名中定义的基本需求。
RAG系统测试问题: 机器学习有哪些主要类型?
Batches: 100%|██████████| 1/1 [00:00<00:00, 84.42it/s]
INFO:main:检索到 1 个相关文档 (阈值: 0.5)
答案: 监督学习、无监督学习和强化学习
来源文档数: 1
reference
如何用DSPy优化RAG prompt示例
https://blog.csdn.net/liliang199/article/details/155893634
开源向量LLM - BGE (BAAI General Embedding)
https://blog.csdn.net/liliang199/article/details/149773775
mac测试ollama llamaindex
https://blog.csdn.net/liliang199/article/details/149542926
chromedb
https://docs.trychroma.com/docs/overview/migration
bge embedding
https://github.com/FlagOpen/FlagEmbedding
一篇吃透模型:all-MiniLM-L6-v2
https://zhuanlan.zhihu.com/p/27730652617
向量数据库Chroma极简教程