如何用DSPy优化chromadb向量检索的RAG系统

之前探索了如何用DSPy优化RAG prompt。

https://blog.csdn.net/liliang199/article/details/155893634

这里首先设定chromadb向量检索系统,然后通过DSPy优化基于retriever的RAG检索功能。

所用测试例和代码修改自网络资料。

1 retriever设置

1.1 向量计算设置

这里基于sentence-transformer框架,并采用bge-m3向量模型。

MiniLM是微软开发的轻量级语言模型,以较小参数量和计算成本实现。

all-MiniLM-L6-v2属于 MiniLM 系列,通过知识蒸馏从更大模型压缩而来,旨在保持较高性能同时减少计算资源需求。

代码示例如下

import os

os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

这里设置HF_ENDPOINT是因为hf访问受限。

1.2 向量库设置

chromadb的目标是帮助用户更加便捷地构建大模型应用,更加轻松的将知识、事实和技能等现实世界中的文档整合进大模型中。

这里选用chromadb向量库,安装过程参考如下链接

https://blog.csdn.net/liliang199/article/details/149542926

设置示例如下所示,向量检索使用余弦相似度。

persist_directory = "./curr_chromadb"

collection_name = "knowledge_test"

client = chromadb.PersistentClient(persist_directory)

collection =client.get_or_create_collection(

name=collection_name,

metadata={"hnsw:space": "cosine"} # 使用余弦相似度

)

1.3 向量检索示例

以下是应用sentence-transformer和chromadb的向量检索器类,可修改后应用于生产环境。

包括向量库初始化,文档划分、向量生成、向量检索等检索器要求的基本功能。

1)retriever
复制代码
import os
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

