AI深入技能之-Rag 检索优化(四)- 实战落地

完整的 RAG 系统远不止是一个简单的查询→检索→生成流程,它更像一个精密的流水线,需要在多个环节进行精细控制。

这里,我将分享一个模块化的实现,它融合了此前讨论的查询改写、混合检索、重排序、上下文压缩和评估等多种优化技术。这个项目分为 8 个模块,并集成了 RAGAS 来进行系统化的性能评估。


1. 项目结构

首先,我们建立一个清晰的项目结构,以确保代码本身具有良好的模块化,便于扩展和维护。

text 复制代码
rag_optimized_system/
├── data/                     # 存放原始文档和索引数据
├── logs/                     # 存放日志文件
├── src/                      # 核心源代码
│   ├── __init__.py
│   ├── ingestion.py          # 数据加载和索引
│   ├── retrieval.py          # 混合检索和重排序
│   ├── generation.py         # 答案生成和上下文压缩
│   ├── evaluation.py         # 评估模块
│   ├── query_processor.py    # 查询改写和HyDE
│   ├── config.py             # 配置管理
│   └── utils.py              # 工具函数
├── tests/                    # 单元测试
│   └── test_retrieval.py
├── scripts/                  # 辅助脚本
│   ├── index_data.py         # 建索引脚本
│   └── run_evaluation.py     # 运行评估脚本
├── requirements.txt          # 项目依赖
└── README.md                 # 项目说明

2. 环境配置与依赖

所有需要的 Python 库都列在 requirements.txt 中,确保环境的一致性和可复现性。

python 复制代码
# requirements.txt
haystack-ai>=2.16.0
sentence-transformers>=3.0.0
transformers>=4.36.0
torch>=2.1.0
rank-bm25>=0.2.2
numpy>=1.24.0
pandas>=2.0.0
tqdm>=4.66.0
openai>=1.30.0
python-dotenv>=1.0.0
cohere>=5.5.0
ravens[faiss-cpu]>=1.0.0
langchain>=0.3.0
ragas>=0.1.0
elasticsearch>=8.10.0
fastapi>=0.100.0
uvicorn>=0.30.0
pytest>=8.0.0

3. 配置管理

通过 .env 文件管理各种 API 密钥和模型参数,既安全又方便修改。

python 复制代码
# src/config.py
import os
from dotenv import load_dotenv

load_dotenv()

class Config:
    # API Keys
    OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
    COHERE_API_KEY = os.getenv("COHERE_API_KEY")
    
    # Model Configurations
    EMBEDDING_MODEL = "intfloat/e5-base-v2"
    RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    GENERATION_MODEL = "gpt-3.5-turbo"
    
    # Retrieval Config
    RETRIEVAL_TOP_K = 20
    RERANK_TOP_K = 5
    HYBRID_ALPHA = 0.5  # BM25 vs Dense 的权重
    
    # Chunking Config
    CHUNK_SIZE = 512
    CHUNK_OVERLAP = 50
    
    # Evaluation
    EVALUATION_DATASET_SIZE = 50

4. 核心模块实现

下面是各个核心模块的代码实现。

📄 模块一:文档索引 (ingestion.py)

该模块加载文档,使用递归分块 策略进行切片,并用 sentence-transformers 模型生成向量嵌入。

python 复制代码
# src/ingestion.py
import logging
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader, PDFMinerLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from src.config import Config

logger = logging.getLogger(__name__)

