LLM应用开发之RAG检索增强生成详解

摘要: 检索增强生成(Retrieval-Augmented Generation,RAG)是一种将大规模语言模型与外部知识检索系统相结合的技术架构,旨在解决大模型幻觉、知识截止和领域知识缺失等问题。本文详细介绍了RAG的核心概念、工作流程、关键技术组件(包括文档分块、Embedding模型、向量数据库和检索策略),并提供了完整的Python代码示例,帮助开发者快速构建生产级RAG应用。实验表明,合理设计的RAG系统能够显著提升大模型在专业领域的问答准确率,是企业知识库、智能客服、文档问答等场景的核心技术方案。

关键词: RAG;检索增强生成;向量数据库;Embedding;知识库问答;LangChain;混合检索


一、为什么需要RAG?

1.1 大语言模型的固有局限

大语言模型(LLM)自诞生以来,在自然语言理解和生成任务上展现了惊人的能力。然而,LLM存在三个根本性局限:

幻觉问题(Hallucination):LLM在生成内容时,可能会产生看似合理但实际错误的信息。这是因为模型本质上是统计概率模型,在缺乏明确依据时容易"一本正经地胡说八道"。

知识截止问题(Knowledge Cutoff):模型的训练数据有固定的时间节点,无法获取最新信息。例如,一个在2023年训练的模型,无法回答2024年发生的最新事件。

领域知识缺失:通用LLM对企业内部知识、专业技术文档、特定业务规则等私有数据一无所知。用通用模型回答专业问题,往往缺乏针对性和准确性。

1.2 RAG vs Fine-tuning:如何选择?

面对上述问题,开发者通常有两种解决方案:微调(Fine-tuning)和检索增强生成(RAG)。下表对比了两者的核心差异:

对比维度 RAG Fine-tuning
知识更新 可实时更新,无需重新训练 需要重新训练或增量训练
部署成本 低(只需部署检索系统) 高(需要GPU进行训练和部署)
可解释性 高(答案可溯源到具体文档) 低(知识隐含在模型参数中)
** hallucination控制** 优秀(严格基于检索内容) 较弱(仍可能产生幻觉)
适用场景 知识库频繁更新、数据量大 任务模式固定、需要风格适配
训练数据需求 少量标注数据即可 需要大量高质量训练数据

如何选择? 简单来说,如果你的数据频繁更新或数据量巨大,优先选择RAG;如果你需要模型适配特定任务风格(如特定对话语气、输出格式),可以考虑Fine-tuning。在实际企业应用中,RAG是最主流的方案,因为它兼具效果和成本优势。


二、RAG工作流程详解

RAG的核心工作流程分为三个阶段:索引(Indexing)→ 检索(Retrieval)→ 生成(Generation)

复制代码
┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│   索引阶段   │ →  │   检索阶段   │ →  │   生成阶段   │
│  Indexing   │    │  Retrieval  │    │ Generation  │
└─────────────┘    └─────────────┘    └─────────────┘

2.1 索引阶段

索引阶段是将文档转化为可检索格式的过程。具体步骤如下:

  1. 文档加载:从PDF、Word、网页等来源读取文本内容

  2. 文本分块:将长文档切分为适合检索的小段落(Chunk)

  3. 向量化:使用Embedding模型将每个Chunk转化为稠密向量

  4. 存储:将向量及其对应的原始文本存储到向量数据库中

2.2 检索阶段

当用户提出问题时,系统执行以下操作:

  1. 将用户问题同样转化为向量(使用相同的Embedding模型)

  2. 在向量数据库中计算相似度,找出最相关的Top-K个Chunk

  3. 可选:使用重排序(Reranking)算法进一步优化排序结果

2.3 生成阶段

  1. 将用户问题与检索到的相关文档组合成提示词(Prompt)

  2. 将提示词发送给LLM,生成最终答案

  3. 可选:将生成答案与源文档关联,支持引用溯源


三、文档处理与分块策略

文档分块(Chunking)是RAG系统中至关重要的环节,直接影响检索质量和生成效果。

3.1 固定大小分块

最简单的分块策略,按照固定字符数或token数进行切分:

复制代码
def fixed_size_chunk(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
    """
    固定大小分块
    
    Args:
        text: 输入文本
        chunk_size: 每个块的目标大小(字符数)
        overlap: 相邻块之间的重叠字符数
    
    Returns:
        分块后的文本列表
    """
    chunks = []
    start = 0
    text_length = len(text)
    
    while start < text_length:
        # 计算当前块的结束位置
        end = start + chunk_size
        
        # 如果不是最后一块,尽量在句子边界处截断
        if end < text_length:
            # 向前寻找最后一个句子结束符(。!?.!?)
            for i in range(end, max(start, end - 100), -1):
                if text[i] in '。!?.!?':
                    end = i + 1
                    break
        
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)
        
        # 移动起始位置(考虑重叠)
        start = end - overlap if end < text_length else text_length
    
    return chunks
​
# 使用示例
sample_text = "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。自诞生以来,人工智能理论和技术日益成熟,应用领域也不断扩大。可以设想,未来人工智能带来的科技产品,将会是人类智慧的"容器"。人工智能可以对人的意识、思维的信息过程进行模拟。人工智能不是人的智能,但能像人那样思考、也可能超过人的智能。"
chunks = fixed_size_chunk(sample_text, chunk_size=100, overlap=20)
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1}: {chunk}")

3.2 递归字符分割

递归字符分割(Recursive Character Text Splitting)是一种更智能的分块方式,它按照优先级尝试不同的分隔符进行切分:

复制代码
import re
​
class RecursiveCharacterTextSplitter:
    """
    递归字符文本分割器
    
    按照优先级尝试不同的分隔符进行文本切分,
    确保每个块都在目标大小范围内,同时尽量保持语义完整性
    """
    
    def __init__(
        self,
        separators: list[str] = None,
        chunk_size: int = 500,
        overlap: int = 50,
        length_function: callable = len
    ):
        """
        初始化分割器
        
        Args:
            separators: 分隔符列表,按优先级排序
            chunk_size: 目标块大小
            overlap: 块之间的重叠大小
            length_function: 计算文本长度的函数
        """
        if separators is None:
            # 默认分隔符列表,按优先级排序
            separators = ['\n\n', '\n', '。', '!', '?', '. ', '! ', '? ', ' ', '']
        
        self.separators = separators
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.length_function = length_function
    
    def split_text(self, text: str) -> list[str]:
        """
        递归分割文本
        
        Args:
            text: 输入文本
        
        Returns:
            分割后的文本块列表
        """
        final_chunks = []
        
        # 递归分割的辅助函数
        def get_chunks(text: str, separator: str) -> list[str]:
            if separator == '':
                # 最终级别,按字符数直接分割
                return [text[i:i+self.chunk_size] 
                        for i in range(0, len(text), self.chunk_size - self.overlap)]
            
            # 按当前分隔符分割
            parts = text.split(separator)
            
            chunks = []
            current_chunk = ""
            
            for part in parts:
                # 如果加上当前部分会超过目标大小,先保存当前chunk
                if self.length_function(current_chunk + separator + part) > self.chunk_size:
                    if current_chunk:
                        chunks.append(current_chunk.strip())
                        # 开始新chunk,保留overlap部分
                        current_chunk = current_chunk[-self.overlap:] if self.overlap > 0 else ""
                
                if current_chunk:
                    current_chunk += separator + part
                else:
                    current_chunk = part
            
            # 处理最后一块
            if current_chunk.strip():
                chunks.append(current_chunk.strip())
            
            return chunks
        
        # 按分隔符优先级逐层尝试分割
        for separator in self.separators:
            if separator in text:
                # 找到合适的分隔符,进行分割
                parts = text.split(separator)
                result = []
                
                for part in parts:
                    # 如果单个部分仍然过长,递归使用更细的分隔符
                    if self.length_function(part) > self.chunk_size:
                        nested_chunks = self.split_text(part)
                        result.extend(nested_chunks)
                    else:
                        result.append(part)
                
                # 合并过小的块
                final_chunks = []
                buffer = ""
                
                for chunk in result:
                    if self.length_function(buffer + separator + chunk) <= self.chunk_size:
                        buffer = (buffer + separator + chunk).strip() if buffer else chunk
                    else:
                        if buffer:
                            final_chunks.append(buffer)
                        # 考虑重叠
                        buffer = chunk[-self.overlap:] + separator + chunk if self.overlap > 0 and buffer else chunk
                
                if buffer:
                    final_chunks.append(buffer)
                
                return final_chunks
        
        # 如果所有分隔符都无法分割,直接按字符数分割
        return get_chunks(text, '')