import dspy
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
import hashlib
import logging
from sentence_transformers import SentenceTransformer
import numpy as np

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VectorRetriever:
    """
    完整的向量检索器,支持文档初始化、嵌入生成和语义检索。
    使用ChromaDB作为向量存储,SentenceTransformer生成嵌入。
    """
    
    def __init__(
        self, 
        collection_name: str = "knowledge_base",
        embedding_model: str = "BAAI/bge-m3",  # 轻量且效果不错的模型
        persist_directory: str = "./chroma_db",
        chunk_size: int = 500,
        chunk_overlap: int = 50
    ):
        """
        初始化检索器
        
        Args:
            collection_name: 集合名称
            embedding_model: 嵌入模型名称
            persist_directory: 向量数据库存储路径
            chunk_size: 文本分块大小(字符数)
            chunk_overlap: 分块重叠大小(字符数)
        """
        self.embedding_model = SentenceTransformer(embedding_model)
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        
        # 初始化ChromaDB客户端
        self.client = chromadb.PersistentClient(persist_directory)
        # chromadb.Client(Settings(
        #     chroma_db_impl="duckdb+parquet",
        #     persist_directory=persist_directory,
        #     anonymized_telemetry=False  # 禁用遥测
        # ))
        
        # 获取或创建集合
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}  # 使用余弦相似度
        )
        
        logger.info(f"检索器初始化完成,使用模型: {embedding_model}")
    
    def _generate_chunk_id(self, text: str, source: str, chunk_index: int) -> str:
        """生成唯一的块ID"""
        content = f"{source}_{chunk_index}_{text[:50]}"
        return hashlib.md5(content.encode()).hexdigest()[:16]
    
    def split_text(self, text: str) -> List[str]:
        """
        将长文本分割为重叠的块
        
        Args:
            text: 输入文本
            
        Returns:
            文本块列表
        """
        if len(text) <= self.chunk_size:
            return [text]
        
        chunks = []
        start = 0
        
        while start < len(text):
            # 计算块结束位置
            end = start + self.chunk_size
            
            # 如果不在文本末尾,尝试在句号、逗号或空格处分割
            if end < len(text):
                # 查找合适的分割点
                for split_point in range(end, start + self.chunk_size - 100, -1):
                    if split_point < len(text) and text[split_point] in '.。!!??,,;;\n\t ':
                        end = split_point + 1  # 包含分割字符
                        break
            
            chunk = text[start:end].strip()
            if chunk:  # 忽略空块
                chunks.append(chunk)
            
            # 移动起始位置,考虑重叠
            start = end - self.chunk_overlap
            
            # 防止无限循环
            if start >= len(text) or (start == end and end >= len(text)):
                break
        
        logger.info(f"将文本分割为 {len(chunks)} 个块")
        return chunks
    
    def prepare_documents(
        self, 
        documents: List[Dict[str, Any]],
        batch_size: int = 32
    ) -> List[Dict[str, Any]]:
        """
        准备文档以供索引
        
        Args:
            documents: 文档列表,每个文档是字典,必须包含"content"字段
            batch_size: 批处理大小
            
        Returns:
            处理后的文档块列表
        """
        all_chunks = []
        total_docs = len(documents)
        
        logger.info(f"开始处理 {total_docs} 个文档...")
        
        for doc_idx, doc in enumerate(documents, 1):
            content = doc.get("content", "")
            if not content:
                logger.warning(f"文档 {doc_idx} 无内容,跳过")
                continue
            
            # 获取元数据(排除content字段)
            metadata = {k: v for k, v in doc.items() if k != "content"}
            metadata["source_doc_id"] = doc.get("id", f"doc_{doc_idx}")
            
            # 分割文本
            chunks = self.split_text(content)
            
            # 为每个块创建记录
            for chunk_idx, chunk in enumerate(chunks):
                chunk_id = self._generate_chunk_id(chunk, metadata["source_doc_id"], chunk_idx)
                
                chunk_metadata = metadata.copy()
                chunk_metadata.update({
                    "chunk_id": chunk_idx,
                    "total_chunks": len(chunks),
                    "char_length": len(chunk),
                    "doc_index": doc_idx - 1
                })
                
                all_chunks.append({
                    "id": chunk_id,
                    "text": chunk,
                    "metadata": chunk_metadata,
                    "embedding": None  # 稍后生成
                })
            
            if doc_idx % 10 == 0 or doc_idx == total_docs:
                logger.info(f"已处理 {doc_idx}/{total_docs} 个文档,生成 {len(all_chunks)} 个块")
        
        logger.info(f"文档准备完成,共生成 {len(all_chunks)} 个文本块")
        return all_chunks
    
    def generate_embeddings(
        self, 
        chunks: List[Dict[str, Any]],
        batch_size: int = 32
    ) -> List[Dict[str, Any]]:
        """
        为文本块生成嵌入向量
        
        Args:
            chunks: 文本块列表
            batch_size: 批处理大小
            
        Returns:
            包含嵌入向量的文本块列表
        """
        if not chunks:
            return chunks
        
        logger.info(f"开始为 {len(chunks)} 个文本块生成嵌入...")
        
        # 提取文本
        texts = [chunk["text"] for chunk in chunks]
        
        # 分批生成嵌入
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_embeddings = self.embedding_model.encode(
                batch_texts, 
                show_progress_bar=False,
                normalize_embeddings=True  # 归一化以使用余弦相似度
            )
            embeddings.extend(batch_embeddings)
            
            if (i // batch_size) % 10 == 0:
                logger.info(f"已生成 {min(i + batch_size, len(texts))}/{len(texts)} 个嵌入")
        
        # 将嵌入添加到块中
        for chunk, embedding in zip(chunks, embeddings):
            chunk["embedding"] = embedding.tolist() if hasattr(embedding, 'tolist') else embedding
        
        logger.info("嵌入生成完成")
        return chunks
    
    def index_documents(
        self, 
        documents: List[Dict[str, Any]],
        clear_existing: bool = False
    ) -> int:
        """
        索引文档到向量数据库
        
        Args:
            documents: 原始文档列表
            clear_existing: 是否清除现有索引
            
        Returns:
            索引的文档块数量
        """
        if clear_existing:
            self.client.delete_collection(self.collection.name)
            self.collection = self.client.get_or_create_collection(
                name=self.collection.name,
                metadata={"hnsw:space": "cosine"}
            )
            logger.info("已清除现有索引")
        
        # 准备文档
        chunks = self.prepare_documents(documents)
        
        if not chunks:
            logger.warning("没有可索引的文档块")
            return 0
        
        # 生成嵌入
        chunks_with_embeddings = self.generate_embeddings(chunks)
        
        # 准备数据用于添加到集合
        ids = [chunk["id"] for chunk in chunks_with_embeddings]
        texts = [chunk["text"] for chunk in chunks_with_embeddings]
        metadatas = [chunk["metadata"] for chunk in chunks_with_embeddings]
        embeddings = [chunk["embedding"] for chunk in chunks_with_embeddings]
        
        # 分批添加到集合
        batch_size = 100
        total_added = 0
        
        for i in range(0, len(ids), batch_size):
            batch_ids = ids[i:i + batch_size]
            batch_texts = texts[i:i + batch_size]
            batch_metadatas = metadatas[i:i + batch_size]
            batch_embeddings = embeddings[i:i + batch_size]
            
            self.collection.add(
                ids=batch_ids,
                documents=batch_texts,
                metadatas=batch_metadatas,
                embeddings=batch_embeddings
            )
            
            total_added += len(batch_ids)
            logger.info(f"已索引 {total_added}/{len(ids)} 个文档块")
        
        logger.info(f"文档索引完成,共索引 {total_added} 个文档块")
        return total_added
    
    def retrieve(
        self, 
        query: str, 
        k: int = 3,
        score_threshold: float = 0.5,
        include_metadata: bool = True
    ) -> List[Any]:
        """
        检索与查询最相关的文档
        
        Args:
            query: 查询文本
            k: 返回结果数量
            score_threshold: 相似度阈值
            include_metadata: 是否包含元数据
            
        Returns:
            检索结果列表
        """
        if self.collection.count() == 0:
            logger.warning("集合为空,请先索引文档")
            return []
        
        # 生成查询嵌入
        query_embedding = self.embedding_model.encode(
            query, 
            normalize_embeddings=True
        ).tolist()
        
        # 执行查询
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=min(k * 2, 20),  # 获取更多结果用于过滤
            include=["documents", "metadatas", "distances"]
        )
        
        # 处理结果
        retrieved_docs = []
        
        if results["documents"] and results["documents"][0]:
            for i, (doc, distance, metadata) in enumerate(zip(
                results["documents"][0],
                results["distances"][0],
                results["metadatas"][0] if results["metadatas"] else [{}] * len(results["documents"][0])
            )):
                # 将距离转换为相似度分数(余弦相似度)
                similarity = 1 - distance
                
                if similarity >= score_threshold:
                    result_item = {
                        "text": doc,
                        "score": similarity,
                        "metadata": metadata if include_metadata else None
                    }
                    retrieved_docs.append(result_item)
                
                # 达到所需数量时停止
                if len(retrieved_docs) >= k:
                    break
        
        # 按分数排序
        retrieved_docs.sort(key=lambda x: x["score"], reverse=True)
        
        logger.info(f"检索到 {len(retrieved_docs)} 个相关文档 (阈值: {score_threshold})")
        return retrieved_docs
    
    def get_collection_stats(self) -> Dict[str, Any]:
        """获取集合统计信息"""
        count = self.collection.count()
        
        # 获取一些样本的元数据以了解结构
        sample = self.collection.peek(limit=1)
        metadata_keys = set()
        
        if sample["metadatas"] and sample["metadatas"][0]:
            metadata_keys = set(sample["metadatas"][0].keys())
        
        return {
            "total_chunks": count,
            "metadata_fields": list(metadata_keys),
            "collection_name": self.collection.name
        }