class DocumentIndexer:
    def __init__(self, embedding_model: str = Config.EMBEDDING_MODEL):
        """初始化文档索引器"""
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=Config.CHUNK_SIZE,
            chunk_overlap=Config.CHUNK_OVERLAP,
            separators=["\n\n", "\n", "。", "!", "?", " ", ""]
        )
        self.vectorstore = None

    def load_documents(self, file_paths: List[str]) -> List[Document]:
        """加载多种格式的文档"""
        documents = []
        for path in file_paths:
            try:
                if path.endswith('.txt'):
                    loader = TextLoader(path)
                elif path.endswith('.pdf'):
                    loader = PDFMinerLoader(path)
                else:
                    logger.warning(f"Unsupported file type: {path}")
                    continue
                documents.extend(loader.load())
                logger.info(f"Loaded {len(documents)} documents from {path}")
            except Exception as e:
                logger.error(f"Failed to load {path}: {e}")
        return documents

    def create_index(self, documents: List[Document]):
        """创建FAISS向量索引"""
        # 文档分块
        chunks = self.text_splitter.split_documents(documents)
        
        # 记录分块前后的信息
        logger.info(f"Original documents: {len(documents)}")
        logger.info(f"Chunks created: {len(chunks)}")
        
        # 创建向量存储
        self.vectorstore = FAISS.from_documents(chunks, self.embeddings)
        logger.info("FAISS index created successfully")

    def save_index(self, path: str):
        """保存索引到本地"""
        if self.vectorstore:
            self.vectorstore.save_local(path)
            logger.info(f"Index saved to {path}")

    def load_index(self, path: str):
        """从本地加载索引"""
        self.vectorstore = FAISS.load_local(path, self.embeddings, allow_dangerous_deserialization=True)
        logger.info(f"Index loaded from {path}")

🔍 模块二:检索优化 (retrieval.py)

该模块实现了混合检索 (BM25 + 向量)和重排序两大核心优化策略。

python 复制代码
# src/retrieval.py
import logging
from typing import List, Tuple
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from src.config import Config

logger = logging.getLogger(__name__)

class OptimizedRetriever:
    def __init__(self, vectorstore: FAISS, documents: List[Document]):
        """初始化优化检索器"""
        self.vectorstore = vectorstore
        self.documents = documents
        # 初始化BM25
        self._init_bm25_index()
        # 初始化重排序模型
        self.reranker = CrossEncoder(Config.RERANK_MODEL) if Config.RERANK_MODEL else None
        
    def _init_bm25_index(self):
        """构建BM25倒排索引"""
        tokenized_corpus = [doc.page_content.split() for doc in self.documents]
        self.bm25 = BM25Okapi(tokenized_corpus)
        logger.info("BM25 index built successfully")

    def hybrid_search(self, query: str, alpha: float = Config.HYBRID_ALPHA, top_k: int = Config.RETRIEVAL_TOP_K) -> List[Tuple[Document, float]]:
        """混合检索:结合BM25关键词检索和稠密向量检索"""
        try:
            # 1. 稠密向量检索
            dense_results = self.vectorstore.similarity_search_with_score(query, k=top_k)
            
            # 2. BM25关键词检索
            tokenized_query = query.split()
            bm25_scores = self.bm25.get_scores(tokenized_query)
            
            # 归一化和融合分数
            dense_scores = [score for doc, score in dense_results]
            dense_normalized = self._normalize_scores(dense_scores)
            bm25_normalized = self._normalize_scores(bm25_scores)
            
            # 3. 加权融合
            hybrid_scores = {}
            for i, (doc, _) in enumerate(dense_results):
                hybrid_scores[doc.page_content] = alpha * dense_normalized[i] + (1 - alpha) * bm25_normalized[i]
            
            # 排序并返回
            sorted_results = sorted(hybrid_scores.items(), key=lambda x: x[1], reverse=True)
            return [(Document(page_content=text), score) for text, score in sorted_results]
            
        except Exception as e:
            logger.error(f"Hybrid search failed: {e}")
            return []

    def rerank(self, query: str, candidates: List[Tuple[Document, float]], top_k: int = Config.RERANK_TOP_K) -> List[Document]:
        """使用Cross-Encoder模型进行重排序,提升检索精准度"""
        if not candidates:
            return []
        
        # 准备rerank输入数据
        pairs = [(query, doc.page_content) for doc, score in candidates]
        # 计算相关性分数
        rerank_scores = self.reranker.predict(pairs)
        
        # 按新分数重新排序
        reranked = sorted(zip([doc for doc, _ in candidates], rerank_scores), key=lambda x: x[1], reverse=True)
        return [doc for doc, score in reranked[:top_k]]
    
    def _normalize_scores(self, scores: np.ndarray) -> np.ndarray:
        """分数归一化到[0,1]区间,便于加权融合"""
        min_score = np.min(scores)
        max_score = np.max(scores)
        if max_score == min_score:
            return np.ones_like(scores)
        return (scores - min_score) / (max_score - min_score)