​
# 使用示例
splitter = RecursiveCharacterTextSplitter(chunk_size=100, overlap=20)
text = """机器学习是人工智能的一个重要分支,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。
​
机器学习算法可以分为多种类型,包括监督学习、无监督学习和强化学习。监督学习是最常见的类型,它使用标记的训练数据进行学习,常见的算法包括决策树、支持向量机和神经网络等。
​
无监督学习则不需要标记数据,它的目标是发现数据中的隐藏模式和结构。聚类和降维是无监督学习的两个主要任务。强化学习则是通过与环境交互来学习最优策略,常见的应用包括游戏和机器人控制。"""
​
chunks = splitter.split_text(text)
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1} ({len(chunk)}字符): {chunk}\n")

3.3 分块大小的选择

分块大小的选择需要权衡以下因素:

Chunk大小 优点 缺点
过小(如100字符) 检索精准、主题集中 上下文信息不足,可能丢失重要语义
适中(如300-500字符) 平衡检索精度和语义完整 需要精细调优
过大(如1000+字符) 语义完整、上下文丰富 检索精度下降、引入噪声

经验法则

  • 通用场景:300-500字符(中文)或500-1000 token(英文)

  • 问答系统:200-400字符

  • 摘要任务:500-1000字符

  • 跨语言场景:适当增大,考虑翻译损失

3.4 元数据提取

元数据对于精准检索和结果过滤非常重要:

复制代码
from datetime import datetime
from typing import Optional
​
class Document:
    """文档对象,包含内容和元数据"""
    
    def __init__(
        self,
        content: str,
        metadata: Optional[dict] = None,
        embedding: Optional[list[float]] = None
    ):
        self.content = content
        self.metadata = metadata or {}
        self.embedding = embedding
    
    def add_metadata(self, key: str, value: any) -> 'Document':
        """添加元数据"""
        self.metadata[key] = value
        return self
    
    def to_dict(self) -> dict:
        return {
            'content': self.content,
            'metadata': self.metadata
        }
​
class MetadataExtractor:
    """元数据提取器"""
    
    @staticmethod
    def extract_from_file(file_path: str, content: str) -> dict:
        """从文件路径提取元数据"""
        import os
        filename = os.path.basename(file_path)
        name, ext = os.path.splitext(filename)
        
        return {
            'source': file_path,
            'filename': filename,
            'file_type': ext,
            'title': name,
            'created_at': datetime.now().isoformat(),
            'char_count': len(content),
            'word_count': len(content.split())
        }
    
    @staticmethod
    def extract_from_text(content: str, source: str = 'unknown') -> dict:
        """从文本内容提取元数据"""
        return {
            'source': source,
            'created_at': datetime.now().isoformat(),
            'char_count': len(content),
            'word_count': len(content.split()),
            'preview': content[:200] + '...' if len(content) > 200 else content
        }
​
# 使用示例
extractor = MetadataExtractor()
metadata = extractor.extract_from_file('/documents/技术文档.pdf', '这是一段技术文档内容...')
metadata['category'] = '技术文档'
metadata['author'] = '张三'
metadata['version'] = '1.0'
​
doc = Document('这是一段技术文档内容...', metadata)
print(f"元数据: {doc.metadata}")

四、Embedding模型选择

Embedding模型是将文本转化为稠密向量的核心组件,其质量直接影响检索效果。

4.1 OpenAI Embeddings

OpenAI提供了成熟的Embedding服务,支持多种模型:

复制代码
from openai import OpenAI
from typing import list
import numpy as np
​
class OpenAIEmbeddings:
    """OpenAI Embedding封装"""
    
    # 支持的模型列表
    MODELS = {
        'text-embedding-3-small': 1536,  # 新版小型模型,性价比高
        'text-embedding-3-large': 3072,  # 新版大型模型,效果更好
        'text-embedding-ada-002': 1536   # 旧版模型,兼容性好
    }
    
    def __init__(self, api_key: str, model: str = 'text-embedding-3-small'):
        """
        初始化OpenAI Embedding客户端
        
        Args:
            api_key: OpenAI API密钥
            model: 使用的模型名称
        """
        self.client = OpenAI(api_key=api_key)
        self.model = model
        self.dimension = self.MODELS.get(model, 1536)
    
    def embed_query(self, text: str) -> list[float]:
        """
        将单个文本转化为向量
        
        Args:
            text: 输入文本
        
        Returns:
            文本的向量表示
        """
        response = self.client.embeddings.create(
            model=self.model,
            input=text
        )
        return response.data[0].embedding
    
    def embed_documents(self, texts: list[str], batch_size: int = 100) -> list[list[float]]:
        """
        批量将文本转化为向量
        
        Args:
            texts: 文本列表
            batch_size: 批处理大小
        
        Returns:
            文本向量列表
        """
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            response = self.client.embeddings.create(
                model=self.model,
                input=batch
            )
            all_embeddings.extend([item.embedding for item in response.data])
        
        return all_embeddings
    
    def cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
        """计算两个向量的余弦相似度"""
        vec1 = np.array(vec1)
        vec2 = np.array(vec2)
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
​
# 使用示例
# 注意:实际使用时需要替换为有效的API Key
# embeddings = OpenAIEmbeddings(api_key="sk-...")
# query_vector = embeddings.embed_query("什么是人工智能?")
# doc_vectors = embeddings.embed_documents(["文本1", "文本2", "文本3"])

4.2 开源Embedding模型

国内常用的开源Embedding模型有BGE、M3E等,可以本地部署:

复制代码
from sentence_transformers import SentenceTransformer
from typing import list
import numpy as np
​
class LocalEmbeddings:
    """本地Embedding模型封装,支持开源模型"""
    
    # 常用开源模型配置
    MODEL_CONFIGS = {
        'bge-base-zh': {
            'model_name': 'BAAI/bge-base-zh-v1.5',
            'dimension': 768,
            'description': '中文BGE基础模型,效果好,推荐使用'
        },
        'bge-small-zh': {
            'model_name': 'BAAI/bge-small-zh-v1.5',
            'dimension': 512,
            'description': '中文BGE小型模型,速度快,资源占用低'
        },
        'm3e-base': {
            'model_name': 'moka-ai/m3e-base',
            'dimension': 768,
            'description': 'M3E基础模型,支持中英文'
        }
    }
    
    def __init__(self, model_name: str = 'bge-base-zh', device: str = 'cpu'):
        """
        初始化本地Embedding模型
        
        Args:
            model_name: 模型名称或本地路径
            device: 推理设备 ('cpu', 'cuda', 'mps')
        """
        if model_name in self.MODEL_CONFIGS:
            config = self.MODEL_CONFIGS[model_name]
            self.model_name = config['model_name']
            self.dimension = config['dimension']
        else:
            self.model_name = model_name
            self.dimension = 768  # 默认维度
        
        # 加载模型
        self.model = SentenceTransformer(self.model_name, device=device)
        print(f"模型加载完成: {self.model_name}, 维度: {self.dimension}")
    
    def embed_query(self, text: str) -> list[float]:
        """将单个查询文本转化为向量"""
        embedding = self.model.encode(text, normalize_embeddings=True)
        return embedding.tolist()
    
    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        """批量将文档转化为向量"""
        embeddings = self.model.encode(texts, normalize_embeddings=True, show_progress_bar=True)
        return embeddings.tolist()
    
    def cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
        """计算余弦相似度"""
        vec1 = np.array(vec1)
        vec2 = np.array(vec2)
        return np.dot(vec1, vec2)
​
# 使用示例(需要先安装: pip install sentence-transformers)
# embeddings = LocalEmbeddings(model_name='bge-base-zh', device='cuda')
# query_vector = embeddings.embed_query("什么是机器学习?")
# doc_vectors = embeddings.embed_documents(["文档1内容", "文档2内容"])

4.3 模型选择指南

模型 维度 优势 适用场景
OpenAI text-embedding-3-large 3072 效果最好、稳定性强 生产环境、高质量需求
OpenAI text-embedding-3-small 1536 性价比高、速度快 中等质量需求、成本敏感
BGE-base-zh-v1.5 768 开源、中文优化、可本地部署 中文场景、私有化部署
M3E-base 768 开源、中英文兼顾 中英文混合场景

选择建议

  • 优先使用OpenAI官方模型,效果有保障

  • 中文私有化部署场景,推荐BGE或M3E

  • 注意向量维度与向量数据库的兼容性


五、向量数据库

向量数据库是存储和检索高维向量的专用数据库,是RAG系统的核心存储组件。

5.1 Chroma(轻量级方案)

Chroma是一个专为AI应用设计的轻量级向量数据库,易于上手,适合原型开发和小型项目:

复制代码
import chromadb
from chromadb.config import Settings
from typing import list, Optional
​
class ChromaVectorStore:
    """Chroma向量数据库封装"""
    
    def __init__(self, persist_directory: str = './chroma_db'):
        """
        初始化Chroma客户端
        
        Args:
            persist_directory: 数据持久化目录
        """
        self.client = chromadb.PersistentClient(path=persist_directory)
        self.collection = None
    
    def create_collection(self, name: str = 'documents', metadata: Optional[dict] = None):
        """创建或获取集合"""
        self.collection = self.client.get_or_create_collection(
            name=name,
            metadata=metadata or {'description': '文档向量集合'}
        )
        print(f"集合已创建/获取: {name}, 现有数据量: {self.collection.count()}")
    
    def add_documents(
        self,
        ids: list[str],
        embeddings: list[list[float]],
        documents: list[str],
        metadata: Optional[list[dict]] = None
    ):
        """
        添加文档到集合
        
        Args:
            ids: 文档ID列表
            embeddings: 向量列表
            documents: 原始文档内容
            metadata: 元数据列表
        """
        if self.collection is None:
            raise ValueError("请先调用create_collection创建集合")
        
        self.collection.add(
            ids=ids,
            embeddings=embeddings,
            documents=documents,
            metadatas=metadata or [{} for _ in range(len(documents))]
        )
        print(f"已添加 {len(documents)} 个文档")
    
    def similarity_search(
        self,
        query_embedding: list[float],
        n_results: int = 5,
        where: Optional[dict] = None,
        where_document: Optional[dict] = None
    ) -> dict:
        """
        相似度搜索
        
        Args:
            query_embedding: 查询向量
            n_results: 返回结果数量
            where: 元数据过滤条件
            where_document: 文档内容过滤条件
        
        Returns:
            搜索结果字典
        """
        if self.collection is None:
            raise ValueError("请先调用create_collection创建集合")
        
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            where=where,
            where_document=where_document
        )
        
        return {
            'ids': results['ids'][0],
            'distances': results['distances'][0],
            'documents': results['documents'][0],
            'metadatas': results['metadatas'][0]
        }
    
    def delete(self, ids: list[str]):
        """删除指定ID的文档"""
        if self.collection:
            self.collection.delete(ids=ids)
            print(f"已删除 {len(ids)} 个文档")
​
# 使用示例
# 初始化
store = ChromaVectorStore(persist_directory='./my_chroma_db')
store.create_collection(name='tech_docs')
​
# 添加文档
doc_ids = ['doc1', 'doc2', 'doc3']
doc_embeddings = [[0.1, 0.2, ...], [0.3, 0.4, ...], [0.5, 0.6, ...]]  # 实际为Embedding模型生成
doc_contents = [
    '人工智能是计算机科学的一个分支',
    '机器学习是人工智能的重要分支',
    '深度学习是机器学习的一个研究方向'
]
doc_metadata = [
    {'category': 'AI', 'author': '张三'},
    {'category': 'ML', 'author': '李四'},
    {'category': 'DL', 'author': '王五'}
]
​
store.add_documents(doc_ids, doc_embeddings, doc_contents, doc_metadata)
​
# 检索
query_vector = [0.15, 0.25, ...]  # 实际为Embedding模型生成
results = store.similarity_search(query_vector, n_results=2)
print(f"检索到 {len(results['documents'])} 个相关文档")
for i, doc in enumerate(results['documents']):
    print(f"文档 {i+1}: {doc}, 距离: {results['distances'][i]:.4f}")

5.2 FAISS(Facebook开源方案)

FAISS是Facebook开源的高效向量检索库,适合大规模数据:

复制代码
import faiss
import numpy as np
from typing import list, Tuple
​
class FAISSVectorStore:
    """FAISS向量数据库封装"""
    
    def __init__(self, dimension: int, index_type: str = 'IVF'):
        """
        初始化FAISS索引
        
        Args:
            dimension: 向量维度
            index_type: 索引类型 ('Flat', 'IVF', 'HNSW', 'PQ')
        """
        self.dimension = dimension
        self.index_type = index_type
        self.index = None
        self.id_map = {}  # 存储原始ID和索引的映射
        self.documents = {}  # 存储原始文档内容
    
    def _create_index(self, nlist: int = 100):
        """创建索引"""
        if self.index_type == 'Flat':
            # 精确搜索,无索引结构
            self.index = faiss.IndexFlatIP(self.dimension)  # 内积相似度
        elif self.index_type == 'IVF':
            # 倒排索引,适合大数据集
            quantizer = faiss.IndexFlatIP(self.dimension)
            self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
        elif self.index_type == 'HNSW':
            # 分层可导航小世界图,高速搜索
            self.index = faiss.IndexHNSWFlat(self.dimension, 32)  # 32为连接数
        elif self.index_type == 'PQ':
            # 产品量化,大幅压缩内存
            self.index = faiss.IndexPQ(self.dimension, 16, 8)  # 16字节码,8个分组
        else:
            raise ValueError(f"不支持的索引类型: {self.index_type}")
        
        print(f"FAISS索引创建成功: {self.index_type}, 维度: {self.dimension}")
    
    def build(self, documents: list[str], embeddings: list[list[float]], ids: list[str]):
        """
        构建索引
        
        Args:
            documents: 文档内容列表
            embeddings: 向量列表
            ids: 文档ID列表
        """
        self._create_index()
        
        # 转换为numpy数组
        embeddings_array = np.array(embeddings).astype('float32')
        
        # 归一化(用于余弦相似度)
        faiss.normalize_L2(embeddings_array)
        
        # 训练索引(IVF和PQ需要训练)
        if hasattr(self.index, 'is_trained') and not self.index.is_trained:
            self.index.train(embeddings_array)
        
        # 添加向量
        self.index.add(embeddings_array)
        
        # 保存映射
        for i, (doc_id, doc) in enumerate(zip(ids, documents)):
            self.id_map[i] = doc_id
            self.documents[doc_id] = doc
        
        print(f"索引构建完成: {len(documents)} 个文档")
    
    def search(self, query_embedding: list[float], k: int = 5) -> list[dict]:
        """
        搜索相似向量
        
        Args:
            query_embedding: 查询向量
            k: 返回结果数量
        
        Returns:
            搜索结果列表
        """
        if self.index is None:
            raise ValueError("索引未构建,请先调用build方法")
        
        # 转换查询向量
        query = np.array([query_embedding]).astype('float32')
        faiss.normalize_L2(query)
        
        # 执行搜索
        distances, indices = self.index.search(query, k)
        
        # 整理结果
        results = []
        for distance, idx in zip(distances[0], indices[0]):
            if idx >= 0 and idx in self.id_map:
                doc_id = self.id_map[idx]
                results.append({
                    'id': doc_id,
                    'document': self.documents[doc_id],
                    'distance': float(distance)
                })
        
        return results
​
# 使用示例
# 初始化
dimension = 768  # Embedding向量维度
store = FAISSVectorStore(dimension=dimension, index_type='HNSW')
​
# 准备数据
documents = [
    '人工智能是计算机科学的一个分支',
    '机器学习是人工智能的重要分支',
    '深度学习是机器学习的一个研究方向'
]
# 实际使用时用Embedding模型生成
embeddings = np.random.rand(3, dimension).astype('float32')
ids = ['doc1', 'doc2', 'doc3']
​
# 构建索引
store.build(documents, embeddings, ids)
​
# 搜索
query = np.random.rand(dimension).astype('float32')
results = store.search(query, k=2)
print(f"检索到 {len(results)} 个结果:")
for r in results:
    print(f"  ID: {r['id']}, 距离: {r['distance']:.4f}")

5.3 Milvus(生产级方案)

Milvus是专门面向AI应用的大规模向量数据库,支持分布式部署:

复制代码
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
from pymilvus import utility
from typing import list, Optional
import numpy as np
​
class MilvusVectorStore:
    """Milvus向量数据库封装"""
    
    def __init__(
        self,
        host: str = 'localhost',
        port: str = '19530',
        collection_name: str = 'documents'
    ):
        """
        初始化Milvus连接
        
        Args:
            host: Milvus服务器地址
            port: Milvus服务器端口
            collection_name: 集合名称
        """
        self.collection_name = collection_name
        self.collection: Optional[Collection] = None
        
        # 连接Milvus
        connections.connect(host=host, port=port)
        print(f"已连接到Milvus服务器: {host}:{port}")
    
    def create_collection(self, dimension: int, description: str = ''):
        """创建集合"""
        # 定义字段
        fields = [
            FieldSchema(name='id', dtype=DataType.VARCHAR, max_length=64, is_primary=True),
            FieldSchema(name='document', dtype=DataType.VARCHAR, max_length=65535),
            FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=dimension),
            FieldSchema(name='category', dtype=DataType.VARCHAR, max_length=64)
        ]
        
        schema = CollectionSchema(fields=fields, description=description)
        
        # 创建集合
        if utility.has_collection(self.collection_name):
            utility.drop_collection(self.collection_name)
        
        self.collection = Collection(name=self.collection_name, schema=schema)
        
        # 创建索引
        index_params = {
            'index_type': 'IVF_FLAT',
            'metric_type': 'IP',  # 内积相似度
            'params': {'nlist': 128}
        }
        self.collection.create_index(field_name='embedding', index_params=index_params)
        self.collection.load()
        
        print(f"集合已创建: {self.collection_name}, 维度: {dimension}")
    
    def insert(self, ids: list[str], documents: list[str], embeddings: list[list[float]], categories: list[str]):
        """
        插入数据
        
        Args:
            ids: 文档ID列表
            documents: 文档内容列表
            embeddings: 向量列表
            categories: 类别列表
        """
        if self.collection is None:
            raise ValueError("集合未创建,请先调用create_collection")
        
        data = [
            ids,
            documents,
            embeddings,
            categories
        ]
        
        result = self.collection.insert(data)
        self.collection.flush()
        print(f"已插入 {len(ids)} 个文档")
        return result
    
    def search(self, query_embedding: list[float], top_k: int = 5, category: Optional[str] = None) -> list[dict]:
        """
        向量搜索
        
        Args:
            query_embedding: 查询向量
            top_k: 返回数量
            category: 可选的类别过滤
        
        Returns:
            搜索结果
        """
        if self.collection is None:
            raise ValueError("集合未创建,请先调用create_collection")
        
        search_params = {'metric_type': 'IP', 'params': {'nprobe': 10}}
        
        # 构建搜索表达式
        expr = f'category == "{category}"' if category else None
        
        results = self.collection.search(
            data=[query_embedding],
            anns_field='embedding',
            param=search_params,
            limit=top_k,
            expr=expr,
            output_fields=['id', 'document', 'category']
        )
        
        # 整理结果
        search_results = []
        for hits in results:
            for hit in hits:
                search_results.append({
                    'id': hit.entity.get('id'),
                    'document': hit.entity.get('document'),
                    'category': hit.entity.get('category'),
                    'distance': hit.distance
                })
        
        return search_results
    
    def delete_by_ids(self, ids: list[str]):
        """根据ID删除文档"""
        if self.collection:
            expr = f'id in {ids}'
            self.collection.delete(expr)
            self.collection.flush()
            print(f"已删除 {len(ids)} 个文档")
    
    def close(self):
        """关闭连接"""
        connections.disconnect()
        print("已断开Milvus连接")