这里尝试初始化和测试检索器,示例代码如下所示。

2)retriever test

retriever的测试代码如下所示,包括准备示例文档、初始化检索器、索引文档、测试检索等过程。

复制代码
# ===== 使用示例 =====
def retriever_test():
    """使用VectorRetriever的完整示例"""
    
    # 1. 准备示例文档
    sample_documents = [
        {
            "id": "doc_1",
            "title": "人工智能简介",
            "content": """
            人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。
            这些任务包括视觉感知、语音识别、决策和语言翻译。AI可以分为两类:弱人工智能和强人工智能。
            弱人工智能专用于特定任务,而强人工智能具有通用的人类认知能力。
            
            机器学习是AI的一个子集,它使计算机能够在没有明确编程的情况下从数据中学习。
            深度学习是机器学习的一个子领域,使用神经网络模拟人脑的工作方式。
            """,
            "category": "科技",
            "author": "张三",
            "year": 2023
        },
        {
            "id": "doc_2", 
            "title": "机器学习算法",
            "content": """
            机器学习算法可以分为三大类:监督学习、无监督学习和强化学习。
            
            监督学习使用标记数据训练模型,例如分类和回归任务。常见算法包括线性回归、逻辑回归、支持向量机和神经网络。
            
            无监督学习处理未标记数据,旨在发现数据中的模式。聚类和降维是无监督学习的典型任务。K-means和PCA是常用算法。
            
            强化学习涉及智能体通过与环境交互学习最优策略。它通过奖励和惩罚机制进行学习,常用于游戏和机器人控制。
            """,
            "category": "科技",
            "author": "李四",
            "year": 2022
        },
        {
            "id": "doc_3",
            "title": "自然语言处理",
            "content": """
            自然语言处理(NLP)是AI的一个分支,专注于计算机与人类语言之间的交互。
            NLP任务包括文本分类、情感分析、命名实体识别、机器翻译和问答系统。
            
            近年来,基于Transformer的模型(如BERT、GPT)彻底改变了NLP领域。
            这些模型使用自注意力机制,能够更好地理解上下文和语言语义。
            
            NLP的应用包括聊天机器人、搜索引擎、语音助手和自动摘要等。
            """,
            "category": "科技", 
            "author": "王五",
            "year": 2024
        }
    ]
    
    # 2. 初始化检索器
    print("初始化向量检索器...")
    retriever = VectorRetriever(
        collection_name="ai_knowledge_base",
        embedding_model="all-MiniLM-L6-v2",
        persist_directory="./chroma_test_ai",
        chunk_size=400,
        chunk_overlap=30
    )
    
    # 3. 索引文档
    print("\n索引文档...")
    indexed_count = retriever.index_documents(
        sample_documents,
        clear_existing=True  # 首次运行设为True,后续可以设为False以增量添加
    )
    print(f"已索引 {indexed_count} 个文档块")
    
    # 4. 查看统计信息
    stats = retriever.get_collection_stats()
    print(f"\n集合统计:")
    print(f"  文档块总数: {stats['total_chunks']}")
    print(f"  元数据字段: {stats['metadata_fields']}")
    
    # 5. 测试检索
    test_queries = [
        "什么是机器学习?",
        "监督学习和无监督学习有什么区别?",
        "介绍一下自然语言处理"
    ]
    
    print("\n测试检索功能:")
    for query in test_queries:
        print(f"\n查询: '{query}'")
        results = retriever.retrieve(query, k=2, score_threshold=0.3)
        
        for i, result in enumerate(results, 1):
            print(f"  结果 {i} (分数: {result['score']:.3f}):")
            print(f"    文本: {result['text'][:100]}...")
            if result['metadata']:
                print(f"    来源: {result['metadata'].get('title', '未知')}")
    
    return retriever