📝 模块三:查询处理 (query_processor.py)

该模块通过查询改写 或使用HyDE(假设性文档嵌入) 技术,在检索前对用户的原始问题进行优化,以更好地匹配潜在的答案。

python 复制代码
# src/query_processor.py
import logging
from typing import List
from openai import OpenAI
from src.config import Config

logger = logging.getLogger(__name__)
client = OpenAI(api_key=Config.OPENAI_API_KEY)

class QueryProcessor:
    def __init__(self, use_hyde: bool = False):
        """初始化查询处理器"""
        self.use_hyde = use_hyde

    def expand_query(self, query: str) -> List[str]:
        """查询扩展:生成相关搜索词"""
        prompt = f"""
        Original query: {query}
        Generate 3 alternative search queries that are semantically similar but use different keywords.
        Return as a list, one per line.
        """
        try:
            response = client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.7
            )
            expanded = response.choices[0].message.content.strip().split('\n')
            return [query] + expanded
        except Exception as e:
            logger.error(f"Query expansion failed: {e}")
            return [query]

    def generate_hypothetical_document(self, query: str) -> str:
        """HyDE:生成假设性文档来弥补查询-文档语义鸿沟"""
        if not self.use_hyde:
            return query
            
        prompt = f"""
        Question: {query}
        Generate a comprehensive hypothetical answer that would address this question.
        The document should contain key information that could answer similar questions.
        """
        try:
            response = client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.3
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            logger.error(f"HyDE generation failed: {e}")
            return query

    def process_query(self, query: str) -> str:
        """完整的查询处理流程"""
        # 应用HyDE
        processed = self.generate_hypothetical_document(query)
        # 查询扩展通常与混合检索配合,但在此简化处理
        return processed

✍️ 模块四:答案生成 (generation.py)

该模块负责根据检索到的文档,高效地生成最终答案。它的核心是将一个 OpenAI 模型与一个上下文压缩步骤相结合,以确保输入给大模型的上下文是精简且相关的。

python 复制代码
# src/generation.py
import logging
from typing import List
from langchain.schema import Document
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers import ContextualCompressionRetriever
from src.config import Config

logger = logging.getLogger(__name__)

class AnswerGenerator:
    def __init__(self):
        """初始化答案生成器"""
        self.llm = ChatOpenAI(model=Config.GENERATION_MODEL, temperature=0.2, api_key=Config.OPENAI_API_KEY)
        self._setup_prompt_template()
        self._setup_context_compressor()

    def _setup_prompt_template(self):
        """设置生成提示词模板"""
        template = """Based on the following context, answer the question concisely and accurately.
        If the context doesn't contain the relevant information, say "I don't have enough information to answer that."
        
        Context: {context}
        
        Question: {question}
        
        Answer:"""
        self.prompt = PromptTemplate(template=template, input_variables=["context", "question"])

    def _setup_context_compressor(self):
        """设置上下文压缩器"""
        compressor = LLMChainExtractor.from_llm(self.llm)
        self.compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=None)

    def compress_context(self, query: str, documents: List[Document]) -> List[Document]:
        """LLM上下文压缩:从文档中提取最关键的部分"""
        try:
            # 模拟压缩检索器
            # 实际应用中,这里会调用压缩逻辑
            return documents  # 简化的实现
        except Exception as e:
            logger.error(f"Context compression failed: {e}")
            return documents

    def generate_answer(self, query: str, documents: List[Document]) -> str:
        """基于检索到的文档生成答案"""
        if not documents:
            return "No relevant documents found to answer the question."

        try:
            # 1. 压缩上下文
            compressed_docs = self.compress_context(query, documents)
            # 2. 构建上下文
            context = "\n\n".join([doc.page_content for doc in compressed_docs])
            # 3. 生成答案
            chain = LLMChain(llm=self.llm, prompt=self.prompt)
            answer = chain.run(context=context, question=query)
            logger.info(f"Answer generated: {answer[:100]}...")
            return answer.strip()
        except Exception as e:
            logger.error(f"Answer generation failed: {e}")
            return f"Answer generation failed: {str(e)}"