​
# 使用示例
# 注意:需要Milvus服务器运行中
# store = MilvusVectorStore(host='localhost', port='19530', collection_name='tech_docs')
# store.create_collection(dimension=768, description='技术文档集合')
# store.insert(ids, documents, embeddings, categories)
# results = store.search(query_vector, top_k=5, category='AI')
# store.close()

5.4 向量数据库对比

数据库 特点 适用场景 部署难度
Chroma 轻量级、易用、内嵌式 原型开发、小型项目 ⭐ 简单
FAISS 高性能、支持GPU加速 中等规模、需本地部署 ⭐⭐ 中等
Milvus 分布式、可扩展、云原生 生产环境、大规模数据 ⭐⭐⭐ 复杂
Pinecone 全托管、云服务 快速上线、无运维 ⭐ 简单(付费)
Weaviate 混合搜索、原生GraphQL 混合检索场景 ⭐⭐ 中等

六、检索策略

检索是RAG系统的核心环节,直接决定生成效果的上限。

6.1 相似度搜索与Top-K召回

复制代码
import numpy as np
from typing import list, Tuple
​
class VectorRetriever:
    """向量检索器"""
    
    def __init__(self, dimension: int, index_type: str = 'cosine'):
        """
        初始化检索器
        
        Args:
            dimension: 向量维度
            index_type: 相似度类型 ('cosine', 'euclidean', 'dot')
        """
        self.dimension = dimension
        self.index_type = index_type
        self.documents = []
        self.embeddings = None
    
    def add_documents(self, documents: list[str], embeddings: list[list[float]]):
        """添加文档"""
        self.documents.extend(documents)
        if self.embeddings is None:
            self.embeddings = np.array(embeddings)
        else:
            self.embeddings = np.vstack([self.embeddings, embeddings])
        
        # 归一化(用于余弦相似度)
        if self.index_type == 'cosine':
            norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
            self.embeddings = self.embeddings / (norms + 1e-8)
    
    def _compute_similarity(self, query_vec: np.ndarray) -> np.ndarray:
        """计算查询向量与所有文档向量的相似度"""
        if self.index_type == 'cosine':
            return np.dot(self.embeddings, query_vec)
        elif self.index_type == 'euclidean':
            return -np.linalg.norm(self.embeddings - query_vec, axis=1)
        elif self.index_type == 'dot':
            return np.dot(self.embeddings, query_vec)
        else:
            raise ValueError(f"不支持的相似度类型: {self.index_type}")
    
    def retrieve(self, query: str, query_embedding: list[float], top_k: int = 5) -> list[dict]:
        """
        检索最相关的文档
        
        Args:
            query: 查询文本
            query_embedding: 查询向量
            top_k: 返回数量
        
        Returns:
            检索结果列表,按相似度降序排列
        """
        query_vec = np.array(query_embedding)
        if self.index_type == 'cosine':
            query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
        
        # 计算相似度
        similarities = self._compute_similarity(query_vec)
        
        # 获取Top-K索引
        top_indices = np.argsort(similarities)[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            results.append({
                'index': int(idx),
                'document': self.documents[idx],
                'similarity': float(similarities[idx]),
                'embedding': self.embeddings[idx].tolist()
            })
        
        return results
​
# 使用示例
retriever = VectorRetriever(dimension=768, index_type='cosine')
​
# 模拟文档数据
docs = [
    '人工智能是计算机科学的一个分支,致力于研究智能的实质。',
    '机器学习是人工智能的重要分支,使用数据来提升模型性能。',
    '深度学习是机器学习的一个研究方向,使用多层神经网络。',
    '自然语言处理是人工智能的另一个重要分支。',
    '计算机视觉是研究如何让机器看懂图像的学科。'
]
​
# 模拟的embeddings(实际使用时用Embedding模型生成)
embeddings = np.random.rand(5, 768).astype('float32')
# 让第二个文档与机器学习query更相关
embeddings[1] = embeddings[0] + np.random.rand(768) * 0.1
​
retriever.add_documents(docs, embeddings)
​
# 检索
query_embedding = np.random.rand(768).astype('float32)
results = retriever.retrieve('机器学习', query_embedding, top_k=3)
​
print("检索结果:")
for i, r in enumerate(results):
    print(f"{i+1}. [相似度: {r['similarity']:.4f}] {r['document']}")

6.2 混合检索策略

混合检索结合了向量检索和关键词检索的优势,能够处理不同类型的查询:

复制代码
from typing import List, Dict, Tuple, Optional
import re
from collections import Counter
​
class BM25KeywordSearch:
    """BM25关键词搜索实现"""
    
    def __init__(self, k1: float = 1.5, b: float = 0.75):
        """
        初始化BM25
        
        Args:
            k1: 词频饱和参数
            b: 文档长度归一化参数
        """
        self.k1 = k1
        self.b = b
        self.documents = []
        self.doc_lengths = []
        self.avg_doc_length = 0
        self.doc_freqs = {}  # 词频统计
        self.idf = {}  # 逆文档频率
        self.vocab = set()
    
    def _tokenize(self, text: str) -> List[str]:
        """简单分词(实际可用jieba等更专业的库)"""
        # 简单的中文分词,按标点和空格分割
        text = re.sub(r'[^\w\s]', ' ', text)
        tokens = text.lower().split()
        # 过滤停用词
        stopwords = {'的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这', '那', '么'}
        return [t for t in tokens if t not in stopwords and len(t) > 1]
    
    def fit(self, documents: List[str]):
        """构建索引"""
        self.documents = documents
        self.doc_lengths = [len(self._tokenize(doc)) for doc in documents]
        self.avg_doc_length = sum(self.doc_lengths) / len(documents) if documents else 1
        
        # 统计文档频率
        for doc in documents:
            tokens = set(self._tokenize(doc))
            for token in tokens:
                self.vocab.add(token)
                self.doc_freqs[token] = self.doc_freqs.get(token, 0) + 1
        
        # 计算IDF
        n = len(documents)
        for token, df in self.doc_freqs.items():
            self.idf[token] = max(0, (n - df + 0.5) / (df + 0.5))
    
    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        """搜索"""
        query_tokens = self._tokenize(query)
        doc_scores = []
        
        for i, doc in enumerate(self.documents):
            doc_tokens = self._tokenize(doc)
            doc_token_freqs = Counter(doc_tokens)
            
            score = 0.0
            for qt in query_tokens:
                if qt in doc_token_freqs:
                    tf = doc_token_freqs[qt]
                    idf = self.idf.get(qt, 0)
                    doc_len = self.doc_lengths[i]
                    
                    # BM25公式
                    term_score = idf * (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_length))
                    score += term_score
            
            doc_scores.append({
                'index': i,
                'document': doc,
                'score': score
            })
        
        # 排序
        doc_scores.sort(key=lambda x: x['score'], reverse=True)
        return doc_scores[:top_k]
​
​
class HybridRetriever:
    """混合检索器:结合向量检索和关键词检索"""
    
    def __init__(
        self,
        vector_weight: float = 0.5,
        keyword_weight: float = 0.5
    ):
        """
        初始化混合检索器
        
        Args:
            vector_weight: 向量检索权重
            keyword_weight: 关键词检索权重
        """
        self.vector_weight = vector_weight
        self.keyword_weight = keyword_weight
        self.vector_retriever = None
        self.keyword_search = BM25KeywordSearch()
        self.is_fitted = False
    
    def fit(self, documents: List[str], embeddings: List[List[float]]):
        """构建索引"""
        self.vector_retriever = VectorRetriever(dimension=len(embeddings[0]))
        self.vector_retriever.add_documents(documents, embeddings)
        self.keyword_search.fit(documents)
        self.is_fitted = True
        print(f"混合检索索引构建完成: {len(documents)} 个文档")
    
    def retrieve(
        self,
        query: str,
        query_embedding: List[float],
        top_k: int = 5,
        alpha: Optional[float] = None
    ) -> List[Dict]:
        """
        混合检索
        
        Args:
            query: 查询文本
            query_embedding: 查询向量
            top_k: 返回数量
            alpha: 动态权重参数,0-1之间,越大越偏向向量检索
        
        Returns:
            融合后的检索结果
        """
        if not self.is_fitted:
            raise ValueError("索引未构建,请先调用fit方法")
        
        # 向量检索
        vector_results = self.vector_retriever.retrieve(query, query_embedding, top_k * 2)
        
        # 关键词检索
        keyword_results = self.keyword_search.search(query, top_k * 2)
        keyword_score_map = {r['index']: r['score'] for r in keyword_results}
        
        # 归一化分数
        vector_scores = [r['similarity'] for r in vector_results]
        max_vec_score = max(vector_scores) if vector_scores else 1
        max_kw_score = max(keyword_score_map.values()) if keyword_score_map else 1
        
        # 融合分数
        fused_scores = {}
        
        # 动态调整权重
        if alpha is not None:
            self.vector_weight = alpha
            self.keyword_weight = 1 - alpha
        
        for r in vector_results:
            idx = r['index']
            norm_vector_score = r['similarity'] / max_vec_score if max_vec_score > 0 else 0
            kw_score = keyword_score_map.get(idx, 0) / max_kw_score if max_kw_score > 0 else 0
            fused_scores[idx] = {
                'document': r['document'],
                'vector_score': norm_vector_score,
                'keyword_score': kw_score,
                'fused_score': self.vector_weight * norm_vector_score + self.keyword_weight * kw_score,
                'index': idx
            }
        
        # 排序
        sorted_results = sorted(fused_scores.values(), key=lambda x: x['fused_score'], reverse=True)
        return sorted_results[:top_k]
​
# 使用示例
documents = [
    '人工智能是计算机科学的一个分支,致力于研究智能的实质。',
    '机器学习是人工智能的重要分支,使用数据来提升模型性能。',
    '深度学习是机器学习的一个研究方向,使用多层神经网络。',
    '自然语言处理是人工智能的另一个重要分支,研究语言理解和生成。',
    '计算机视觉是研究如何让机器看懂图像的学科。',
    '强化学习是机器学习的一个分支,通过与环境交互来学习决策。'
]
​
# 模拟embeddings
embeddings = np.random.rand(len(documents), 768).astype('float32')
​
# 构建混合检索器
hybrid_retriever = HybridRetriever(vector_weight=0.5, keyword_weight=0.5)
hybrid_retriever.fit(documents, embeddings)
​
# 执行混合检索
query = "机器学习的分支有哪些?"
query_embedding = np.random.rand(768).astype('float32)
​
results = hybrid_retriever.retrieve(query, query_embedding, top_k=3, alpha=0.5)
​
print(f"\n查询: {query}")
print("=" * 60)
for i, r in enumerate(results):
    print(f"{i+1}. [融合分: {r['fused_score']:.4f}] "
          f"向量:{r['vector_score']:.4f} | 关键词:{r['keyword_score']:.4f}")
    print(f"   文档: {r['document']}")

6.3 重排序(Reranking)

重排序是在初步检索后,使用更精细的模型对结果进行二次排序:

复制代码
from typing import List, Dict
​
class SimpleReranker:
    """简化版重排序器(实际生产中可使用Cross-Encoder模型)"""
    
    def __init__(self):
        """初始化重排序器"""
        # 模拟的权重配置
        self.exact_match_bonus = 0.3
        self.semantic_bonus = 0.2
        self.length_penalty_factor = 0.1
    
    def _compute_exact_match_score(self, query: str, document: str) -> float:
        """计算关键词匹配分数"""
        query_words = set(query.lower().split())
        doc_words = set(document.lower().split())
        intersection = query_words & doc_words
        return len(intersection) / len(query_words) if query_words else 0
    
    def _compute_length_penalty(self, document: str) -> float:
        """计算长度惩罚(避免过长或过短的文档)"""
        doc_len = len(document)
        # 理想长度在200-500字之间
        if 200 <= doc_len <= 500:
            return 1.0
        elif doc_len < 200:
            return 0.8 + 0.2 * (doc_len / 200)
        else:
            return max(0.5, 1.0 - 0.1 * ((doc_len - 500) / 500))
    
    def rerank(self, query: str, initial_results: List[Dict], top_k: int = 5) -> List[Dict]:
        """
        对初步检索结果进行重排序
        
        Args:
            query: 查询文本
            initial_results: 初步检索结果
            top_k: 返回数量
        
        Returns:
            重排序后的结果
        """
        scored_results = []
        
        for result in initial_results:
            document = result.get('document', '')
            
            # 计算各项分数
            exact_match = self._compute_exact_match_score(query, document)
            length_penalty = self._compute_length_penalty(document)
            
            # 综合分数 = 原始相似度 + 额外加分
            original_score = result.get('similarity', 0) or result.get('fused_score', 0)
            final_score = original_score + self.exact_match_bonus * exact_match + self.length_penalty_factor * (length_penalty - 1)
            
            scored_results.append({
                **result,
                'exact_match_score': exact_match,
                'length_penalty': length_penalty,
                'final_score': final_score
            })
        
        # 按最终分数排序
        scored_results.sort(key=lambda x: x['final_score'], reverse=True)
        return scored_results[:top_k]
​
# 使用示例
reranker = SimpleReranker()
​
# 假设这是初步检索结果
initial_results = [
    {'index': 0, 'document': '人工智能是计算机科学的一个分支,致力于研究智能的实质。', 'similarity': 0.85},
    {'index': 1, 'document': '机器学习是人工智能的重要分支,使用数据来提升模型性能。', 'similarity': 0.78},
    {'index': 2, 'document': '深度学习是机器学习的一个研究方向,使用多层神经网络。', 'similarity': 0.72},
]
​
query = "人工智能和机器学习的关系"
reranked = reranker.rerank(query, initial_results, top_k=3)
​
print("重排序结果:")
for r in reranked:
    print(f"  文档: {r['document']}")
    print(f"  最终分数: {r['final_score']:.4f} = 原始:{r['similarity']} + 匹配:{r['exact_match_score']:.2f}")
    print()

七、RAG完整实现(LangChain)

LangChain是当前最流行的RAG应用开发框架,以下是完整的RAG实现:

复制代码
from langchain.document_loaders import TextLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain.vectorstores import Chroma, FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from typing import List, Optional
import os
​
class RAGSystem:
    """完整的RAG系统"""
    
    def __init__(
        self,
        embedding_model: str = 'openai',
        llm_model: str = 'gpt-3.5-turbo',
        api_key: Optional[str] = None,
        vector_store_type: str = 'chroma',
        persist_directory: str = './vector_store'
    ):
        """
        初始化RAG系统
        
        Args:
            embedding_model: Embedding模型类型 ('openai', 'huggingface')
            llm_model: LLM模型名称
            api_key: OpenAI API密钥
            vector_store_type: 向量存储类型 ('chroma', 'faiss')
            persist_directory: 向量存储持久化目录
        """
        self.embedding_model = embedding_model
        self.llm_model = llm_model
        self.persist_directory = persist_directory
        self.vector_store_type = vector_store_type
        
        # 初始化Embedding模型
        if embedding_model == 'openai':
            self.embeddings = OpenAIEmbeddings(
                model='text-embedding-3-small',
                openai_api_key=api_key or os.environ.get('OPENAI_API_KEY')
            )
        else:
            self.embeddings = HuggingFaceEmbeddings(
                model_name='BAAI/bge-base-zh-v1.5',
                model_kwargs={'device': 'cpu'}
            )
        
        # 初始化LLM
        self.llm = ChatOpenAI(
            model_name=llm_model,
            openai_api_key=api_key or os.environ.get('OPENAI_API_KEY'),
            temperature=0.3  # 低温度保证回答准确性
        )
        
        self.vectorstore = None
        self.qa_chain = None
    
    def load_documents(self, file_paths: List[str]) -> List:
        """
        加载文档
        
        Args:
            file_paths: 文件路径列表
        
        Returns:
            文档列表
        """
        documents = []
        
        for path in file_paths:
            if path.endswith('.txt'):
                loader = TextLoader(path, encoding='utf-8')
            elif path.endswith('.pdf'):
                loader = PyPDFLoader(path)
            else:
                print(f"不支持的文件类型: {path}")
                continue
            
            docs = loader.load()
            documents.extend(docs)
            print(f"已加载: {path}, 页数: {len(docs)}")
        
        return documents
    
    def process_documents(
        self,
        documents: List,
        chunk_size: int = 500,
        chunk_overlap: int = 50
    ) -> List:
        """
        处理文档:分块
        
        Args:
            documents: 文档列表
            chunk_size: 块大小
            chunk_overlap: 块重叠大小
        
        Returns:
            分块后的文档列表
        """
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
            separators=['\n\n', '\n', '。', '!', '?', '. ', '! ', '? ', ' ']
        )
        
        chunks = text_splitter.split_documents(documents)
        print(f"文档分块完成: {len(chunks)} 个块")
        return chunks
    
    def build_index(self, chunks: List, force_rebuild: bool = False):
        """
        构建向量索引
        
        Args:
            chunks: 文档块列表
            force_rebuild: 是否强制重建
        """
        if self.vector_store_type == 'chroma':
            # 使用Chroma
            self.vectorstore = Chroma.from_documents(
                documents=chunks,
                embedding=self.embeddings,
                persist_directory=self.persist_directory
            )
        elif self.vector_store_type == 'faiss':
            # 使用FAISS
            self.vectorstore = FAISS.from_documents(
                documents=chunks,
                embedding=self.embeddings
            )
        else:
            raise ValueError(f"不支持的向量存储类型: {self.vector_store_type}")
        
        print(f"向量索引构建完成: {self.vectorstore.index.num_vectors if hasattr(self.vectorstore.index, 'num_vectors') else 'N/A'} 个向量")
    
    def setup_qa_chain(self, prompt_template: Optional[str] = None):
        """
        设置问答链
        
        Args:
            prompt_template: 自定义提示词模板
        """
        if self.vectorstore is None:
            raise ValueError("请先调用build_index构建索引")
        
        # 默认提示词模板
        if prompt_template is None:
            prompt_template = """基于以下上下文信息回答问题。如果上下文中没有相关信息,请如实说明。
​
上下文信息:
{context}
​
用户问题:{question}
​
请提供准确、专业的回答:"""
        
        prompt = PromptTemplate(
            template=prompt_template,
            input_variables=['context', 'question']
        )
        
        # 创建检索问答链
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type='stuff',
            retriever=self.vectorstore.as_retriever(
                search_kwargs={'k': 5}  # 返回Top-5相关文档
            ),
            chain_type_kwargs={'prompt': prompt},
            return_source_documents=True  # 返回源文档用于溯源
        )
        
        print("问答链配置完成")
    
    def query(self, question: str, verbose: bool = False) -> dict:
        """
        问答
        
        Args:
            question: 用户问题
            verbose: 是否显示详细信息
        
        Returns:
            包含答案和源文档的字典
        """
        if self.qa_chain is None:
            raise ValueError("请先调用setup_qa_chain配置问答链")
        
        result = self.qa_chain({'query': question})
        
        if verbose:
            print(f"\n问题: {question}")
            print(f"\n答案: {result['result']}")
            print(f"\n参考文档 ({len(result['source_documents'])} 个):")
            for i, doc in enumerate(result['source_documents']):
                print(f"  [{i+1}] {doc.page_content[:100]}...")
        
        return {
            'question': question,
            'answer': result['result'],
            'source_documents': result['source_documents']
        }
    
    @classmethod
    def from_existing_index(
        cls,
        persist_directory: str,
        embedding_model: str = 'openai',
        llm_model: str = 'gpt-3.5-turbo',
        api_key: Optional[str] = None,
        vector_store_type: str = 'chroma'
    ) -> 'RAGSystem':
        """
        从已有索引加载RAG系统
        
        Args:
            persist_directory: 向量存储目录
            embedding_model: Embedding模型类型
            llm_model: LLM模型名称
            api_key: API密钥
            vector_store_type: 向量存储类型
        
        Returns:
            RAGSystem实例
        """
        system = cls(
            embedding_model=embedding_model,
            llm_model=llm_model,
            api_key=api_key,
            vector_store_type=vector_store_type,
            persist_directory=persist_directory
        )
        
        # 加载已有索引
        if vector_store_type == 'chroma':
            system.vectorstore = Chroma(
                persist_directory=persist_directory,
                embedding_function=system.embeddings
            )
        elif vector_store_type == 'faiss':
            system.vectorstore = FAISS.load_local(
                persist_directory,
                system.embeddings
            )
        
        system.setup_qa_chain()
        return system