retriever = retriever_test()

输出如下

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2

初始化向量检索器...

INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.

INFO:main:检索器初始化完成,使用模型: all-MiniLM-L6-v2

INFO:main:已清除现有索引

INFO:main:开始处理 3 个文档...

INFO:main:已处理 3/3 个文档,生成 3 个块

INFO:main:文档准备完成,共生成 3 个文本块

INFO:main:开始为 3 个文本块生成嵌入...

INFO:main:已生成 3/3 个嵌入

INFO:main:嵌入生成完成

INFO:main:已索引 3/3 个文档块

INFO:main:文档索引完成,共索引 3 个文档块

索引文档...

已索引 3 个文档块

集合统计:

文档块总数: 3

元数据字段: ['chunk_id', 'category', 'id', 'total_chunks', 'char_length', 'year', 'source_doc_id', 'author', 'title', 'doc_index']

测试检索功能:

查询: '什么是机器学习?'

Batches: 100%|██████████| 1/1 [00:00<00:00, 87.86it/s]

INFO:main:检索到 2 个相关文档 (阈值: 0.3)

结果 1 (分数: 0.475):

文本:

机器学习算法可以分为三大类:监督学习、无监督学习和强化学习。

监督学习使用标记数据训练模型,例如分类和回归任务。常见算法包括线性回归、逻辑回归、支持...