📊 模块五:系统评估 (evaluation.py)

该系统集成了 Ragas 评估框架,用于客观衡量检索和生成的质量。

python 复制代码
# src/evaluation.py
import logging
from typing import List, Dict
from ragas import evaluate
from ragas.metrics import context_relevancy, answer_relevancy, faithfulness
from datasets import Dataset
import pandas as pd
from langchain.schema import Document
from src.config import Config

logger = logging.getLogger(__name__)

class RAGEvaluator:
    def __init__(self):
        """初始化RAG评估器"""
        self.metrics = [
            context_relevancy,   # 上下文相关性
            answer_relevancy,    # 答案相关性
            faithfulness,        # 忠实度(是否幻觉)
        ]
    
    def create_evaluation_dataset(self, questions: List[str], contexts: List[List[str]], answers: List[str]) -> Dataset:
        """创建Ragas评估所需的数据集格式"""
        data = {
            "question": questions,
            "contexts": contexts,
            "answer": answers,
        }
        dataset = Dataset.from_dict(data)
        return dataset
    
    def evaluate_system(self, predictions: pd.DataFrame) -> Dict:
        """运行完整的系统评估"""
        try:
            # 运行Ragas评估
            result = evaluate(
                dataset=predictions,
                metrics=self.metrics,
            )
            
            # 解析结果
            scores = {
                "context_relevancy": round(result["context_relevancy"], 4),
                "answer_relevancy": round(result["answer_relevancy"], 4),
                "faithfulness": round(result["faithfulness"], 4),
            }
            
            logger.info(f"Evaluation completed: {scores}")
            return scores
            
        except Exception as e:
            logger.error(f"Evaluation failed: {e}")
            return {}
    
    def generate_test_cases(self, num_cases: int = Config.EVALUATION_DATASET_SIZE) -> List[str]:
        """生成测试问题集,用于后续评估"""
        # 实际应用中,这个函数应该从文档内容生成问题
        # 暂返回示例问题列表
        return [f"Test question {i}" for i in range(min(num_cases, 10))]

🚀 模块六:系统集成 (main.py)

最后,main.py 将所有模块串联起来,形成一个完整的、可运行的 RAG 管道。

python 复制代码
# src/main.py
import logging
from pathlib import Path
from typing import List
from langchain.schema import Document