​
​
# 使用示例
def main():
    """完整使用流程示例"""
    
    # 初始化RAG系统
    rag = RAGSystem(
        embedding_model='openai',  # 或 'huggingface' 使用本地模型
        llm_model='gpt-3.5-turbo',
        vector_store_type='chroma',
        persist_directory='./my_vector_store'
    )
    
    # 步骤1: 加载文档
    documents = rag.load_documents(['./docs/技术文档1.txt', './docs/技术文档2.pdf'])
    
    # 步骤2: 处理文档(分块)
    chunks = rag.process_documents(documents, chunk_size=500, chunk_overlap=50)
    
    # 步骤3: 构建向量索引
    rag.build_index(chunks, force_rebuild=False)
    
    # 步骤4: 设置问答链
    rag.setup_qa_chain()
    
    # 步骤5: 问答
    result = rag.query("什么是人工智能?", verbose=True)
    
    # 从已有索引加载
    # loaded_rag = RAGSystem.from_existing_index('./my_vector_store')
    # result = loaded_rag.query("另一个问题")
    
    return result
​
# if __name__ == '__main__':
#     main()

八、使用场景

8.1 企业知识库

企业知识库是RAG最典型的应用场景,包括:

  • 内部制度文档:员工手册、财务制度、审批流程等

  • 技术文档:API文档、架构设计、技术方案等

  • 产品文档:产品规格、使用说明、FAQ等