来源: 机器学习算法

结果 2 (分数: 0.375):

文本:

人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。

这些任务包括视觉感知、语音识别、决策和语言翻译。AI可以分为...

来源: 人工智能简介

查询: '监督学习和无监督学习有什么区别?'

Batches: 100%|██████████| 1/1 [00:00<00:00, 126.13it/s]

INFO:main:检索到 2 个相关文档 (阈值: 0.3)

结果 1 (分数: 0.478):

文本:

机器学习算法可以分为三大类:监督学习、无监督学习和强化学习。

监督学习使用标记数据训练模型,例如分类和回归任务。常见算法包括线性回归、逻辑回归、支持...

来源: 机器学习算法

结果 2 (分数: 0.416):

文本:

自然语言处理(NLP)是AI的一个分支,专注于计算机与人类语言之间的交互。

NLP任务包括文本分类、情感分析、命名实体识别、机器翻译和问答系统。

...

来源: 自然语言处理

查询: '介绍一下自然语言处理'

Batches: 100%|██████████| 1/1 [00:00<00:00, 141.48it/s]

INFO:main:检索到 2 个相关文档 (阈值: 0.3)

结果 1 (分数: 0.330):

文本:

自然语言处理(NLP)是AI的一个分支,专注于计算机与人类语言之间的交互。

NLP任务包括文本分类、情感分析、命名实体识别、机器翻译和问答系统。

...

来源: 自然语言处理

结果 2 (分数: 0.318):

文本:

人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。

这些任务包括视觉感知、语音识别、决策和语言翻译。AI可以分为...

来源: 人工智能简介

2 RAG设置

2.1 配置DSPy签名

沿用之前应用习惯,这里继续使用ollama/gemma3n:e2b模型,并假设ollama环境可用。

在此基础上,设置:

context: 问题相关的上下文背景,通过retriever向量检索获得

question: 用户查询问题

answer: llm依据context和question生成简洁、准确的答案。

示例代码如下。

复制代码
# 配置DSPy
# lm = dspy.LM('openai/gpt-4o-mini')
lm = dspy.LM(model="ollama/gemma3n:e2b", api_base="http://localhost:11434")
dspy.configure(lm=lm)
    
# 定义RAG签名
class RAGSignature(dspy.Signature):
    """基于上下文回答问题"""
    context = dspy.InputField(desc="相关背景信息")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="简洁、准确的答案")

2.2 配置RAG系统

这里进一步设置RAG系统,在生成环节应用DSPy的ChainOfThought,通过思维链进行优化。

检索具体过程为:

首先,retriever检索question相关的上下文并合并

然后,基于检索上下文和问题生成答案,这一步会应用DSPy优化后的prompt。

最后,将上下文、问题、答案返回给用户。

示例代码如下。