from src.config import Config
from src.ingestion import DocumentIndexer
from src.retrieval import OptimizedRetriever
from src.query_processor import QueryProcessor
from src.generation import AnswerGenerator
from src.evaluation import RAGEvaluator

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class RAGSystem:
    def __init__(self):
        """初始化完整的RAG系统"""
        logger.info("Initializing RAG system...")
        
        # 初始化组件
        self.indexer = DocumentIndexer()
        self.query_processor = QueryProcessor(use_hyde=False)  # 是否启用HyDE
        self.generator = AnswerGenerator()
        self.evaluator = RAGEvaluator()
        
        # 非初始化状态下的组件
        self.retriever = None
        self.is_initialized = False
    
    def initialize_with_documents(self, document_paths: List[str]):
        """通过文档初始化整个系统"""
        try:
            # 1. 加载和索引文档
            logger.info("Loading documents...")
            documents = self.indexer.load_documents(document_paths)
            
            if not documents:
                raise ValueError("No documents loaded")
                
            logger.info(f"Loaded {len(documents)} documents")
            
            # 2. 创建向量索引
            self.indexer.create_index(documents)
            
            # 3. 初始化检索器
            split_docs = self.indexer.text_splitter.split_documents(documents)
            self.retriever = OptimizedRetriever(self.indexer.vectorstore, split_docs)
            
            self.is_initialized = True
            logger.info("RAG system initialized successfully")
            
        except Exception as e:
            logger.error(f"Initialization failed: {e}")
            raise
    
    def answer_question(self, question: str) -> str:
        """处理单个问答请求"""
        if not self.is_initialized:
            return "System not initialized"
        
        try:
            # 1. 查询处理
            processed_query = self.query_processor.process_query(question)
            logger.info(f"Original: {question} -> Processed: {processed_query}")
            
            # 2. 混合检索 + 重排序
            candidates = self.retriever.hybrid_search(processed_query)
            retrieved_docs = self.retriever.rerank(question, candidates)
            
            logger.info(f"Retrieved {len(retrieved_docs)} relevant documents")
            
            # 3. 生成最终答案
            answer = self.generator.generate_answer(question, retrieved_docs)
            return answer
            
        except Exception as e:
            logger.error(f"Question answering failed: {e}")
            return f"Error: {str(e)}"
    
    def batch_answer(self, questions: List[str]) -> List[str]:
        """批量处理问题"""
        return [self.answer_question(q) for q in questions]

# 使用示例
if __name__ == "__main__":
    system = RAGSystem()
    document_files = ["data/sample.txt", "data/manual.pdf"]
    system.initialize_with_documents(document_files)
    
    test_questions = [
        "What are the key features of the product?",
        "How to install the software?",
    ]
    
    for q in test_questions:
        answer = system.answer_question(q)
        print(f"Q: {q}")
        print(f"A: {answer}\n")

5. 运行和测试

  1. 安装依赖:pip install -r requirements.txt
  2. 将文档放入 data/ 目录,并修改 main.py 中的文件路径。
  3. 运行建索引脚本:python scripts/index_data.py
  4. 启动服务进行问答:python src/main.py
  5. 运行评估脚本:python scripts/run_evaluation.py

这个结构清晰的代码库,为你提供了一个可直接运行的 RAG 系统,并展示了如何模块化地集成高级优化技术。

相关推荐
一休哥助手1 小时前
2026年5月12日人工智能早间新闻
人工智能
名不经传的养虾人1 小时前
从0到1:企业级AI项目迭代日记 Vol.19|两个环节 vs 十几个环节:Hermes厉害在哪里?
大数据·人工智能·ai编程·企业ai·多agent协作
茶马古道的搬运工1 小时前
AI 深度技能之-Agent 工具调度设计(一)-核心概念
人工智能
user29876982706541 小时前
五、AI Agent 设计模式:子 Agent 架构
人工智能
人月神话-Lee1 小时前
【图像处理】坐标系与图像加载——UIImage 是怎么变成内存像素的
图像处理·人工智能
PanShanShan1 小时前
我把 Claude Code 发掘金的 token 成本砍到 1/50:web-publish 设计实录
人工智能
Hector_zh1 小时前
容器化部署踩坑记:测试环境 Git 凭证外挂方案验证
人工智能·ai编程
cici158741 小时前
基于 BP 神经网络的语音信号分类系统
人工智能·神经网络·分类
AI街潜水的八角1 小时前
PyTorch框架——基于深度学习SRN-DeblurNet神经网络AI去模糊图像增强系统
人工智能·pytorch·深度学习