RAG系统让员工能够通过自然语言快速检索和理解企业知识,无需逐个文档查找。

8.2 智能客服

在客服场景中,RAG可以:

  • 从产品手册、常见问题中检索相关信息

  • 结合历史对话上下文提供个性化回答

  • 保持回答与官方文档的一致性

  • 支持多轮对话和意图理解

复制代码
class CustomerServiceRAG:
    """客服场景RAG实现"""
    
    def __init__(self, knowledge_base_path: str):
        self.rag = RAGSystem(
            embedding_model='openai',
            llm_model='gpt-3.5-turbo',
            vector_store_type='chroma',
            persist_directory='./客服知识库索引'
        )
        self.knowledge_base_path = knowledge_base_path
        self.conversation_history = []
    
    def initialize(self):
        """初始化知识库"""
        docs = self.rag.load_documents([
            f'{self.knowledge_base_path}/faq.txt',
            f'{self.knowledge_base_path}/products.txt',
            f'{self.knowledge_base_path}/policies.txt'
        ])
        chunks = self.rag.process_documents(docs)
        self.rag.build_index(chunks)
        self.rag.setup_qa_chain()
    
    def chat(self, user_message: str) -> str:
        """对话"""
        # 构建带上下文的查询
        context_prompt = ""
        if self.conversation_history:
            context_prompt = "对话历史:\n"
            for h in self.conversation_history[-3:]:  # 最近3轮
                context_prompt += f"用户:{h['user']}\n助手:{h['assistant']}\n"
            context_prompt += "\n"
        
        full_query = context_prompt + f"用户问题:{user_message}"
        
        # 查询
        result = self.rag.query(user_message)
        
        # 更新历史
        self.conversation_history.append({
            'user': user_message,
            'assistant': result['answer']
        })
        
        return result['answer']
    
    def clear_history(self):
        """清除对话历史"""
        self.conversation_history = []