复制代码
# 创建RAG模块
class EnhancedRAG(dspy.Module):
    def __init__(self, retriever):
        super().__init__()
        self.retriever = retriever
        self.generate_answer = dspy.ChainOfThought(RAGSignature)
        
    def forward(self, question):
        # 检索相关文档
        retrieved = self.retriever.retrieve(question, k=3)
            
        if not retrieved:
            return dspy.Prediction(
                context="",
                answer="未找到相关信息",
                sources=[]
            )
            
        # 合并上下文
        contexts = [r["text"] for r in retrieved]
        context_str = "\n\n".join(contexts)
            
        # 生成答案
        prediction = self.generate_answer(
            context=context_str,
            question=question
        )
            
        # 返回结果
        return dspy.Prediction(
            context=context_str,
            answer=prediction.answer,
            sources=[r["metadata"] for r in retrieved if r["metadata"]],
            reasoning=getattr(prediction, 'reasoning', '')
        )
    
# 创建RAG系统并测试
rag_system = EnhancedRAG(retriever)

2.3 测试RAG系统

这里测试刚才定义好的RAG系统,直接输入用户问题。

test_question = "机器学习有哪些主要类型?"

测试代码示例如下

复制代码
test_question = "机器学习有哪些主要类型?"
print(f"\nRAG系统测试问题: {test_question}")
    
result = rag_system(test_question)
print(f"答案: {result.answer}")
print(f"来源文档数: {len(result.sources)}")

输出如下,RAG检索到test_question相关的问题,并应用LLM进行回答。

并且回答内容简洁、准确,符合DSPy签名中定义的基本需求。

RAG系统测试问题: 机器学习有哪些主要类型?

Batches: 100%|██████████| 1/1 [00:00<00:00, 84.42it/s]

INFO:main:检索到 1 个相关文档 (阈值: 0.5)

答案: 监督学习、无监督学习和强化学习

来源文档数: 1

reference


如何用DSPy优化RAG prompt示例

https://blog.csdn.net/liliang199/article/details/155893634

开源向量LLM - BGE (BAAI General Embedding)

https://blog.csdn.net/liliang199/article/details/149773775

mac测试ollama llamaindex

https://blog.csdn.net/liliang199/article/details/149542926

chromedb

https://docs.trychroma.com/docs/overview/migration

bge embedding

https://github.com/FlagOpen/FlagEmbedding

一篇吃透模型:all-MiniLM-L6-v2

https://zhuanlan.zhihu.com/p/27730652617

向量数据库Chroma极简教程

https://zhuanlan.zhihu.com/p/665715823

相关推荐
寂寞恋上夜13 小时前
枚举值怎么管理:固定枚举/字典表/接口动态(附管理策略)
prompt·状态模式·markdown转xmind·deepseek思维导图
沛沛老爹19 小时前
Skills高级设计模式(一):向导式工作流与模板生成
java·人工智能·设计模式·prompt·aigc·agent·web转型
minhuan20 小时前
大模型应用:大模型权限管控设计:角色权限分配与违规 Prompt 拦截.49
prompt·大模型应用·大模型权限管控·违规提示词监测
Helson@lin1 天前
Vibe Coding-Web端UI分享Prompt 可复刻
prompt
victory04311 天前
同一prompt下 doubao qwen gpt kimi的模型训练时长预测不同表现
gpt·prompt
后端小张1 天前
【AI 学习】AI提示词工程:从入门到实战的全栈指南
java·人工智能·深度学习·学习·语言模型·prompt·知识图谱
reddingtons1 天前
【游戏宣发】PS “生成式扩展”流,30秒无损适配全渠道KV
游戏·设计模式·新媒体运营·prompt·aigc·教育电商·游戏美术
Chasing Aurora1 天前
数据库连接+查询优化
数据库·sql·mysql·prompt·约束
效率客栈老秦2 天前
Python Trae提示词开发实战(2):2026 最新 10个自动化批处理场景 + 完整代码
人工智能·python·ai·prompt·trae
GISer_Jing2 天前
提示链(Prompt Chaining)、路由、并行化和反思
人工智能·设计模式·prompt·aigc