​
# 使用示例
# chatbot = CustomerServiceRAG('./knowledge_base')
# chatbot.initialize()
# response = chatbot.chat("你们的退货政策是什么?")
# response = chatbot.chat("那换货呢?")  # 带上下文

8.3 文档问答

文档问答系统允许用户针对特定文档提出问题:

复制代码
class DocumentQASystem:
    """文档问答系统"""
    
    def __init__(self):
        self.documents = {}
        self.active_doc_id = None
        self.rag = RAGSystem(
            embedding_model='openai',
            vector_store_type='chroma'
        )
    
    def load_document(self, doc_id: str, file_path: str):
        """加载文档"""
        docs = self.rag.load_documents([file_path])
        chunks = self.rag.process_documents(docs)
        
        # 为每个文档创建独立索引
        doc_store_path = f'./doc_index_{doc_id}'
        self.rag.build_index(chunks, force_rebuild=True)
        
        self.documents[doc_id] = {
            'path': file_path,
            'chunk_count': len(chunks),
            'index_path': doc_store_path
        }
        self.active_doc_id = doc_id
        
        print(f"文档 '{doc_id}' 加载完成: {len(chunks)} 个块")
    
    def ask(self, question: str, doc_id: Optional[str] = None) -> dict:
        """提问"""
        target_doc_id = doc_id or self.active_doc_id
        
        if target_doc_id not in self.documents:
            raise ValueError(f"文档 '{target_doc_id}' 未加载")
        
        # 切换到目标文档索引
        # 实际实现中可能需要切换索引或添加过滤条件
        return self.rag.query(question)
​
# 使用示例
doc_qa = DocumentQASystem()
doc_qa.load_document('技术方案', './docs/技术方案.pdf')
result = doc_qa.ask("这个方案的技术架构是什么?")
print(f"答案: {result['answer']}")

九、总结与展望

本文详细介绍了RAG(检索增强生成)技术的完整知识体系,包括:

  1. 核心概念:RAG通过将外部知识检索与LLM生成相结合,有效解决了大模型的幻觉和知识截止问题

  2. 关键技术组件

    • 文档分块策略影响检索精度

    • Embedding模型决定语义理解能力

    • 向量数据库支撑大规模数据存储和检索

    • 混合检索和重排序提升检索效果

  3. 最佳实践

    • 根据场景选择合适的分块大小

    • 中文场景推荐BGE、M3E等开源模型

    • 生产环境推荐Milvus等分布式方案

    • 混合检索+重排序是提升效果的关键

  4. 应用场景:企业知识库、智能客服、文档问答等场景已广泛采用RAG技术

未来发展趋势

  • 多模态RAG:支持图像、音频、视频等多种模态的检索和生成

  • 动态知识更新:实现知识的实时更新和增量索引

  • 个性化RAG:根据用户画像和偏好定制检索策略

  • RAG与Fine-tuning融合:结合两者优势,打造更强大的AI应用

RAG作为当前最实用的LLM应用架构之一,正在快速发展并持续演进。建议开发者在实践中根据具体需求选择合适的技术方案,并持续关注该领域的最新进展。

相关推荐
用户6000718191010 小时前
【翻译】给Agent配上解释器
人工智能
明志数科10 小时前
仿真数据与真实数据:机器人训练的数据策略选择
人工智能·算法·机器学习
老司机张师傅10 小时前
AI第一章:虚拟环境库安装
人工智能
深度学习lover10 小时前
<数据集>yolo汉字识别<目标检测>
人工智能·yolo·目标检测·数据集·汉字识别
Master_oid10 小时前
机器学习43:线性回归进阶篇①
人工智能·机器学习·线性回归
香蕉鼠片10 小时前
CNN学习时的代码
人工智能·学习·cnn
AskHarries10 小时前
Google Trends 找蓝海赛道:独立开发者如何挖出没人做、但有人搜的项目
人工智能
searchforAI10 小时前
5款AI笔记工具实测:导入体验、结构化输出、后续能力逐项对比
人工智能·笔记·学习·ai·chatgpt·aigc·音视频
深度学习lover10 小时前
<项目代码>yolo缆绳识别<目标检测>
人工智能·深度学习·yolo·目标检测·项目代码·缆绳识别