目录
[6.1 数据摄取管道](#6.1 数据摄取管道)
[6.1.1 文档解析引擎:Unstructured.io处理PDF表格与层级标题](#6.1.1 文档解析引擎:Unstructured.io处理PDF表格与层级标题)
[6.1.2 网页爬取集成:Crawl4AI异步抓取与内容清洗](#6.1.2 网页爬取集成:Crawl4AI异步抓取与内容清洗)
[6.1.3 分块策略:语义分块(Semantic Chunking)与递归字符切分对比](#6.1.3 分块策略:语义分块(Semantic Chunking)与递归字符切分对比)
[6.1.4 元数据提取:文件名、章节、时间戳增强检索过滤](#6.1.4 元数据提取:文件名、章节、时间戳增强检索过滤)
[6.2 向量存储架构](#6.2 向量存储架构)
[6.2.1 嵌入模型管理:Sentence-Transformers本地部署与OpenAI嵌入切换](#6.2.1 嵌入模型管理:Sentence-Transformers本地部署与OpenAI嵌入切换)
[6.2.2 向量数据库选型:Milvus分布式部署与Qdrant轻量级对比](#6.2.2 向量数据库选型:Milvus分布式部署与Qdrant轻量级对比)
[6.2.3 混合检索:向量相似度 + BM25关键词搜索融合(RRF算法)](#6.2.3 混合检索:向量相似度 + BM25关键词搜索融合(RRF算法))
[6.2.4 索引优化:HNSW参数调优与量化(Quantization)压缩](#6.2.4 索引优化:HNSW参数调优与量化(Quantization)压缩)
[6.3 检索策略优化](#6.3 检索策略优化)
[6.3.1 查询重写:HyDE(Hypothetical Document Embedding)生成假设答案](#6.3.1 查询重写:HyDE(Hypothetical Document Embedding)生成假设答案)
[6.3.2 重排序(Rerank):Cohere Rerank与交叉编码器微调](#6.3.2 重排序(Rerank):Cohere Rerank与交叉编码器微调)
[6.3.3 多跳检索:GraphRAG构建文档关系图与社区摘要](#6.3.3 多跳检索:GraphRAG构建文档关系图与社区摘要)
[6.3.4 查询路由:元数据过滤与多索引联合查询](#6.3.4 查询路由:元数据过滤与多索引联合查询)
[6.4 生成与后处理](#6.4 生成与后处理)
[6.4.1 上下文组装:token预算管理与相关片段优先级排序](#6.4.1 上下文组装:token预算管理与相关片段优先级排序)
[6.4.2 引用溯源:检索结果高亮与原文链接生成](#6.4.2 引用溯源:检索结果高亮与原文链接生成)
[6.4.3 答案验证:Self-RAG反思机制与幻觉检测](#6.4.3 答案验证:Self-RAG反思机制与幻觉检测)
[6.4.4 缓存策略:语义缓存(Semantic Cache)与精确匹配缓存](#6.4.4 缓存策略:语义缓存(Semantic Cache)与精确匹配缓存)
[6.5 评估与监控](#6.5 评估与监控)
[6.5.1 离线评估:RAGAS指标(Faithfulness/Answer Relevancy)自动计算](#6.5.1 离线评估:RAGAS指标(Faithfulness/Answer Relevancy)自动计算)
[6.5.2 在线反馈:用户点赞/点踩与 thumbs 信号收集](#6.5.2 在线反馈:用户点赞/点踩与 thumbs 信号收集)
[6.5.3 A/B测试:不同分块策略与提示词模板效果对比](#6.5.3 A/B测试:不同分块策略与提示词模板效果对比)
[6.5.4 持续学习:Bad Case收集与模型微调触发机制](#6.5.4 持续学习:Bad Case收集与模型微调触发机制)
[6.1.1 文档解析引擎](#6.1.1 文档解析引擎)
[6.1.2 网页爬取集成](#6.1.2 网页爬取集成)
[6.1.3 分块策略](#6.1.3 分块策略)
[6.1.4 元数据提取](#6.1.4 元数据提取)
[6.2.1 嵌入模型管理](#6.2.1 嵌入模型管理)
[6.2.2 向量数据库选型](#6.2.2 向量数据库选型)
[6.2.3 混合检索](#6.2.3 混合检索)
[6.2.4 索引优化](#6.2.4 索引优化)
[6.3.1 查询重写](#6.3.1 查询重写)
[6.3.2 重排序](#6.3.2 重排序)
[6.3.3 多跳检索](#6.3.3 多跳检索)
[6.3.4 查询路由](#6.3.4 查询路由)
[6.4.1 上下文组装](#6.4.1 上下文组装)
[6.4.2 引用溯源](#6.4.2 引用溯源)
[6.4.3 答案验证](#6.4.3 答案验证)
[6.4.4 缓存策略](#6.4.4 缓存策略)
[6.5.1 离线评估](#6.5.1 离线评估)
[6.5.2 在线反馈](#6.5.2 在线反馈)
[6.5.3 A/B测试](#6.5.3 A/B测试)
[6.5.4 持续学习](#6.5.4 持续学习)
第一部分:原理详解
6.1 数据摄取管道
6.1.1 文档解析引擎:Unstructured.io处理PDF表格与层级标题
现代企业知识库的核心挑战在于处理异构文档格式,特别是PDF文件中嵌套的表格结构与层级化标题体系。Unstructured.io框架采用深度文档理解模型,通过计算机视觉与自然语言处理的融合架构,实现版面分析(Layout Analysis)与文本抽取的协同优化。
该引擎首先应用基于Transformer的版面检测网络,识别文档中的文本块、表格区域、图像及标题层级。对于表格解析,系统采用两阶段方法:第一阶段使用目标检测算法定位表格边界框,第二阶段通过序列到序列模型重建表格的行-列结构,将视觉表格转换为结构化的HTML或Markdown表示。层级标题的识别依赖于字体特征(字号、字重、字体家族)的统计分析,结合语义连贯性模型,构建文档的章节树(Table of Contents Tree)。
在技术实现层面,解析引擎维护一个文档对象模型(Document Object Model),其中每个节点包含坐标元数据(bounding box coordinates)、文本内容、元素类型(Title, NarrativeText, Table, ListItem等)以及与其他节点的层级关系。这种细粒度的文档表示为后续的语义分块提供了结构感知的基础。
6.1.2 网页爬取集成:Crawl4AI异步抓取与内容清洗
针对动态网页与企业内部Wiki系统的数据摄取,Crawl4AI采用异步I/O架构(基于Python的asyncio库)实现高并发爬取。该框架设计了智能请求调度器,通过令牌桶算法(Token Bucket Algorithm)控制爬取速率,避免对目标服务器造成过载。
内容清洗模块采用基于视觉的网页去噪技术,利用Readability算法与机器学习方法识别主要内容区域(Main Content Area),过滤导航栏、广告、页脚等噪声元素。对于JavaScript渲染的单页应用(SPA),系统集成Playwright或Selenium进行无头浏览器渲染,等待动态内容加载完成后执行DOM抽取。
清洗流程包括:HTML标签剥离(保留语义化标签如h1-h6, table, ul/ol)、CSS内联样式移除、编码规范化(统一转换为UTF-8)、以及重复内容检测(基于SimHash或MinHash的近似重复检测算法)。最终输出的是结构化的Markdown或纯文本,保留原始网页的语义层次结构。
6.1.3 分块策略:语义分块(Semantic Chunking)与递归字符切分对比
文本分块(Text Chunking)是RAG系统中断句上下文窗口与检索精度的关键权衡点。递归字符切分(Recursive Character Text Splitting)采用分层分隔符策略,优先按段落(\n\n)、句子(\n)、单词(空格)的层级递归切割,确保块内文本的连续性。该方法计算效率高,但可能破坏语义边界,导致跨段落的上下文碎片化。
语义分块(Semantic Chunking)则基于嵌入向量的语义相似性动态确定切分点。算法首先计算句子级别的嵌入向量,然后检测相邻句子间余弦相似度的显著下降(阈值通常为0.7-0.8),在语义边界处执行切分。该方法的优势在于保持主题一致性,块内语义连贯性显著优于固定长度切分,但计算成本较高,需要预计算嵌入向量。
similarity(s_i, s_{i+1}) = \\frac{e_i \\cdot e_{i+1}}{\\\|e_i\\\| \\\|e_{i+1}\\\|}
其中 e_i 表示第 i 个句子的嵌入向量。当相似度低于阈值 \\tau 时,触发分块边界。
6.1.4 元数据提取:文件名、章节、时间戳增强检索过滤
元数据增强检索(Metadata-Enriched Retrieval)通过在向量存储中附加结构化属性,支持基于过滤条件的精确检索。系统自动提取的元数据包括:文档级属性(文件名、作者、创建时间、文档类型)、章节级属性(标题层级、章节编号、父章节引用)以及内容级属性(关键词标签、实体识别结果、摘要)。
时间戳元数据支持时间范围过滤(Temporal Filtering),适用于需要检索最新版本文档或特定时间段内更新的场景。章节层级元数据构建文档的导航路径(Breadcrumb),在检索结果中提供上下文定位。文件名与路径信息用于权限控制(Access Control List, ACL),确保检索结果符合用户的文档访问权限。
6.2 向量存储架构
6.2.1 嵌入模型管理:Sentence-Transformers本地部署与OpenAI嵌入切换
嵌入模型(Embedding Model)负责将文本映射到高维语义空间。Sentence-Transformers框架支持多种预训练架构(BERT, RoBERTa, MPNet, E5, GTE等)的本地部署,通过量化技术(INT8/INT4)与ONNX Runtime优化推理延迟。本地部署确保数据隐私,适用于金融、医疗等敏感领域。
模型选择依据MTEB(Massive Text Embedding Benchmark)排行榜的检索性能指标。对于多语言场景,采用支持跨语言对齐的模型(如paraphrase-multilingual-mpnet-base-v2)。向量维度通常为384至1024维,需在存储成本与语义表达能力间权衡。
OpenAI的text-embedding-3-large等API服务提供更高维度的嵌入(3072维)与更好的下游任务性能,但引入网络延迟与数据出境合规风险。系统应支持模型路由(Model Routing),根据数据敏感度与延迟要求动态选择本地或云端嵌入服务。
v = Encoder(Tokenize(text)) \\in \\mathbb{R}\^d
其中 d 为嵌入维度,Encoder 为双向Transformer编码器。
6.2.2 向量数据库选型:Milvus分布式部署与Qdrant轻量级对比
Milvus采用云原生架构,支持水平扩展(Horizontal Scaling)与分布式向量索引。其存储-计算分离架构允许独立扩展查询节点(Query Node)与数据节点(Data Node),适用于十亿级向量的高并发检索。Milvus支持多种索引类型(FLAT, IVF-FLAT, IVF-PQ, HNSW, ANNOY),并提供基于Raft协议的分布式一致性保证。
Qdrant采用Rust编写,专注于单节点或轻量级集群部署的高性能。其特色在于混合查询能力(Hybrid Filtering),支持在向量相似度搜索的同时应用复杂的标量过滤条件(如 price > 100 AND category = "electronics")。Qdrant的内存映射(Memory Mapping)机制优化了大规模数据集的内存占用,适合资源受限的边缘部署。
ANN(q, k) = \\arg\\max_{D' \\subset D, \|D'\|=k} \\sum_{x \\in D'} \\frac{q \\cdot x}{\\\|q\\\| \\\|x\\\|}
近似最近邻(ANN)检索在向量空间中寻找与查询向量 q 最相似的 k 个向量。
6.2.3 混合检索:向量相似度 + BM25关键词搜索融合(RRF算法)
混合检索(Hybrid Retrieval)结合稠密向量检索(Dense Retrieval)的语义理解能力与稀疏检索(Sparse Retrieval,如BM25)的精确关键词匹配优势。稠密检索捕获语义相关性(如"汽车"与"车辆"的隐含关联),而BM25确保对特定术语、产品代码、人名等精确匹配。
倒数排序融合(Reciprocal Rank Fusion, RRF)算法融合两种检索结果列表。对于每个文档 d ,计算其在向量检索列表中的排名 r_v(d) 与在BM25列表中的排名 r_b(d) (若不存在则设为无穷大)。RRF得分公式为:
RRF(d) = \\sum_{i \\in \\{v, b\\}} \\frac{1}{k + r_i(d)}
其中 k 为常数(通常取60),用于平滑高排名文档的得分差异。最终按RRF得分降序排列,生成融合后的检索结果。
6.2.4 索引优化:HNSW参数调优与量化(Quantization)压缩
HNSW(Hierarchical Navigable Small World)图索引通过构建多层近似图结构实现对数级复杂度的最近邻搜索。关键参数包括:M(每层最大邻居数,控制图密度)、efConstruction(构建时的搜索范围,影响索引质量)与ef(查询时的搜索范围,影响召回率)。
Recall@k = \\frac{\|ANN_k(q) \\cap Exact_k(q)\|}{k}
量化压缩技术降低存储与计算开销。乘积量化(Product Quantization, PQ)将高维向量分解为子向量,对每个子空间训练码本(Codebook),用质心索引替代原始浮点值。标量量化(Scalar Quantization, SQ)将float32映射为int8,减少75%存储空间同时保持较高精度。二进制量化(Binary Quantization)进一步压缩为比特向量,通过汉明距离加速计算,但牺牲部分精度。
6.3 检索策略优化
6.3.1 查询重写:HyDE(Hypothetical Document Embedding)生成假设答案
Hypothetical Document Embeddings(HyDE)技术利用大型语言模型的生成能力弥合查询与文档间的语义鸿沟。传统检索中,简短或模糊的查询(如"最佳实践")难以匹配详细的文档内容。HyDE通过提示工程(Prompt Engineering)指令语言模型生成假设性答案文档,该文档虽未基于真实知识库,但在语义空间上更接近目标文档。
具体流程为:首先将原始查询 q 输入LLM生成假设文档 d_{hyp} ;随后计算 d_{hyp} 的嵌入向量 e_{hyp} ;最后以 e_{hyp} 作为检索查询在向量数据库中搜索相似文档。该方法有效扩展了查询的语义表达,尤其适用于零样本(Zero-Shot)检索场景。
e_{hyde} = Embed(LLM(q \\mid prompt_{hyde}))
Results = TopK(e_{hyde}, Index)
6.3.2 重排序(Rerank):Cohere Rerank与交叉编码器微调
初始检索阶段通常采用双编码器(Bi-Encoder)架构,独立编码查询与文档,通过向量内积快速筛选候选集。然而,双编码器无法充分建模查询-文档间的细粒度交互。重排序(Reranking)阶段使用交叉编码器(Cross-Encoder)或专用重排序模型(如Cohere Rerank API)对候选文档进行精确排序。
交叉编码器将查询与文档文本拼接(如"[CLS] Query [SEP] Document [SEP]"),通过Transformer编码器生成相关性分数。这种全注意力机制捕获了词汇级对齐与语义交互,精度显著高于双编码器,但计算成本较高,仅适用于小规模候选集(通常Top-100)。
微调策略包括在领域特定数据(如MS MARCO或自定义标注数据)上训练重排序模型,优化二元分类(相关/不相关)或细粒度相关性评分目标。知识蒸馏(Knowledge Distillation)技术将大型交叉编码器的知识迁移到轻量级模型,平衡精度与延迟。
Score(q, d) = MLP(Transformer(\[q; d\]))
6.3.3 多跳检索:GraphRAG构建文档关系图与社区摘要
GraphRAG(Graph-based Retrieval-Augmented Generation)通过构建文档实体关系图支持多跳推理(Multi-hop Reasoning)。系统首先使用命名实体识别(NER)与关系抽取(RE)从文档中提取实体(Entity)与关系(Relation)三元组,构建知识图谱 G=(V, E) ,其中节点 V 表示实体,边 E 表示关系。
社区检测(Community Detection)算法(如Louvain或Leiden算法)识别图谱中的紧密连接子图(社区),每个社区对应特定主题或概念集群。对每个社区生成摘要(Community Summary),描述该社区的核心主题与关键实体。
多跳检索流程:首先识别查询中的种子实体,在图中执行广度优先搜索(BFS)或个性化PageRank(PPR)探索多跳邻居;然后检索相关社区摘要与关联文档;最后基于图结构路径生成带有溯源的推理链(Chain of Evidence)。
PPR(v) = \\alpha \\cdot e_v + (1 - \\alpha) \\cdot \\sum_{u \\in N(v)} \\frac{PPR(u)}{\|N(u)\|}
其中 \\alpha 为随机跳转概率,e_v 为种子实体的one-hot向量。
6.3.4 查询路由:元数据过滤与多索引联合查询
查询路由(Query Routing)机制根据查询特征动态选择检索策略或索引子集。元数据过滤路由分析查询中的结构化约束(如时间范围、文档类型、作者),将查询定向到特定分区(Partition)或集合(Collection)。例如,查询"2024年财务报告中的营收数据"被路由到2024年文档集合与财务类别索引。
多索引联合查询(Multi-Index Federation)在垂直领域知识库中尤为关键,其中不同文档类型(产品手册、技术规范、客户案例)存储于独立索引。路由分类器(基于轻量级BERT或关键词规则)预测查询应检索的索引子集,并行执行检索后融合结果。
自适应检索(Adaptive Retrieval)根据查询复杂度动态调整检索深度:简单事实查询(如"公司成立时间")仅需单跳检索,而复杂分析查询(如"比较Q1与Q2产品线表现差异")触发多跳或迭代检索。
6.4 生成与后处理
6.4.1 上下文组装:token预算管理与相关片段优先级排序
大语言模型的上下文窗口(Context Window)存在长度限制(如128K tokens),而检索返回的文档片段总和常超出该限制。上下文组装(Context Assembly)模块实施token预算管理(Token Budget Management),在约束条件下最大化上下文信息量。
优先级排序策略基于相关性分数与信息多样性。首先按检索相关性排序片段;然后应用最大边际相关性(Maximal Marginal Relevance, MMR)算法,在相关性与多样性间权衡:
MMR(d_i) = \\lambda \\cdot Sim(d_i, q) - (1 - \\lambda) \\cdot \\max_{d_j \\in S} Sim(d_i, d_j)
其中 S 为已选片段集合,\\lambda 控制权衡系数。该公式选择既相关又与已选片段差异大的文档,减少冗余信息。
动态截断(Dynamic Truncation)根据文档结构在句子或段落边界处截断,避免切割语义单元。基于语义的压缩(Semantic Compression)使用较小语言模型提取片段核心句,进一步节省token预算。
6.4.2 引用溯源:检索结果高亮与原文链接生成
答案溯源(Attribution)是确保RAG系统可解释性与可验证性的关键。系统通过语句级对齐(Sentence-Level Alignment)识别生成答案中的每个陈述(Claim)所支持的原文证据。
实现方法包括:在提示工程中加入指令要求模型为每个陈述添加引用标记(如[1], [2]);使用后处理算法将生成文本分割为原子陈述,通过自然语言推理(NLI)模型验证各陈述与检索片段的蕴含关系(Entailment);对支持特定陈述的原文片段进行高亮(Highlighting)处理。
原文链接生成将引用标记映射到文档元数据(文件名、页码、段落ID),生成可点击的超链接或结构化引用(如"根据《产品手册v2.0》第15页...")。对于PDF文档,通过坐标元数据生成精确到行的文本高亮区域。
6.4.3 答案验证:Self-RAG反思机制与幻觉检测
Self-RAG(Self-Reflective Retrieval-Augmented Generation)框架在生成过程中插入反思标记(Reflection Tokens),使模型动态决定是否需要检索、评估检索内容的相关性、以及验证生成内容的准确性。
反思机制通过特殊训练或提示工程实现,模型生成格式如:[Retrieve]、[No Retrieve]、[Relevant]、[Irrelevant]、[Supported]、[Contradictory]。在生成每个陈述后,模型评估该陈述是否需要外部验证;若需要,则触发检索并评估返回文档的相关性;最后验证生成内容与文档的一致性。
幻觉检测(Hallucination Detection)采用基于NLI的事实核查:将答案分解为事实陈述集合 \\{f_1, f_2, ..., f_n\\} ,对每个 f_i 验证其与检索上下文 C 的蕴含关系:
Faithfulness = \\frac{\|\\{f_i \\mid C \\models f_i\\}\|}{n}
其中 C \\models f_i 表示上下文支持该事实。低置信度陈述触发警告或拒绝回答。
6.4.4 缓存策略:语义缓存(Semantic Cache)与精确匹配缓存
语义缓存(Semantic Cache)通过识别语义等价的查询(尽管字面不同)减少重复计算与API调用。系统维护缓存存储 (Q_{cache}, V_{cache}, A_{cache}) ,其中 Q_{cache} 为历史查询,V_{cache} 为嵌入向量,A_{cache} 为缓存答案。
对于新查询 q_{new} ,计算其嵌入 v_{new} ,在缓存向量中搜索相似度超过阈值 \\theta (如0.95)的条目。若存在,直接返回对应答案。近似最近邻搜索(ANN)加速缓存查找。
精确匹配缓存(Exact Match Cache)使用哈希表存储字面完全相同的查询,适用于高频重复问题。分层缓存策略结合精确匹配(L1缓存)与语义匹配(L2缓存), misses时触发完整RAG流程并将结果写入缓存。
缓存失效(Cache Invalidation)策略包括:基于时间的生存期(TTL)、文档更新事件驱动的主动失效、以及基于语义漂移检测的被动失效。
6.5 评估与监控
6.5.1 离线评估:RAGAS指标(Faithfulness/Answer Relevancy)自动计算
RAGAS(Retrieval-Augmented Generation Assessment)框架提供无参考(Reference-Free)的自动化评估指标,无需人工标注的标准答案即可评估RAG系统性能。
忠实度(Faithfulness)度量生成答案 A 与检索上下文 C 的事实一致性。首先使用LLM将答案分解为原子陈述集合 S(A)=\\{s_1, s_2, ..., s_m\\} ;然后验证每个陈述 s_i 是否被 C 支持(支持、矛盾或未知)。忠实度分数为被支持陈述的比例:
F = \\frac{\|\\{s_i \\mid Supported(s_i, C)\\}\|}{\|S(A)\|}
答案相关性(Answer Relevancy)评估答案 A 对用户查询 Q 的针对性。系统生成 k 个潜在问题 \\{q_1, ..., q_k\\} ,这些问题应以 A 为正确答案;然后计算这些生成问题与原始查询 Q 的嵌入相似度平均值:
AR = \\frac{1}{k} \\sum_{i=1}\^k \\cos(Embed(Q), Embed(q_i))
上下文精确率(Context Precision)衡量检索片段中相关部分的比例;上下文召回率(Context Recall)评估检索是否覆盖了回答问题所需的全部信息。
6.5.2 在线反馈:用户点赞/点踩与 thumbs 信号收集
在线反馈机制捕获真实用户交互信号,用于持续优化检索与生成质量。显式反馈(Explicit Feedback)包括点赞(Thumbs Up)/点踩(Thumbs Down)按钮、星级评分(1-5星)、以及可选的自由文本反馈。
隐式反馈(Implicit Feedback)通过用户行为推断满意度:答案复制行为、 dwell time(答案展示后用户在页面的停留时间)、后续查询(若用户在获得答案后立即发起相关查询,可能表明答案不完整)、以及会话终止(成功解决用户问题后结束对话)。
反馈信号与检索上下文、生成答案、系统配置(使用的分块策略、模型版本)关联存储,构建反馈数据库用于后续的模型微调与策略优化。对抗性反馈(Adversarial Feedback)识别系统失效模式,如事实错误、未回答查询核心、或检索不相关内容。
6.5.3 A/B测试:不同分块策略与提示词模板效果对比
A/B测试框架对比RAG系统变体在实际流量下的性能差异。测试维度包括:分块策略(固定长度vs语义分块)、嵌入模型(E5 vs OpenAI)、检索算法(纯向量vs混合检索)、重排序模型(有无Cohere Rerank)、以及提示词模板(Zero-Shot vs Few-Shot vs Chain-of-Thought)。
流量分割(Traffic Splitting)采用用户ID哈希或会话级随机化,确保同一用户在会话期间体验一致的系统版本(避免混杂效应)。关键指标包括:答案接受率(基于点赞/点踩)、任务完成率(用户是否达成查询目标)、平均延迟、以及错误率。
统计显著性检验(如双样本t检验或Mann-Whitney U检验)确定观察到的差异是否显著。多臂老虎机(Multi-Armed Bandit)算法动态调整流量分配,将更多流量导向表现优异的变体,同时保持对次优变体的探索。
6.5.4 持续学习:Bad Case收集与模型微调触发机制
持续学习(Continual Learning)机制识别系统的失效案例(Bad Cases),触发针对性改进。Bad Case检测标准包括:用户明确点踩、答案忠实度评分低于阈值、检测到幻觉、或检索结果为空。
Bad Case分类器将失效归因于特定组件:解析错误(文档未正确提取)、检索失败(相关文档未召回)、排序错误(相关文档排名过低)、生成错误(模型未正确利用上下文)、或提示不足(上下文未包含足够信息)。
当累积的Bad Case数量达到阈值(如100例),触发模型微调(Fine-Tuning)或检索索引更新。微调数据构建包括:对检索失败的案例,将正确文档标记为正样本,挖掘困难负样本(Hard Negatives);对生成错误的案例,构建偏好对(Preference Pairs),使用RLHF(Reinforcement Learning from Human Feedback)或DPO(Direct Preference Optimization)优化生成模型。
自动化再训练管道(Retraining Pipeline)执行数据验证、模型训练、离线评估、以及影子部署(Shadow Deployment),验证通过后才推送到生产环境。
第二部分:结构化伪代码
6.1 数据摄取管道
6.1.1 文档解析引擎
代码段
\begin{algorithm}
\caption{Unstructured Document Parsing Engine}
\begin{algorithmic}[1]
\Require Document file path $P$, Extraction schema $S$
\Ensure Structured document elements $E=\{e_1, e_2, \dots, e_n\}$
\State $doc \leftarrow \text{LoadDocument}(P)$
\State $elements \leftarrow \text{InitializeEmptyList}()$
\State $layout \leftarrow \text{DetectLayout}(doc)$
\For{each region $r \in layout$}
\If{$r.type = \text{Table}$}
\State $html\_table \leftarrow \text{ExtractTableStructure}(r)$
\State $e \leftarrow \text{CreateElement}(\text{type=Table}, \text{content}=html\_table)$
\ElsIf{$r.type = \text{Title}$}
\State $level \leftarrow \text{InferHeadingLevel}(r.font\_features)$
\State $e \leftarrow \text{CreateElement}(\text{type=Title}, \text{level}=level, \text{content}=r.text)$
\Else
\State $e \leftarrow \text{CreateElement}(\text{type}=r.type, \text{content}=r.text)$
\EndIf
\State $e.metadata \leftarrow \{bbox: r.coordinates, page: r.page\_num\}$
\State $elements.append(e)$
\EndFor
\State $hierarchy \leftarrow \text{BuildHierarchyTree}(elements)$
\State \Return $hierarchy$
\end{algorithmic}
\end{algorithm}
6.1.2 网页爬取集成
代码段
\begin{algorithm}
\caption{Asynchronous Web Crawling with Content Cleaning}
\begin{algorithmic}[1]
\Require Seed URLs $U$, Max depth $D$, Rate limit $\lambda$
\Ensure Cleaned web documents $W$
\State $frontier \leftarrow \text{PriorityQueue}(U)$
\State $visited \leftarrow \text{HashSet}()$
\State $results \leftarrow \text{ConcurrentQueue}()$
\State $semaphore \leftarrow \text{Semaphore}(\lambda)$
\While{$\neg frontier.empty() \land depth < D$}
\State $batch \leftarrow frontier.pop\_batch(B)$
\State $tasks \leftarrow \{\}$
\For{each $url \in batch$}
\If{$url \notin visited$}
\State $visited.add(url)$
\State $t \leftarrow asyncio.create\_task(\text{CrawlPage}(url, semaphore))$
\State $tasks \leftarrow tasks \cup \{t\}$
\EndIf
\EndFor
\State $pages \leftarrow asyncio.gather(tasks)$
\For{each $page \in pages$}
\If{$page.status = 200$}
\State $main\_content \leftarrow \text{ExtractContent}(\text{Readability}(page.html))$
\State $cleaned \leftarrow \text{CleanHTML}(main\_content)$
\State $markdown \leftarrow \text{HTML2Markdown}(cleaned)$
\State $results.put(\{url: page.url, content: markdown\})$
\State $new\_urls \leftarrow \text{ExtractLinks}(page.html)$
\State $frontier.extend(new\_urls)$
\EndIf
\EndFor
\EndWhile
\State \Return $results$
\end{algorithmic}
\end{algorithm}
6.1.3 分块策略
代码段
\begin{algorithm}
\caption{Semantic Chunking vs Recursive Character Splitting}
\begin{algorithmic}[1]
\Require Text $T$, Chunk size $C$, Overlap $O$, Similarity threshold $\tau$
\Ensure Text chunks $K=\{k_1, \dots, k_m\}$
\State \Comment{Recursive Character Splitting}
\State $separators \leftarrow [\text{"\backslash n\backslash n"}, \text{"\backslash n"}, \text{"."}, \text{","}, \text{" "}, \text{""}]$
\State $chunks \leftarrow \text{RecursiveSplit}(T, separators, C, O)$
\State \Return $chunks$
\State \Comment{Semantic Chunking}
\State $sentences \leftarrow \text{SentenceSegmentation}(T)$
\State $embeddings \leftarrow \text{Encode}(sentences)$
\State $chunks \leftarrow \text{InitializeEmptyList}()$
\State $current\_chunk \leftarrow [sentences_0]$
\State $current\_emb \leftarrow embeddings_0$
\For{$i \leftarrow 1$ \textbf{to} $|sentences| - 1$}
\State $sim \leftarrow \cos(embeddings_i, embeddings_{i-1})$
\If{$sim < \tau \land |current\_chunk| > 0$}
\State $chunks.append(\text{Join}(current\_chunk))$
\State $current\_chunk \leftarrow [sentences_i]$
\Else
\State $current\_chunk.append(sentences_i)$
\EndIf
\If{$|current\_chunk| \ge C$}
\State $chunks.append(\text{Join}(current\_chunk))$
\State $current\_chunk \leftarrow []$
\EndIf
\EndFor
\If{$current\_chunk \neq \emptyset$}
\State $chunks.append(\text{Join}(current\_chunk))$
\EndIf
\State \Return $chunks$
\end{algorithmic}
\end{algorithm}
6.1.4 元数据提取
代码段
\begin{algorithm}
\caption{Metadata Extraction and Enrichment}
\begin{algorithmic}[1]
\Require Document $D$, Extraction rules $R$
\Ensure Enriched chunks $\{(c_i, m_i)\}_{i=1}^n$
\State $file\_meta \leftarrow \{filename: D.name, created: D.ctime, author: D.author\}$
\State $structural\_meta \leftarrow \text{ParseHierarchy}(D.headings)$
\State $chunks \leftarrow \text{ChunkDocument}(D)$
\State $enriched \leftarrow \text{InitializeEmptyList}()$
\For{each $c \in chunks$}
\State $m \leftarrow file\_meta.copy()$
\State $m.section \leftarrow \text{FindNearestHeading}(c, structural\_meta)$
\State $m.timestamp \leftarrow \text{ExtractDate}(c.content)$
\State $m.entities \leftarrow \text{NER}(c.content)$
\State $m.keywords \leftarrow \text{TFIDF}(c.content, topk=5)$
\State $m.position \leftarrow c.index$
\State $enriched.append((c, m))$
\EndFor
\State \Return $enriched$
\end{algorithmic}
\end{algorithm}
6.2 向量存储架构
6.2.1 嵌入模型管理
代码段
\begin{algorithm}
\caption{Embedding Model Router and Local Deployment}
\begin{algorithmic}[1]
\Require Text batch $T$, Model configuration $M$, Privacy level $p$
\Ensure Embeddings $V=\{v_1, \dots, v_n\}$
\State $model \leftarrow \text{SelectModel}(M, p)$
\If{$p = high$}
\State $encoder \leftarrow \text{LoadLocalModel}(model.path)$
\If{$model.quantization = \text{INT8}$}
\State $encoder \leftarrow \text{Quantize}(encoder, bits=8)$
\EndIf
\State $embeddings \leftarrow encoder.encode(T, batch\_size=32)$
\Else
\State $embeddings \leftarrow \text{APIRequest}(\text{OpenAI}, T, model=model.id)$
\EndIf
\State \Return $embeddings$
\end{algorithmic}
\end{algorithm}
6.2.2 向量数据库选型
代码段
\begin{algorithm}
\caption{Vector Database Operations (Milvus/Qdrant)}
\begin{algorithmic}[1]
\Require Vector $v$, Collection name $C$, Top-k $k$, Filters $F$
\Ensure Search results $R$
\State \Comment{Insertion}
\State $id \leftarrow \text{GenerateUUID}()$
\State $payload \leftarrow \{vector: v, metadata: F\}$
\State $client.upsert(collection=C, points=[payload])$
\State \Comment{Hybrid Search}
\State $vector\_results \leftarrow client.search(collection=C, vector=v, limit=k \times 2, filter=F)$
\State $keyword\_results \leftarrow \text{BM25Search}(C, query\_text, k \times 2)$
\State $fused \leftarrow \text{RRFFusion}(vector\_results, keyword\_results, k)$
\State \Return $fused$
\end{algorithmic}
\end{algorithm}
6.2.3 混合检索
代码段
\begin{algorithm}
\caption{Reciprocal Rank Fusion (RRF) Algorithm}
\begin{algorithmic}[1]
\Require Ranked lists $L=\{L_1, L_2, \dots, L_m\}$, Constant $k=60$, Final top $n$
\Ensure Fused ranking $R$
\State $scores \leftarrow \text{DefaultDict}(0.0)$
\State $all\_docs \leftarrow \bigcup_{i=1}^m \{d \mid d \in L_i\}$
\For{each $L_i \in L$}
\For{each document $d \in all\_docs$}
\If{$d \in L_i$}
\State $rank \leftarrow L_i.index(d) + 1$
\State $scores[d] \leftarrow scores[d] + \frac{1}{k + rank}$
\Else
\State $scores[d] \leftarrow scores[d] + 0$
\EndIf
\EndFor
\EndFor
\State $sorted\_docs \leftarrow \text{SortByScoreDescending}(scores)$
\State $R \leftarrow \text{TopN}(sorted\_docs, n)$
\State \Return $R$
\end{algorithmic}
\end{algorithm}
6.2.4 索引优化
代码段
\begin{algorithm}
\caption{HNSW Index Construction and Quantization}
\begin{algorithmic}[1]
\Require Vector set $V$, $M$ parameter, $efConstruction$, Quantization type $Q$
\Ensure Optimized index $I$
\State \Comment{HNSW Graph Construction}
\State $graph \leftarrow \text{InitializeGraph}()$
\State $enter\_point \leftarrow \text{RandomSelect}(V)$
\For{each $v \in V$}
\State $layer \leftarrow \text{RandomLevel}(M)$
\State $neighbors \leftarrow \text{SearchLayer}(v, enter\_point, efConstruction, layer)$
\State $pruned \leftarrow \text{SelectNeighbors}(neighbors, M)$
\State $graph.add\_node(v, layer, pruned)$
\EndFor
\State \Comment{Product Quantization}
\If{$Q = \text{PQ}$}
\State $D \leftarrow \text{dimension}(V)$
\State $m \leftarrow num\_subspaces$
\State $subspaces \leftarrow \text{Split}(V, m)$
\For{$i \leftarrow 1$ \textbf{to} $m$}
\State $codebook_i \leftarrow \text{KMeans}(subspaces_i, clusters=256)$
\EndFor
\State $codes \leftarrow \text{Quantize}(V, \{codebook_i\})$
\State $I \leftarrow \{graph, codes, codebooks\}$
\EndIf
\State \Return $I$
\end{algorithmic}
\end{algorithm}
6.3 检索策略优化
6.3.1 查询重写
代码段
\begin{algorithm}
\caption{HyDE: Hypothetical Document Embedding}
\begin{algorithmic}[1]
\Require Query $q$, LLM $L$, Embedding model $E$, Prompt template $P_{hyde}$
\Ensure Enhanced query embedding $v_{hyde}$
\State $prompt \leftarrow P_{hyde}.format(query=q)$
\State $d_{hyp} \leftarrow L.generate(prompt)$
\State $v_{hyde} \leftarrow E.embed(d_{hyp})$
\State $v_{orig} \leftarrow E.embed(q)$
\State $v_{combined} \leftarrow \alpha \cdot v_{hyde} + (1 - \alpha) \cdot v_{orig}$
\State $candidates \leftarrow \text{VectorSearch}(v_{combined}, k)$
\State \Return $candidates$
\end{algorithmic}
\end{algorithm}
6.3.2 重排序
代码段
\begin{algorithm}
\caption{Cross-Encoder Reranking}
\begin{algorithmic}[1]
\Require Query $q$, Candidate documents $D=\{d_1, \dots, d_k\}$, Cross-encoder $C$
\Ensure Reranked documents $D'$
\State $pairs \leftarrow [(q, d_i) \mid d_i \in D]$
\State $scores \leftarrow \text{InitializeEmptyList}()$
\For{each $(q, d) \in pairs$}
\State $input \leftarrow \text{Concatenate}(q, \text{"[SEP]"}, d)$
\State $encoding \leftarrow C.tokenize(input)$
\State $logits \leftarrow C.forward(encoding)$
\State $relevance \leftarrow \text{Softmax}(logits)[1]$
\State $scores.append(relevance)$
\EndFor
\State $ranked \leftarrow \text{SortByScore}(\text{zip}(D, scores))$
\State $D' \leftarrow \{d \mid (d, s) \in ranked\}$
\State \Return $D'$
\end{algorithmic}
\end{algorithm}
6.3.3 多跳检索
代码段
\begin{algorithm}
\caption{GraphRAG Multi-hop Retrieval}
\begin{algorithmic}[1]
\Require Query $q$, Knowledge graph $G=(V, E)$, Seed entity extractor $E$, Depth $d$
\Ensure Retrieved subgraph $G'$
\State $entities \leftarrow E.extract(q)$
\State $frontier \leftarrow \{v \mid v \in V \land v.name \in entities\}$
\State $visited \leftarrow frontier$
\State $subgraph \leftarrow \text{InitializeSubgraph}()$
\For{$i \leftarrow 1$ \textbf{to} $d$}
\State $new\_frontier \leftarrow \emptyset$
\For{each $v \in frontier$}
\State $neighbors \leftarrow \{u \mid (v, u) \in E \lor (u, v) \in E\}$
\For{each $u \in neighbors$}
\If{$u \notin visited$}
\State $visited \leftarrow visited \cup \{u\}$
\State $new\_frontier \leftarrow new\_frontier \cup \{u\}$
\State $subgraph.add\_edge(v, u, E(v, u))$
\EndIf
\EndFor
\EndFor
\State $frontier \leftarrow new\_frontier$
\EndFor
\State $communities \leftarrow \text{CommunityDetection}(subgraph)$
\State $summaries \leftarrow \{\text{GenerateSummary}(c) \mid c \in communities\}$
\State $docs \leftarrow \text{RetrieveDocuments}(subgraph.nodes)$
\State \Return $(docs, summaries)$
\end{algorithmic}
\end{algorithm}
6.3.4 查询路由
代码段
\begin{algorithm}
\caption{Query Routing and Multi-Index Federation}
\begin{algorithmic}[1]
\Require Query $q$, Index registry $R=\{(I_1, F_1), \dots, (I_n, F_n)\}$, Router model $M$
\Ensure Aggregated results $A$
\State $intent \leftarrow M.classify(q)$
\State $metadata\_filters \leftarrow \text{ParseConstraints}(q)$
\State $selected\_indices \leftarrow \emptyset$
\For{each $(I, F) \in R$}
\If{$\text{IntentMatch}(intent, I.domain) \land \text{FilterCompatible}(metadata\_filters, F)$}
\State $selected\_indices \leftarrow selected\_indices \cup \{I\}$
\EndIf
\EndFor
\State $tasks \leftarrow \{\text{Search}(I, q, metadata\_filters) \mid I \in selected\_indices\}$
\State $results \leftarrow \text{ExecuteParallel}(tasks)$
\State $merged \leftarrow \text{MergeByScore}(results)$
\State $deduplicated \leftarrow \text{Deduplicate}(merged, threshold=0.95)$
\State \Return $deduplicated$
\end{algorithmic}
\end{algorithm}
6.4 生成与后处理
6.4.1 上下文组装
代码段
\begin{algorithm}
\caption{Token Budget Management with MMR}
\begin{algorithmic}[1]
\Require Query $q$, Retrieved chunks $C$, Max tokens $T_{max}$, LLM tokenizer $\mathcal{T}$
\Ensure Assembled context $C_{final}$
\State $selected \leftarrow \emptyset$
\State $remaining \leftarrow C$
\State $T_{used} \leftarrow |\mathcal{T}.\text{tokenize}(q)|$
\State $\lambda \leftarrow 0.5$
\While{$remaining \neq \emptyset \land T_{used} < T_{max}$}
\State $scores \leftarrow \emptyset$
\For{each $c \in remaining$}
\State $rel \leftarrow \text{CosineSim}(\text{Embed}(c), \text{Embed}(q))$
\State $div \leftarrow \max_{s \in selected} \text{CosineSim}(\text{Embed}(c), \text{Embed}(s))$
\If{$selected = \emptyset$}
\State $mmr \leftarrow rel$
\Else
\State $mmr \leftarrow \lambda \cdot rel - (1 - \lambda) \cdot div$
\EndIf
\State $scores[c] \leftarrow mmr$
\EndFor
\State $c_{best} \leftarrow \arg\max_{c} \text{scores}[c]$
\State $T_c \leftarrow |\mathcal{T}.\text{tokenize}(c)|$
\If{$T_{used} + T_c > T_{max}$}
\State $c_{trunc} \leftarrow \text{TruncateAtSentence}(c, T_{max} - T_{used}, \mathcal{T})$
\If{$c_{trunc} \neq \emptyset$}
\State $selected \leftarrow selected \cup \{c_{trunc}\}$
\EndIf
\State \textbf{break}
\EndIf
\State $selected \leftarrow selected \cup \{c_{best}\}$
\State $T_{used} \leftarrow T_{used} + T_c$
\State $remaining \leftarrow remaining \setminus \{c_{best}\}$
\EndWhile
\State $C_{final} \leftarrow \text{Concatenate}(selected)$
\State \Return $C_{final}$
\end{algorithmic}
\end{algorithm}
6.4.2 引用溯源
代码段
\begin{algorithm}
\caption{Citation Generation and Source Attribution}
\begin{algorithmic}[1]
\Require Generated answer $A$, Retrieved contexts $C$, NLI model $N$
\Ensure Attributed answer $A'$ with citations
\State $claims \leftarrow \text{SegmentIntoSentences}(A)$
\State $citation\_map \leftarrow \text{Dictionary}()$
\For{each $claim \in claims$}
\State $supporting \leftarrow \emptyset$
\For{$i \leftarrow 1$ \textbf{to} $|C|$}
\State $premise \leftarrow C[i].content$
\State $label \leftarrow N.predict(premise, claim)$
\If{$label = \text{entailment}$}
\State $supporting \leftarrow supporting \cup \{i\}$
\EndIf
\EndFor
\If{$supporting \neq \emptyset$}
\State $citation\_map[claim] \leftarrow supporting$
\EndIf
\EndFor
\State $A' \leftarrow \text{""}$
\For{each $claim \in claims$}
\State $A' \leftarrow A' + claim$
\If{$claim \in citation\_map$}
\State $refs \leftarrow citation\_map[claim]$
\State $A' \leftarrow A' + \text{" ["} + \text{Join}(refs, \text{","}) + \text{"]"}$
\EndIf
\EndFor
\State \Return $A'$
\end{algorithmic}
\end{algorithm}
6.4.3 答案验证
代码段
\begin{algorithm}
\caption{Self-RAG Reflection Mechanism}
\begin{algorithmic}[1]
\Require Query $q$, Generator $G$, Retriever $R$, Reflection tokens $T$
\Ensure Verified answer $A$ with reflection traces
\State $output \leftarrow \text{""}$
\State $reflection\_log \leftarrow \text{InitializeEmptyList}()$
\While{\textbf{true}}
\State $token \leftarrow G.generate\_next(q, output)$
\If{$token \in T$}
\If{$token = \text{[Retrieve]}$}
\State $context \leftarrow R.retrieve(q)$
\State $reflection\_log.append(\{action: \text{retrieve}, context: context\})$
\State $output \leftarrow output + token + \text{FormatContext}(context)$
\ElsIf{$token = \text{[Verify]}$}
\State $claim \leftarrow \text{ExtractLastSentence}(output)$
\State $is\_supported \leftarrow \text{VerifyAgainstContext}(claim, context)$
\State $reflection\_log.append(\{action: \text{verify}, result: is\_supported\})$
\If{$\neg is\_supported$}
\State $output \leftarrow output + \text{[Correction]}$
\State $output \leftarrow output + \text{GenerateCorrection}(claim, context)$
\EndIf
\ElsIf{$token = \text{[EOS]}$}
\State \textbf{break}
\EndIf
\Else
\State $output \leftarrow output + token$
\EndIf
\EndWhile
\State $final\_check \leftarrow \text{HallucinationDetection}(output, context)$
\State \Return $(output, reflection\_log, final\_check)$
\end{algorithmic}
\end{algorithm}
6.4.4 缓存策略
代码段
\begin{algorithm}
\caption{Semantic Cache with TTL}
\begin{algorithmic}[1]
\Require Query $q$, Cache store $S$, Similarity threshold $\theta$, TTL $\Delta t$
\Ensure Cached answer $a$ or $\perp$
\State $v_q \leftarrow \text{Embed}(q)$
\State $candidates \leftarrow S.ann\_search(v_q, k=5)$
\For{each $(v_c, a_c, t_c) \in candidates$}
\If{$\text{CurrentTime}() - t_c > \Delta t$}
\State $S.delete(v_c)$
\State \textbf{continue}
\EndIf
\State $sim \leftarrow \cos(v_q, v_c)$
\If{$sim > \theta$}
\State \Return $a_c$
\EndIf
\EndFor
\State \Return $\perp$
\end{algorithmic}
\end{algorithm}
6.5 评估与监控
6.5.1 离线评估
代码段
\begin{algorithm}
\caption{RAGAS Metrics Computation}
\begin{algorithmic}[1]
\Require QA pairs $\{(q_i, a_i, c_i)\}_{i=1}^n$, LLM evaluator $\mathcal{L}$
\Ensure Metric scores $M$
\State $faithfulness\_scores \leftarrow \text{InitializeEmptyList}()$
\State $relevancy\_scores \leftarrow \text{InitializeEmptyList}()$
\For{each $(q, a, c) \in dataset$}
\State \Comment{Faithfulness}
\State $claims \leftarrow \mathcal{L}.extract\_statements(a)$
\State $supported \leftarrow 0$
\For{each $claim \in claims$}
\State $verdict \leftarrow \mathcal{L}.verify(claim, c)$
\If{$verdict = \text{supported}$}
\State $supported \leftarrow supported + 1$
\EndIf
\EndFor
\State $F \leftarrow \frac{supported}{|claims|}$
\State $faithfulness\_scores.append(F)$
\State \Comment{Answer Relevancy}
\State $artificial\_qs \leftarrow \mathcal{L}.generate\_questions(a, k=3)$
\State $\mathbf{v}_q \leftarrow \text{Embed}(q)$
\State $sims \leftarrow \emptyset$
\For{each $q_{gen} \in artificial\_qs$}
\State $\mathbf{v}_{gen} \leftarrow \text{Embed}(q_{gen})$
\State $sims \leftarrow sims \cup \{\cos(\mathbf{v}_q, \mathbf{v}_{gen})\}$
\EndFor
\State $AR \leftarrow \frac{1}{|sims|} \sum sims$
\State $relevancy\_scores.append(AR)$
\EndFor
\State $M \leftarrow \{faithfulness: \text{Mean}(faithfulness\_scores), relevancy: \text{Mean}(relevancy\_scores)\}$
\State \Return $M$
\end{algorithmic}
\end{algorithm}
6.5.2 在线反馈
代码段
\begin{algorithm}
\caption{Online Feedback Collection and Analysis}
\begin{algorithmic}[1]
\Require User interaction stream $U$, Feedback window $\Delta$
\Ensure Feedback statistics $F$
\State $buffer \leftarrow \text{RingBuffer}(size=\Delta)$
\State $feedback\_db \leftarrow \text{InitializeDatabase}()$
\For{each $u \in U$}
\If{$u.type = \text{explicit}$}
\State $record \leftarrow \{query: u.q, answer: u.a, rating: u.rating, timestamp: u.t\}$
\State $feedback\_db.insert(record)$
\ElsIf{$u.type = \text{implicit}$}
\State $signals \leftarrow \text{AnalyzeBehavior}(u.session)$
\State $inferred\_score \leftarrow \text{CalculateSatisfaction}(signals)$
\State $record \leftarrow \{query: u.q, implicit\_score: inferred\_score, signals: signals\}$
\State $feedback\_db.insert(record)$
\EndIf
\If{$feedback\_db.count() \pmod{100} = 0$}
\State $stats \leftarrow \text{ComputeStatistics}(feedback\_db)$
\State $\text{AlertIfAnomaly}(stats)$
\EndIf
\EndFor
\end{algorithmic}
\end{algorithm}
6.5.3 A/B测试
代码段
\begin{algorithm}
\caption{A/B Testing Framework for RAG}
\begin{algorithmic}[1]
\Require Variants $V=\{v_1, v_2\}$, Traffic split ratio $\rho$, Success metric $M$
\Ensure Statistical comparison $C$
\State $assignments \leftarrow \text{HashBasedAssignment}(\rho)$
\State $results \leftarrow \{v_1: [], v_2: []\}$
\For{each incoming query $q$}
\State $v \leftarrow assignments.get\_variant(q.user\_id)$
\State $response \leftarrow v.process(q)$
\State $metric\_value \leftarrow M(response, q)$
\State $results[v].append(metric\_value)$
\EndFor
\State $\mu_1, \sigma_1 \leftarrow \text{ComputeMeanStd}(results[v_1])$
\State $\mu_2, \sigma_2 \leftarrow \text{ComputeMeanStd}(results[v_2])$
\State $n_1, n_2 \leftarrow |results[v_1]|, |results[v_2]|$
\State $t \leftarrow \frac{\mu_1 - \mu_2}{\sqrt{\frac{\sigma_1^2}{n_1} + \frac{\sigma_2^2}{n_2}}}$
\State $p \leftarrow \text{CalculatePValue}(t, dof=n_1+n_2-2)$
\If{$p < 0.05$}
\State $winner \leftarrow v_1 \textbf{ if } \mu_1 > \mu_2 \textbf{ else } v_2$
\State $\text{Rollout}(winner, traffic=100\%)$
\EndIf
\State \Return $\{mean\_diff: \mu_1 - \mu_2, p\_value: p, significant: p < 0.05\}$
\end{algorithmic}
\end{algorithm}
6.5.4 持续学习
代码段
\begin{algorithm}
\caption{Continual Learning with Bad Case Mining}
\begin{algorithmic}[1]
\Require Bad case threshold $\tau$, Retraining dataset size $N$, Model $M$
\Ensure Updated model $M'$
\State $bad\_cases \leftarrow \text{QueryBadCases}(threshold=\tau)$
\State $categories \leftarrow \text{CategorizeFailures}(bad\_cases)$
\If{$|bad\_cases| < N$}
\State \Return $M$ \Comment{Insufficient data}
\EndIf
\State $training\_data \leftarrow \emptyset$
\For{each $case \in bad\_cases$}
\If{$case.type = \text{retrieval\_failure}$}
\State $positives \leftarrow case.ground\_truth\_docs$
\State $negatives \leftarrow \text{HardNegativeMining}(case.query, positives)$
\State $training\_data \leftarrow training\_data \cup \{(case.query, positives, negatives)\}$
\ElsIf{$case.type = \text{generation\_error}$}
\State $pair \leftarrow (case.context, case.bad\_answer, case.good\_answer)$
\State $training\_data \leftarrow training\_data \cup \{pair\}$
\EndIf
\EndFor
\State $M' \leftarrow \text{FineTune}(M, training\_data, epochs=3)$
\State $validation\_score \leftarrow \text{Evaluate}(M', holdout\_set)$
\If{$validation\_score > \text{CurrentScore}(M)$}
\State $\text{Deploy}(M')$
\State $\text{ClearBadCases}()$
\EndIf
\State \Return $M'$
\end{algorithmic}
\end{algorithm}
第三部分:代码实现
6.1.1 文档解析引擎
Python
#!/usr/bin/env python3
"""
Script: document_parser_unstructured.py
功能: 基于Unstructured.io的PDF/Word文档解析引擎,支持表格提取与层级标题识别
使用方式: python document_parser_unstructured.py --input path/to/doc.pdf --output parsed.json
"""
import os
import json
import argparse
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.docx import partition_docx
from unstructured.chunking.title import chunk_by_title
from unstructured.documents.elements import Table, Title, NarrativeText, ListItem
import cv2
import numpy as np
from PIL import Image
@dataclass
class ParsedElement:
"""文档元素数据结构"""
element_id: str
type: str
content: str
metadata: Dict[str, Any]
hierarchy_level: Optional[int] = None
parent_id: Optional[str] = None
class DocumentParserEngine:
"""Unstructured.io文档解析引擎"""
def __init__(self, strategy: str = "hi_res", model_name: str = "detectron2"):
"""
初始化解析引擎
Args:
strategy: 解析策略 ('fast', 'hi_res', 'ocr_only')
model_name: 版面检测模型 ('detectron2', 'yolox')
"""
self.strategy = strategy
self.model_name = model_name
self.elements = []
self.hierarchy_tree = []
def parse_pdf(self, file_path: str) -> List[ParsedElement]:
"""
解析PDF文件,提取表格、标题与文本
Args:
file_path: PDF文件路径
Returns:
结构化元素列表
"""
print(f"[INFO] 开始解析PDF: {file_path}")
# 使用hi_res策略进行版面分析
raw_elements = partition_pdf(
filename=file_path,
strategy=self.strategy,
hi_res_model_name=self.model_name,
infer_table_structure=True, # 启用表格结构推断
extract_images_in_pdf=False,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000
)
self.elements = self._convert_to_structured(raw_elements)
self._build_hierarchy()
print(f"[INFO] 解析完成,共提取 {len(self.elements)} 个元素")
return self.elements
def parse_word(self, file_path: str) -> List[ParsedElement]:
"""解析Word文档"""
print(f"[INFO] 开始解析Word: {file_path}")
raw_elements = partition_docx(filename=file_path)
self.elements = self._convert_to_structured(raw_elements)
self._build_hierarchy()
return self.elements
def _convert_to_structured(self, raw_elements) -> List[ParsedElement]:
"""将原始元素转换为结构化格式"""
structured = []
for idx, element in enumerate(raw_elements):
element_id = f"elem_{idx:06d}"
# 确定元素类型
if isinstance(element, Table):
elem_type = "table"
# 提取表格为HTML格式
content = element.metadata.text_as_html if hasattr(element.metadata, 'text_as_html') else str(element)
elif isinstance(element, Title):
elem_type = "title"
content = str(element)
elif isinstance(element, ListItem):
elem_type = "list_item"
content = str(element)
else:
elem_type = "text"
content = str(element)
# 构建元数据
metadata = {
"page_number": element.metadata.page_number if hasattr(element.metadata, 'page_number') else None,
"filename": element.metadata.filename if hasattr(element.metadata, 'filename') else None,
"coordinates": element.metadata.coordinates if hasattr(element.metadata, 'coordinates') else None,
"category_depth": element.metadata.category_depth if hasattr(element.metadata, 'category_depth') else None
}
parsed = ParsedElement(
element_id=element_id,
type=elem_type,
content=content,
metadata=metadata,
hierarchy_level=metadata.get("category_depth")
)
structured.append(parsed)
return structured
def _build_hierarchy(self):
"""构建文档层级树"""
stack = []
for elem in self.elements:
if elem.type == "title":
level = elem.hierarchy_level or 0
# 维护层级栈
while stack and stack[-1].hierarchy_level >= level:
stack.pop()
if stack:
elem.parent_id = stack[-1].element_id
stack.append(elem)
else:
if stack:
elem.parent_id = stack[-1].element_id
def extract_tables(self) -> List[Dict]:
"""提取所有表格元素"""
tables = []
for elem in self.elements:
if elem.type == "table":
tables.append({
"element_id": elem.element_id,
"html_content": elem.content,
"page": elem.metadata.get("page_number"),
"coordinates": elem.metadata.get("coordinates")
})
return tables
def get_markdown_output(self) -> str:
"""生成Markdown格式输出"""
md_lines = []
for elem in self.elements:
if elem.type == "title":
level = min((elem.hierarchy_level or 0) + 1, 6)
md_lines.append(f"{'#' * level} {elem.content}\n")
elif elem.type == "table":
md_lines.append(f"\n{elem.content}\n")
elif elem.type == "list_item":
md_lines.append(f"- {elem.content}\n")
else:
md_lines.append(f"{elem.content}\n")
return "\n".join(md_lines)
def save_to_json(self, output_path: str):
"""保存解析结果到JSON"""
output = {
"elements": [asdict(e) for e in self.elements],
"statistics": {
"total_elements": len(self.elements),
"title_count": sum(1 for e in self.elements if e.type == "title"),
"table_count": sum(1 for e in self.elements if e.type == "table"),
"text_count": sum(1 for e in self.elements if e.type == "text")
}
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output, f, ensure_ascii=False, indent=2)
print(f"[INFO] 结果已保存至: {output_path}")
def main():
parser = argparse.ArgumentParser(description="文档解析引擎")
parser.add_argument("--input", "-i", required=True, help="输入文件路径(PDF或Word)")
parser.add_argument("--output", "-o", default="parsed_output.json", help="输出JSON路径")
parser.add_argument("--md-output", "-m", default=None, help="可选的Markdown输出路径")
parser.add_argument("--strategy", "-s", default="hi_res", choices=["fast", "hi_res", "ocr_only"])
args = parser.parse_args()
# 初始化解析器
engine = DocumentParserEngine(strategy=args.strategy)
# 根据文件类型解析
ext = os.path.splitext(args.input)[1].lower()
if ext == '.pdf':
engine.parse_pdf(args.input)
elif ext in ['.docx', '.doc']:
engine.parse_word(args.input)
else:
raise ValueError(f"不支持的文件格式: {ext}")
# 保存结果
engine.save_to_json(args.output)
# 可选输出Markdown
if args.md_output:
md_content = engine.get_markdown_output()
with open(args.md_output, 'w', encoding='utf-8') as f:
f.write(md_content)
print(f"[INFO] Markdown已保存至: {args.md_output}")
# 打印统计信息
tables = engine.extract_tables()
print(f"[STATS] 提取表格数: {len(tables)}")
if tables:
print(f"[SAMPLE] 首个表格预览(前200字符): {tables[0]['html_content'][:200]}")
if __name__ == "__main__":
main()
6.1.2 网页爬取集成
Python
#!/usr/bin/env python3
"""
Script: web_crawler_crawl4ai.py
功能: 基于Crawl4AI的异步网页爬取与内容清洗系统
使用方式: python web_crawler_crawl4ai.py --urls "https://example.com" --max-depth 2 --output crawl_results.json
"""
import asyncio
import json
import hashlib
import argparse
from typing import List, Dict, Set, Optional
from dataclasses import dataclass, asdict
from urllib.parse import urljoin, urlparse
from datetime import datetime
import aiohttp
from bs4 import BeautifulSoup
from readability import Document
import trafilatura
@dataclass
class CrawledPage:
"""爬取页面数据结构"""
url: str
title: str
markdown_content: str
metadata: Dict
links: List[str]
crawl_timestamp: str
depth: int
class AsyncWebCrawler:
"""异步网页爬取器"""
def __init__(self, max_depth: int = 2, max_concurrent: int = 5,
rate_limit: float = 1.0, respect_robots: bool = True):
"""
初始化爬取器
Args:
max_depth: 最大爬取深度
max_concurrent: 最大并发数
rate_limit: 每秒请求数限制
respect_robots: 是否遵守robots.txt
"""
self.max_depth = max_depth
self.max_concurrent = max_concurrent
self.rate_limit = rate_limit
self.respect_robots = respect_robots
self.visited_urls: Set[str] = set()
self.crawled_pages: List[CrawledPage] = []
self.domain_counts: Dict[str, int] = {}
# 信号量控制并发
self.semaphore = asyncio.Semaphore(max_concurrent)
async def crawl(self, seed_urls: List[str]) -> List[CrawledPage]:
"""
执行爬取任务
Args:
seed_urls: 种子URL列表
Returns:
爬取页面列表
"""
queue = [(url, 0) for url in seed_urls]
tasks = []
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=30),
headers={'User-Agent': 'RAG-Crawler/1.0 (Research Bot)'}
) as session:
while queue or tasks:
# 启动新任务直到达到并发限制
while queue and len(tasks) < self.max_concurrent:
url, depth = queue.pop(0)
if url not in self.visited_urls and depth <= self.max_depth:
self.visited_urls.add(url)
task = asyncio.create_task(
self._fetch_and_parse(session, url, depth)
)
tasks.append(task)
if tasks:
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
tasks = list(pending)
for task in done:
result, new_links = await task
if result:
self.crawled_pages.append(result)
# 添加新链接到队列
for link in new_links:
if link not in self.visited_urls:
queue.append((link, result.depth + 1))
# 速率限制
await asyncio.sleep(1.0 / self.rate_limit)
return self.crawled_pages
async def _fetch_and_parse(self, session: aiohttp.ClientSession,
url: str, depth: int) -> tuple:
"""
获取并解析单个页面
Args:
session: aiohttp会话
url: 目标URL
depth: 当前深度
Returns:
(CrawledPage, 新链接列表)
"""
async with self.semaphore:
try:
print(f"[CRAWL] 正在获取 (深度{depth}): {url}")
async with session.get(url, ssl=False) as response:
if response.status != 200:
print(f"[WARN] HTTP {response.status}: {url}")
return None, []
html = await response.text()
content_type = response.headers.get('content-type', '')
# 只处理HTML内容
if 'text/html' not in content_type:
return None, []
# 使用readability提取正文
doc = Document(html)
title = doc.title()
summary = doc.summary()
# 转换为Markdown
markdown = self._html_to_markdown(summary)
# 提取所有链接
soup = BeautifulSoup(html, 'html.parser')
links = self._extract_links(soup, url)
# 使用trafilatura进行内容清洗(备用/增强)
cleaned = trafilatura.extract(html, include_comments=False,
include_tables=True,
deduplicate=True,
target_language="zh")
# 合并两种提取结果,优先使用trafilatura
final_content = cleaned if cleaned else markdown
page = CrawledPage(
url=url,
title=title or "Untitled",
markdown_content=final_content or "",
metadata={
"content_type": content_type,
"content_length": len(final_content) if final_content else 0,
"source_domain": urlparse(url).netloc
},
links=links,
crawl_timestamp=datetime.now().isoformat(),
depth=depth
)
return page, links
except Exception as e:
print(f"[ERROR] 爬取失败 {url}: {str(e)}")
return None, []
def _html_to_markdown(self, html: str) -> str:
"""将HTML转换为Markdown"""
if not html:
return ""
soup = BeautifulSoup(html, 'html.parser')
# 移除脚本和样式
for script in soup(["script", "style", "nav", "footer", "header"]):
script.decompose()
# 处理标题
for i in range(1, 7):
for tag in soup.find_all(f'h{i}'):
tag.insert_before('#' * i + ' ')
tag.insert_after('\n\n')
# 处理列表
for ul in soup.find_all('ul'):
for li in ul.find_all('li'):
li.insert_before('- ')
li.insert_after('\n')
for ol in soup.find_all('ol'):
for idx, li in enumerate(ol.find_all('li'), 1):
li.insert_before(f'{idx}. ')
li.insert_after('\n')
# 处理链接
for a in soup.find_all('a', href=True):
href = a['href']
text = a.get_text(strip=True)
if text and href:
a.replace_with(f'[{text}]({href})')
# 处理强调
for strong in soup.find_all(['strong', 'b']):
text = strong.get_text(strip=True)
strong.replace_with(f'**{text}**')
for em in soup.find_all(['em', 'i']):
text = em.get_text(strip=True)
em.replace_with(f'*{text}*')
# 提取文本并清理
text = soup.get_text(separator='\n')
lines = [line.strip() for line in text.splitlines() if line.strip()]
return '\n\n'.join(lines)
def _extract_links(self, soup: BeautifulSoup, base_url: str) -> List[str]:
"""提取页面中的所有链接"""
links = []
for a in soup.find_all('a', href=True):
href = a['href']
full_url = urljoin(base_url, href)
# 只保留HTTP(S)链接,过滤锚点与媒体文件
if full_url.startswith('http') and not any(ext in full_url.lower()
for ext in ['.pdf', '.jpg', '.png', '.gif', '.zip']):
links.append(full_url)
return list(set(links))[:20] # 限制每页链接数
def save_results(self, output_path: str):
"""保存爬取结果"""
output = {
"crawl_info": {
"total_pages": len(self.crawled_pages),
"unique_domains": len(set(p.metadata["source_domain"] for p in self.crawled_pages)),
"crawl_time": datetime.now().isoformat()
},
"pages": [asdict(p) for p in self.crawled_pages]
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output, f, ensure_ascii=False, indent=2)
print(f"[INFO] 结果已保存至: {output_path}")
def main():
parser = argparse.ArgumentParser(description="异步网页爬取器")
parser.add_argument("--urls", "-u", nargs='+', required=True, help="种子URL列表")
parser.add_argument("--max-depth", "-d", type=int, default=2, help="最大爬取深度")
parser.add_argument("--max-concurrent", "-c", type=int, default=5, help="最大并发数")
parser.add_argument("--output", "-o", default="crawl_results.json", help="输出文件路径")
parser.add_argument("--rate-limit", "-r", type=float, default=1.0, help="每秒请求数限制")
args = parser.parse_args()
crawler = AsyncWebCrawler(
max_depth=args.max_depth,
max_concurrent=args.max_concurrent,
rate_limit=args.rate_limit
)
# 运行爬取
pages = asyncio.run(crawler.crawl(args.urls))
# 保存结果
crawler.save_results(args.output)
# 打印统计
print(f"\n[SUMMARY] 爬取完成:")
print(f" - 总页面数: {len(pages)}")
print(f" - 平均内容长度: {sum(len(p.markdown_content) for p in pages)/max(len(pages),1):.0f} 字符")
if pages:
domains = {}
for p in pages:
dom = p.metadata["source_domain"]
domains[dom] = domains.get(dom, 0) + 1
print(f" - 域名分布: {domains}")
if __name__ == "__main__":
main()
6.1.3 分块策略
Python
#!/usr/bin/env python3
"""
Script: semantic_chunking.py
功能: 实现递归字符切分与语义分块策略对比,包含可视化
使用方式: python semantic_chunking.py --input text.txt --method semantic --output chunks.json
"""
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
from dataclasses import dataclass
from sentence_transformers import SentenceTransformer
import re
@dataclass
class TextChunk:
"""文本块数据结构"""
content: str
index: int
method: str
metadata: Dict
embedding: np.ndarray = None
class RecursiveCharacterSplitter:
"""递归字符切分器"""
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = ["\n\n", "\n", ". ", " ", ""]
def split(self, text: str) -> List[TextChunk]:
"""
递归切分文本
Args:
text: 输入文本
Returns:
文本块列表
"""
chunks = self._recursive_split(text, self.separators)
return [
TextChunk(
content=chunk,
index=i,
method="recursive_character",
metadata={"char_count": len(chunk)}
)
for i, chunk in enumerate(chunks)
]
def _recursive_split(self, text: str, separators: List[str]) -> List[str]:
"""递归切分实现"""
if not text:
return []
# 尝试使用当前分隔符
separator = separators[0]
next_separators = separators[1:]
if separator == "":
# 最后手段:按字符硬切
return [text[i:i+self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)]
# 按分隔符分割
splits = text.split(separator)
chunks = []
current_chunk = []
current_length = 0
for split in splits:
split_length = len(split)
if current_length + split_length + len(separator) > self.chunk_size:
if current_chunk:
# 保存当前块
chunk_text = separator.join(current_chunk)
chunks.append(chunk_text)
# 处理重叠
overlap_text = self._get_overlap(current_chunk)
current_chunk = [overlap_text, split] if overlap_text else [split]
current_length = len(overlap_text) + split_length + len(separator)
else:
# 单个片段就超过限制,需要递归切分
if next_separators:
sub_chunks = self._recursive_split(split, next_separators)
chunks.extend(sub_chunks)
else:
chunks.append(split)
else:
current_chunk.append(split)
current_length += split_length + len(separator)
# 处理最后一个块
if current_chunk:
chunks.append(separator.join(current_chunk))
return chunks
def _get_overlap(self, chunks: List[str]) -> str:
"""获取重叠部分"""
overlap = []
current_len = 0
for chunk in reversed(chunks):
if current_len + len(chunk) > self.chunk_overlap:
break
overlap.insert(0, chunk)
current_len += len(chunk)
return " ".join(overlap)
class SemanticChunker:
"""语义分块器"""
def __init__(self,
embedding_model: str = 'all-MiniLM-L6-v2',
max_chunk_size: int = 1000,
similarity_threshold: float = 0.7):
self.model = SentenceTransformer(embedding_model)
self.max_chunk_size = max_chunk_size
self.similarity_threshold = similarity_threshold
def split(self, text: str) -> List[TextChunk]:
"""
基于语义相似度的切分
Args:
text: 输入文本
Returns:
文本块列表
"""
# 首先按句子分割
sentences = self._split_into_sentences(text)
if not sentences:
return []
# 编码所有句子
print(f"[INFO] 正在编码 {len(sentences)} 个句子...")
embeddings = self.model.encode(sentences, convert_to_numpy=True)
chunks = []
current_chunk = [sentences[0]]
current_embeddings = [embeddings[0]]
current_length = len(sentences[0])
for i in range(1, len(sentences)):
current_sentence = sentences[i]
current_embedding = embeddings[i]
# 计算与前一句的相似度
prev_embedding = embeddings[i-1]
similarity = np.dot(current_embedding, prev_embedding) / (
np.linalg.norm(current_embedding) * np.linalg.norm(prev_embedding)
)
# 检查是否超过最大长度或语义断裂
if (current_length + len(current_sentence) > self.max_chunk_size or
(similarity < self.similarity_threshold and len(current_chunk) > 0)):
# 保存当前块
chunk_text = " ".join(current_chunk)
chunks.append(TextChunk(
content=chunk_text,
index=len(chunks),
method="semantic",
metadata={
"char_count": len(chunk_text),
"sentence_count": len(current_chunk),
"avg_similarity": self._calculate_internal_similarity(current_embeddings)
},
embedding=np.mean(current_embeddings, axis=0)
))
# 开始新块
current_chunk = [current_sentence]
current_embeddings = [current_embedding]
current_length = len(current_sentence)
else:
current_chunk.append(current_sentence)
current_embeddings.append(current_embedding)
current_length += len(current_sentence)
# 处理最后一个块
if current_chunk:
chunk_text = " ".join(current_chunk)
chunks.append(TextChunk(
content=chunk_text,
index=len(chunks),
method="semantic",
metadata={
"char_count": len(chunk_text),
"sentence_count": len(current_chunk),
"avg_similarity": self._calculate_internal_similarity(current_embeddings)
},
embedding=np.mean(current_embeddings, axis=0)
))
return chunks
def _split_into_sentences(self, text: str) -> List[str]:
"""将文本分割为句子"""
# 简单的句子分割逻辑(可根据需要改进)
sentences = re.split(r'(?<=[.!?。!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
def _calculate_internal_similarity(self, embeddings: List[np.ndarray]) -> float:
"""计算块内平均相似度"""
if len(embeddings) < 2:
return 1.0
similarities = []
for i in range(len(embeddings)-1):
sim = np.dot(embeddings[i], embeddings[i+1]) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[i+1])
)
similarities.append(sim)
return float(np.mean(similarities))
class ChunkingVisualizer:
"""分块结果可视化"""
def __init__(self):
self.fig_size = (12, 8)
def visualize_comparison(self,
recursive_chunks: List[TextChunk],
semantic_chunks: List[TextChunk],
output_path: str = "chunking_comparison.png"):
"""可视化对比两种分块策略"""
fig, axes = plt.subplots(2, 2, figsize=self.fig_size)
# 1. 块长度分布
ax1 = axes[0, 0]
rec_lengths = [c.metadata["char_count"] for c in recursive_chunks]
sem_lengths = [c.metadata["char_count"] for c in semantic_chunks]
ax1.hist(rec_lengths, bins=20, alpha=0.5, label="Recursive", color="blue")
ax1.hist(sem_lengths, bins=20, alpha=0.5, label="Semantic", color="red")
ax1.set_xlabel("Chunk Length (characters)")
ax1.set_ylabel("Frequency")
ax1.set_title("Chunk Length Distribution")
ax1.legend()
# 2. 块数量对比
ax2 = axes[0, 1]
methods = ["Recursive", "Semantic"]
counts = [len(recursive_chunks), len(semantic_chunks)]
colors = ["blue", "red"]
ax2.bar(methods, counts, color=colors, alpha=0.6)
ax2.set_ylabel("Number of Chunks")
ax2.set_title("Total Chunks Generated")
for i, v in enumerate(counts):
ax2.text(i, v, str(v), ha='center', va='bottom')
# 3. 语义分块内部相似度
ax3 = axes[1, 0]
if semantic_chunks and "avg_similarity" in semantic_chunks[0].metadata:
similarities = [c.metadata["avg_similarity"] for c in semantic_chunks]
ax3.plot(range(len(similarities)), similarities, marker='o', color='red')
ax3.axhline(y=0.7, color='gray', linestyle='--', label="Threshold")
ax3.set_xlabel("Chunk Index")
ax3.set_ylabel("Internal Similarity")
ax3.set_title("Semantic Coherence within Chunks")
ax3.legend()
# 4. 累积长度分布
ax4 = axes[1, 1]
rec_cumsum = np.cumsum(rec_lengths)
sem_cumsum = np.cumsum(sem_lengths)
ax4.plot(range(len(rec_cumsum)), rec_cumsum, label="Recursive", color="blue")
ax4.plot(range(len(sem_cumsum)), sem_cumsum, label="Semantic", color="red")
ax4.set_xlabel("Chunk Index")
ax4.set_ylabel("Cumulative Length")
ax4.set_title("Text Coverage Progression")
ax4.legend()
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 可视化结果已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="文本分块策略对比")
parser.add_argument("--input", "-i", required=True, help="输入文本文件")
parser.add_argument("--method", "-m", choices=["recursive", "semantic", "both"],
default="both", help="分块方法")
parser.add_argument("--chunk-size", "-c", type=int, default=500, help="块大小")
parser.add_argument("--overlap", "-o", type=int, default=50, help="重叠大小")
parser.add_argument("--similarity-threshold", "-t", type=float, default=0.7,
help="语义相似度阈值")
parser.add_argument("--output", "-out", default="chunks.json", help="输出JSON")
parser.add_argument("--visualize", "-v", action="store_true", help="是否生成可视化")
parser.add_argument("--viz-output", "-vo", default="chunking_comparison.png",
help="可视化输出路径")
args = parser.parse_args()
# 读取输入文本
with open(args.input, 'r', encoding='utf-8') as f:
text = f.read()
print(f"[INFO] 输入文本长度: {len(text)} 字符")
recursive_chunks = []
semantic_chunks = []
# 递归字符切分
if args.method in ["recursive", "both"]:
print("[INFO] 执行递归字符切分...")
splitter = RecursiveCharacterSplitter(
chunk_size=args.chunk_size,
chunk_overlap=args.overlap
)
recursive_chunks = splitter.split(text)
print(f"[INFO] 递归切分生成 {len(recursive_chunks)} 个块")
# 语义分块
if args.method in ["semantic", "both"]:
print("[INFO] 执行语义分块...")
chunker = SemanticChunker(
max_chunk_size=args.chunk_size,
similarity_threshold=args.similarity_threshold
)
semantic_chunks = chunker.split(text)
print(f"[INFO] 语义分块生成 {len(semantic_chunks)} 个块")
# 保存结果
output_data = {
"recursive_character": [
{"content": c.content, "metadata": c.metadata}
for c in recursive_chunks
],
"semantic": [
{"content": c.content, "metadata": c.metadata}
for c in semantic_chunks
]
}
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
print(f"[INFO] 结果已保存至: {args.output}")
# 可视化
if args.visualize and args.method == "both":
visualizer = ChunkingVisualizer()
visualizer.visualize_comparison(recursive_chunks, semantic_chunks, args.viz_output)
# 打印示例
if recursive_chunks:
print(f"\n[示例] 递归切分首个块 (前200字符):")
print(recursive_chunks[0].content[:200])
if semantic_chunks:
print(f"\n[示例] 语义切分首个块 (前200字符):")
print(semantic_chunks[0].content[:200])
if __name__ == "__main__":
main()
6.1.4 元数据提取
Python
#!/usr/bin/env python3
"""
Script: metadata_extractor.py
功能: 文档元数据提取与增强系统,支持文件名、章节、时间戳、实体识别
使用方式: python metadata_extractor.py --input chunks.json --output enriched_chunks.json
"""
import json
import re
import argparse
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
import hashlib
# 模拟NER(实际应用中可使用spaCy或transformers)
class SimpleNER:
"""简单命名实体识别(演示用)"""
PATTERNS = {
"DATE": r'\b(\d{4}[-/]\d{1,2}[-/]\d{1,2}|\d{1,2}[-/]\d{1,2}[-/]\d{4})\b',
"EMAIL": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
"PHONE": r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b',
"MONEY": r'\$\d+(?:,\d{3})*(?:\.\d{2})?',
"PERCENT": r'\d+(?:\.\d+)?%',
"ORG": r'\b(?:[A-Z][a-z]*\s)*(?:Inc\.|Corp\.|LLC|Ltd\.|Company|Corporation)\b',
}
def extract(self, text: str) -> Dict[str, List[str]]:
"""提取实体"""
entities = {}
for label, pattern in self.PATTERNS.items():
matches = re.findall(pattern, text)
if matches:
entities[label] = list(set(matches))[:5] # 限制数量
return entities
class MetadataExtractor:
"""元数据提取器"""
def __init__(self):
self.ner = SimpleNER()
def extract_file_metadata(self, file_path: str) -> Dict[str, Any]:
"""
提取文件级元数据
Args:
file_path: 文件路径
Returns:
文件元数据字典
"""
path = Path(file_path)
stats = path.stat()
return {
"filename": path.name,
"file_extension": path.suffix,
"file_size_bytes": stats.st_size,
"created_timestamp": datetime.fromtimestamp(stats.st_ctime).isoformat(),
"modified_timestamp": datetime.fromtimestamp(stats.st_mtime).isoformat(),
"directory_hierarchy": str(path.parent),
"file_hash": self._calculate_hash(file_path)
}
def extract_content_metadata(self,
text: str,
chunk_index: int,
parent_sections: List[str] = None) -> Dict[str, Any]:
"""
提取内容级元数据
Args:
text: 文本内容
chunk_index: 块索引
parent_sections: 父章节列表
Returns:
内容元数据字典
"""
# 基础统计
metadata = {
"char_count": len(text),
"word_count": len(text.split()),
"sentence_count": len(re.split(r'[.!?。!?]+', text)),
"chunk_index": chunk_index,
"extraction_timestamp": datetime.now().isoformat(),
"content_hash": hashlib.md5(text.encode()).hexdigest()[:8]
}
# 时间戳提取
metadata["temporal_info"] = self._extract_temporal_info(text)
# 关键词提取(简单TF-IDF模拟)
metadata["keywords"] = self._extract_keywords(text)
# 实体识别
metadata["entities"] = self.ner.extract(text)
# 章节层级
if parent_sections:
metadata["section_hierarchy"] = parent_sections
metadata["section_level"] = len(parent_sections)
# 内容类型检测
metadata["content_type"] = self._detect_content_type(text)
return metadata
def _calculate_hash(self, file_path: str) -> str:
"""计算文件哈希"""
hash_obj = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_obj.update(chunk)
return hash_obj.hexdigest()[:16]
def _extract_temporal_info(self, text: str) -> Dict[str, Any]:
"""提取时间信息"""
dates = []
# ISO格式日期
iso_pattern = r'\b(\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?)?)\b'
dates.extend(re.findall(iso_pattern, text))
# 中文日期
chinese_pattern = r'(\d{4}年\d{1,2}月\d{1,2}日)'
dates.extend(re.findall(chinese_pattern, text))
# 相对时间
relative_terms = re.findall(r'\b(上周|昨天|今天|明天|最近|过去|未来)\w*\b', text)
return {
"explicit_dates": dates[:3],
"relative_terms": list(set(relative_terms)),
"temporal_relevance": len(dates) > 0 or len(relative_terms) > 0
}
def _extract_keywords(self, text: str, topk: int = 5) -> List[str]:
"""提取关键词(基于词频)"""
# 简单的词频统计(实际应用应使用TF-IDF或RAKE)
words = re.findall(r'\b[A-Za-z]{4,}\b', text)
word_freq = {}
for word in words:
word_lower = word.lower()
if word_lower not in {"this", "that", "with", "from", "they", "have", "were"}:
word_freq[word_lower] = word_freq.get(word_lower, 0) + 1
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
return [word for word, freq in sorted_words[:topk]]
def _detect_content_type(self, text: str) -> str:
"""检测内容类型"""
text_lower = text.lower()
indicators = {
"technical_spec": ["specification", "parameter", "configuration", "api", "endpoint"],
"legal_contract": ["agreement", "clause", "party", "terms", "conditions", "liability"],
"financial_report": ["revenue", "profit", "loss", "balance sheet", "cash flow", "quarter"],
"meeting_minutes": ["meeting", "attendee", "agenda", "discussion", "action item", "decided"],
"academic_paper": ["abstract", "introduction", "methodology", "conclusion", "references"]
}
scores = {}
for doc_type, keywords in indicators.items():
score = sum(1 for kw in keywords if kw in text_lower)
scores[doc_type] = score
if max(scores.values()) > 0:
return max(scores, key=scores.get)
return "general_document"
class EnrichedChunk:
"""增强型文本块"""
def __init__(self, content: str, metadata: Dict[str, Any]):
self.content = content
self.metadata = metadata
self.combined_text = self._build_combined_text()
def _build_combined_text(self) -> str:
"""构建用于检索的增强文本(元数据+内容)"""
meta_parts = []
if "filename" in self.metadata:
meta_parts.append(f"Source: {self.metadata['filename']}")
if "section_hierarchy" in self.metadata:
meta_parts.append(f"Section: {' > '.join(self.metadata['section_hierarchy'])}")
if "keywords" in self.metadata:
meta_parts.append(f"Keywords: {', '.join(self.metadata['keywords'])}")
meta_str = " | ".join(meta_parts)
return f"{meta_str}\n\n{self.content}" if meta_str else self.content
def main():
parser = argparse.ArgumentParser(description="元数据提取与增强")
parser.add_argument("--input", "-i", required=True, help="输入JSON文件(包含文本块)")
parser.add_argument("--output", "-o", default="enriched_chunks.json", help="输出文件")
parser.add_argument("--file-path", "-f", default=None, help="原始文件路径(用于提取文件元数据)")
parser.add_argument("--parent-sections", "-p", nargs='+', default=[], help="父章节层级")
args = parser.parse_args()
# 加载输入
with open(args.input, 'r', encoding='utf-8') as f:
data = json.load(f)
# 获取文本块列表(适配不同输入格式)
if isinstance(data, list):
chunks = data
elif isinstance(data, dict):
chunks = data.get("chunks", []) or data.get("semantic", []) or data.get("recursive_character", [])
else:
chunks = []
print(f"[INFO] 加载了 {len(chunks)} 个文本块")
# 初始化提取器
extractor = MetadataExtractor()
# 提取文件级元数据(如果提供)
file_metadata = {}
if args.file_path:
file_metadata = extractor.extract_file_metadata(args.file_path)
print(f"[INFO] 文件元数据: {file_metadata['filename']}")
# 处理每个块
enriched_chunks = []
for idx, chunk in enumerate(chunks):
if isinstance(chunk, dict):
content = chunk.get("content", "")
else:
content = str(chunk)
# 提取内容元数据
content_meta = extractor.extract_content_metadata(
content,
chunk_index=idx,
parent_sections=args.parent_sections
)
# 合并元数据
full_metadata = {**file_metadata, **content_meta}
# 创建增强块
enriched = EnrichedChunk(content, full_metadata)
enriched_chunks.append({
"content": content,
"metadata": full_metadata,
"combined_search_text": enriched.combined_text
})
if idx < 2: # 打印前两个作为示例
print(f"\n[块 {idx}] 元数据摘要:")
print(f" - 字符数: {content_meta['char_count']}")
print(f" - 关键词: {content_meta.get('keywords', [])}")
print(f" - 实体: {list(content_meta.get('entities', {}).keys())}")
# 保存结果
output_data = {
"extraction_info": {
"total_chunks": len(enriched_chunks),
"extraction_time": datetime.now().isoformat(),
"metadata_fields": list(enriched_chunks[0]["metadata"].keys()) if enriched_chunks else []
},
"chunks": enriched_chunks
}
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
print(f"\n[INFO] 增强后的数据已保存至: {args.output}")
print(f"[STATS] 平均块大小: {sum(c['metadata']['char_count'] for c in enriched_chunks)/len(enriched_chunks):.0f} 字符")
if __name__ == "__main__":
main()
6.2.1 嵌入模型管理
Python
#!/usr/bin/env python3
"""
Script: embedding_manager.py
功能: 嵌入模型管理系统,支持本地Sentence-Transformers与OpenAI API切换,包含量化优化
使用方式: python embedding_manager.py --input texts.json --model local --output embeddings.npy
"""
import os
import json
import argparse
import numpy as np
from typing import List, Union, Dict, Optional
from dataclasses import dataclass
from enum import Enum
import hashlib
import time
class ModelType(Enum):
"""模型类型枚举"""
LOCAL = "local"
OPENAI = "openai"
COHERE = "cohere"
@dataclass
class EmbeddingConfig:
"""嵌入配置"""
model_type: ModelType
model_name: str
dimension: int
batch_size: int
quantization: Optional[str] = None # 'int8', 'int4', None
api_key: Optional[str] = None
max_retries: int = 3
class LocalEmbeddingProvider:
"""本地嵌入模型提供者"""
def __init__(self, config: EmbeddingConfig):
self.config = config
self.model = None
self._load_model()
def _load_model(self):
"""加载模型(支持量化)"""
from sentence_transformers import SentenceTransformer
print(f"[INFO] 加载本地模型: {self.config.model_name}")
# 加载模型
self.model = SentenceTransformer(self.config.model_name)
# 应用量化
if self.config.quantization == "int8":
print("[INFO] 应用INT8量化...")
# 注意:实际量化需要使用optimum或类似库
# 这里简化处理,实际应使用ONNX Runtime或bitsandbytes
self.model = self.model.half() # 半精度作为演示
# 获取实际维度
sample_embedding = self.model.encode("test")
self.config.dimension = len(sample_embedding)
print(f"[INFO] 模型维度: {self.config.dimension}")
def encode(self, texts: List[str], show_progress: bool = True) -> np.ndarray:
"""
编码文本
Args:
texts: 文本列表
show_progress: 是否显示进度条
Returns:
嵌入向量数组
"""
embeddings = self.model.encode(
texts,
batch_size=self.config.batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
normalize_embeddings=True
)
return embeddings
class OpenAIEmbeddingProvider:
"""OpenAI嵌入API提供者"""
def __init__(self, config: EmbeddingConfig):
self.config = config
self.client = None
self._init_client()
def _init_client(self):
"""初始化OpenAI客户端"""
try:
from openai import OpenAI
except ImportError:
raise ImportError("请安装openai库: pip install openai")
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("需要提供OpenAI API Key")
self.client = OpenAI(api_key=api_key)
print(f"[INFO] 初始化OpenAI客户端,模型: {self.config.model_name}")
def encode(self, texts: List[str], show_progress: bool = True) -> np.ndarray:
"""
调用OpenAI API编码
Args:
texts: 文本列表
show_progress: 是否显示进度(API调用中忽略)
Returns:
嵌入向量数组
"""
embeddings = []
# OpenAI限制batch size
batch_size = min(self.config.batch_size, 100)
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
for attempt in range(self.config.max_retries):
try:
response = self.client.embeddings.create(
model=self.config.model_name,
input=batch
)
batch_embeddings = [item.embedding for item in response.data]
embeddings.extend(batch_embeddings)
if show_progress:
print(f"[PROGRESS] {min(i+batch_size, len(texts))}/{len(texts)}")
# 速率限制
time.sleep(0.1)
break
except Exception as e:
print(f"[ERROR] API调用失败 (尝试 {attempt+1}/{self.config.max_retries}): {e}")
if attempt == self.config.max_retries - 1:
raise
time.sleep(2 ** attempt) # 指数退避
return np.array(embeddings)
class EmbeddingRouter:
"""嵌入模型路由器"""
# 预配置模型
PRESETS = {
"local-e5": EmbeddingConfig(
model_type=ModelType.LOCAL,
model_name="intfloat/e5-large-v2",
dimension=1024,
batch_size=32
),
"local-minilm": EmbeddingConfig(
model_type=ModelType.LOCAL,
model_name="all-MiniLM-L6-v2",
dimension=384,
batch_size=64
),
"openai-ada": EmbeddingConfig(
model_type=ModelType.OPENAI,
model_name="text-embedding-3-small",
dimension=1536,
batch_size=100
),
"openai-large": EmbeddingConfig(
model_type=ModelType.OPENAI,
model_name="text-embedding-3-large",
dimension=3072,
batch_size=100
)
}
def __init__(self, preset: str = None, config: EmbeddingConfig = None):
"""
初始化路由器
Args:
preset: 预配置名称
config: 自定义配置
"""
if preset:
self.config = self.PRESETS.get(preset)
if not self.config:
raise ValueError(f"未知预设: {preset},可用: {list(self.PRESETS.keys())}")
elif config:
self.config = config
else:
raise ValueError("需要提供preset或config")
self.provider = self._create_provider()
def _create_provider(self):
"""创建对应的提供者"""
if self.config.model_type == ModelType.LOCAL:
return LocalEmbeddingProvider(self.config)
elif self.config.model_type == ModelType.OPENAI:
return OpenAIEmbeddingProvider(self.config)
else:
raise ValueError(f"不支持的模型类型: {self.config.model_type}")
def embed(self, texts: Union[str, List[str]], show_progress: bool = True) -> np.ndarray:
"""
嵌入文本
Args:
texts: 单个文本或文本列表
show_progress: 是否显示进度
Returns:
嵌入向量
"""
if isinstance(texts, str):
texts = [texts]
# 空值处理
texts = [t if t and str(t).strip() else "empty" for t in texts]
return self.provider.encode(texts, show_progress)
def embed_with_cache(self,
texts: List[str],
cache_dir: str = ".embedding_cache",
show_progress: bool = True) -> np.ndarray:
"""
带缓存的嵌入
Args:
texts: 文本列表
cache_dir: 缓存目录
show_progress: 是否显示进度
Returns:
嵌入向量
"""
os.makedirs(cache_dir, exist_ok=True)
# 计算文本哈希
def get_hash(text):
return hashlib.md5(text.encode()).hexdigest()[:16]
# 检查缓存
cache_key = f"{self.config.model_name.replace('/', '_')}"
cache_file = os.path.join(cache_dir, f"{cache_key}.npz")
cached_embeddings = {}
if os.path.exists(cache_file):
cached_data = np.load(cache_file, allow_pickle=True)
cached_embeddings = dict(zip(cached_data['hashes'], cached_data['embeddings']))
# 确定需要编码的文本
to_encode = []
indices = []
results = [None] * len(texts)
for i, text in enumerate(texts):
h = get_hash(text)
if h in cached_embeddings:
results[i] = cached_embeddings[h]
else:
to_encode.append(text)
indices.append(i)
if to_encode:
print(f"[INFO] 缓存命中: {len(texts) - len(to_encode)}/{len(texts)}")
new_embeddings = self.embed(to_encode, show_progress)
# 更新缓存
for idx, emb, text in zip(indices, new_embeddings, to_encode):
results[idx] = emb
cached_embeddings[get_hash(text)] = emb
# 保存缓存
np.savez_compressed(
cache_file,
hashes=list(cached_embeddings.keys()),
embeddings=list(cached_embeddings.values())
)
return np.array(results)
def main():
parser = argparse.ArgumentParser(description="嵌入模型管理")
parser.add_argument("--input", "-i", required=True, help="输入JSON文件(包含文本列表)")
parser.add_argument("--model", "-m", choices=["local-e5", "local-minilm", "openai-ada", "openai-large"],
default="local-minilm", help="模型预设")
parser.add_argument("--output", "-o", default="embeddings.npy", help="输出numpy文件")
parser.add_argument("--quantization", "-q", choices=[None, "int8", "int4"],
default=None, help="量化类型(仅本地模型)")
parser.add_argument("--use-cache", "-c", action="store_true", help="使用缓存")
args = parser.parse_args()
# 加载文本
with open(args.input, 'r', encoding='utf-8') as f:
data = json.load(f)
# 适配不同格式
if isinstance(data, list):
texts = data
elif isinstance(data, dict):
texts = [item.get("content", item.get("text", str(item))) for item in data.get("chunks", [])]
else:
texts = [str(data)]
print(f"[INFO] 加载了 {len(texts)} 个文本")
print(f"[INFO] 平均长度: {sum(len(t) for t in texts)/len(texts):.0f} 字符")
# 初始化路由器
router = EmbeddingRouter(preset=args.model)
# 嵌入
start_time = time.time()
if args.use_cache:
embeddings = router.embed_with_cache(texts)
else:
embeddings = router.embed(texts)
elapsed = time.time() - start_time
print(f"[INFO] 嵌入完成:")
print(f" - 输出形状: {embeddings.shape}")
print(f" - 耗时: {elapsed:.2f}s")
print(f" - 平均速度: {len(texts)/elapsed:.2f} 文本/秒")
# 保存
np.save(args.output, embeddings)
print(f"[INFO] 嵌入已保存至: {args.output}")
# 保存元数据
meta = {
"model": args.model,
"dimension": embeddings.shape[1],
"count": len(texts),
"quantization": args.quantization
}
with open(args.output.replace('.npy', '_meta.json'), 'w') as f:
json.dump(meta, f)
# 计算相似度矩阵示例(前5个)
if len(texts) >= 5:
sample_emb = embeddings[:5]
similarity_matrix = np.dot(sample_emb, sample_emb.T)
print(f"\n[示例] 前5个文本的相似度矩阵:")
print(similarity_matrix.round(3))
if __name__ == "__main__":
main()
6.2.2 向量数据库选型
Python
#!/usr/bin/env python3
"""
Script: vector_database_manager.py
功能: Milvus与Qdrant向量数据库操作封装,支持混合检索与元数据过滤
使用方式: python vector_database_manager.py --db-type milvus --action insert --vectors embeddings.npy
"""
import os
import json
import argparse
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import uuid
class DBType(Enum):
"""数据库类型"""
MILVUS = "milvus"
QDRANT = "qdrant"
@dataclass
class VectorRecord:
"""向量记录"""
id: str
vector: np.ndarray
metadata: Dict[str, Any]
text: Optional[str] = None
class MilvusManager:
"""Milvus向量数据库管理器"""
def __init__(self,
host: str = "localhost",
port: str = "19530",
collection_name: str = "rag_collection"):
try:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
self.milvus = __import__('pymilvus')
except ImportError:
raise ImportError("请安装pymilvus: pip install pymilvus")
self.collection_name = collection_name
self.collection = None
# 连接
connections.connect(alias="default", host=host, port=port)
print(f"[INFO] 已连接到Milvus: {host}:{port}")
def create_collection(self,
dimension: int,
metadata_fields: Optional[Dict[str, str]] = None,
drop_existing: bool = False):
"""
创建集合
Args:
dimension: 向量维度
metadata_fields: 元数据字段定义 {"field_name": "DataType"}
drop_existing: 是否删除已存在的集合
"""
if drop_existing and self.milvus.utility.has_collection(self.collection_name):
self.milvus.utility.drop_collection(self.collection_name)
print(f"[INFO] 已删除旧集合: {self.collection_name}")
if self.milvus.utility.has_collection(self.collection_name):
self.collection = self.milvus.Collection(self.collection_name)
return
# 定义字段
fields = [
self.milvus.FieldSchema(name="id", dtype=self.milvus.DataType.VARCHAR, is_primary=True, max_length=64),
self.milvus.FieldSchema(name="vector", dtype=self.milvus.DataType.FLOAT_VECTOR, dim=dimension),
self.milvus.FieldSchema(name="text", dtype=self.milvus.DataType.VARCHAR, max_length=65535)
]
# 动态元数据字段
if metadata_fields:
for field_name, field_type in metadata_fields.items():
dtype = getattr(self.milvus.DataType, field_type.upper(), self.milvus.DataType.VARCHAR)
max_len = 1024 if dtype == self.milvus.DataType.VARCHAR else None
kwargs = {"max_length": max_len} if max_len else {}
fields.append(self.milvus.FieldSchema(name=field_name, dtype=dtype, **kwargs))
schema = self.milvus.CollectionSchema(fields, description="RAG Vector Collection", enable_dynamic_field=True)
self.collection = self.milvus.Collection(self.collection_name, schema)
# 创建索引
index_params = {
"metric_type": "COSINE",
"index_type": "HNSW",
"params": {"M": 16, "efConstruction": 200}
}
self.collection.create_index(field_name="vector", index_params=index_params)
print(f"[INFO] 已创建集合: {self.collection_name}, 维度: {dimension}")
def insert(self, records: List[VectorRecord]):
"""插入记录"""
if not records:
return
entities = {
"id": [r.id for r in records],
"vector": [r.vector.tolist() for r in records],
"text": [r.text or "" for r in records]
}
# 处理元数据
for key in records[0].metadata.keys():
entities[key] = [r.metadata.get(key) for r in records]
self.collection.insert(entities)
self.collection.flush()
print(f"[INFO] 已插入 {len(records)} 条记录")
def search(self,
query_vector: np.ndarray,
top_k: int = 10,
filters: Optional[str] = None,
output_fields: Optional[List[str]] = None) -> List[Dict]:
"""
向量搜索
Args:
query_vector: 查询向量
top_k: 返回数量
filters: 过滤表达式 (e.g., "file_type == 'pdf' and page > 10")
output_fields: 返回的字段列表
Returns:
搜索结果列表
"""
self.collection.load()
search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
results = self.collection.search(
data=[query_vector.tolist()],
anns_field="vector",
param=search_params,
limit=top_k,
expr=filters,
output_fields=output_fields or ["id", "text"]
)[0]
return [
{
"id": hit.id,
"distance": hit.distance,
"text": hit.entity.text,
**{field: getattr(hit.entity, field) for field in (output_fields or []) if hasattr(hit.entity, field)}
}
for hit in results
]
def hybrid_search(self,
query_vector: np.ndarray,
query_text: str,
top_k: int = 10,
vector_weight: float = 0.7,
bm25_weight: float = 0.3) -> List[Dict]:
"""
混合搜索(向量 + BM25关键词)
Args:
query_vector: 查询向量
query_text: 查询文本(用于BM25)
top_k: 返回数量
vector_weight: 向量搜索权重
bm25_weight: BM25权重
Returns:
融合后的结果
"""
# Milvus 2.4+支持稀疏向量,这里简化演示
# 实际应使用Sparse-BM25或两路召回+RRF
vector_results = self.search(query_vector, top_k=top_k*2)
# 简单关键词匹配作为BM25代理
keywords = query_text.lower().split()
text_results = []
# 获取所有文本进行匹配(实际应使用倒排索引)
all_data = self.collection.query(
expr="id != ''",
output_fields=["id", "text", "vector"]
)
for item in all_data:
score = sum(1 for kw in keywords if kw in item["text"].lower())
if score > 0:
text_results.append({
"id": item["id"],
"text": item["text"],
"bm25_score": score
})
# RRF融合
return self._rrf_fusion(vector_results, text_results, top_k, vector_weight, bm25_weight)
def _rrf_fusion(self,
vector_results: List[Dict],
text_results: List[Dict],
top_k: int,
vector_weight: float,
bm25_weight: float,
k_constant: int = 60) -> List[Dict]:
"""RRF融合"""
scores = {}
# 向量分数
for rank, hit in enumerate(vector_results):
doc_id = hit["id"]
scores[doc_id] = scores.get(doc_id, 0) + vector_weight * (1.0 / (k_constant + rank))
scores[doc_id + "_data"] = hit
# BM25分数
# 先排序
text_results.sort(key=lambda x: x["bm25_score"], reverse=True)
for rank, hit in enumerate(text_results):
doc_id = hit["id"]
scores[doc_id] = scores.get(doc_id, 0) + bm25_weight * (1.0 / (k_constant + rank))
if doc_id + "_data" not in scores:
scores[doc_id + "_data"] = hit
# 排序
sorted_docs = sorted(scores.items(),
key=lambda x: x[1] if not x[0].endswith("_data") else 0,
reverse=True)
return [scores[item[0] + "_data"] for item in sorted_docs[:top_k] if not item[0].endswith("_data")]
class QdrantManager:
"""Qdrant向量数据库管理器"""
def __init__(self,
host: str = "localhost",
port: int = 6333,
collection_name: str = "rag_collection"):
try:
from qdrant_client import QdrantClient
except ImportError:
raise ImportError("请安装qdrant-client: pip install qdrant-client")
self.client = QdrantClient(host=host, port=port)
self.collection_name = collection_name
print(f"[INFO] 已连接到Qdrant: {host}:{port}")
def create_collection(self,
dimension: int,
distance: str = "Cosine",
on_disk: bool = True):
"""
创建集合
Args:
dimension: 向量维度
distance: 距离度量
on_disk: 是否启用磁盘存储
"""
from qdrant_client.models import Distance, VectorParams
# 删除已存在的集合
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=dimension, distance=getattr(Distance, distance)),
on_disk_payload=on_disk
)
print(f"[INFO] 已创建集合: {self.collection_name}")
def insert(self, records: List[VectorRecord]):
"""插入记录"""
from qdrant_client.models import PointStruct
points = [
PointStruct(
id=r.id,
vector=r.vector.tolist(),
payload={**r.metadata, "text": r.text}
)
for r in records
]
self.client.upsert(collection_name=self.collection_name, points=points)
print(f"[INFO] 已插入 {len(records)} 条记录")
def search(self,
query_vector: np.ndarray,
top_k: int = 10,
filters: Optional[Dict] = None,
with_payload: bool = True) -> List[Dict]:
"""
向量搜索
Args:
query_vector: 查询向量
top_k: 返回数量
filters: 过滤条件 {"must": [{"key": "file_type", "match": {"value": "pdf"}}]}
Returns:
搜索结果
"""
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector.tolist(),
limit=top_k,
query_filter=filters,
with_payload=with_payload
)
return [
{
"id": hit.id,
"score": hit.score,
**hit.payload
}
for hit in results
]
def hybrid_search(self,
query_vector: np.ndarray,
query_text: str,
top_k: int = 10) -> List[Dict]:
"""
Qdrant混合搜索(使用内置的BM25支持需要额外配置,这里演示两路召回)
"""
# 向量搜索
vector_results = self.search(query_vector, top_k=top_k*2)
# 关键词过滤增强
keywords = query_text.lower().split()
# 构造过滤条件(简单版本:至少匹配一个关键词)
# 实际应使用全文索引或稀疏向量
filtered_results = []
for r in vector_results:
text = r.get("text", "").lower()
match_count = sum(1 for kw in keywords if kw in text)
r["keyword_matches"] = match_count
filtered_results.append(r)
# 重排序:结合向量相似度和关键词匹配
filtered_results.sort(key=lambda x: (x["score"] * 0.7 + x["keyword_matches"] * 0.1), reverse=True)
return filtered_results[:top_k]
def main():
parser = argparse.ArgumentParser(description="向量数据库管理")
parser.add_argument("--db-type", choices=["milvus", "qdrant"], required=True, help="数据库类型")
parser.add_argument("--action", choices=["create", "insert", "search", "hybrid"],
default="search", help="操作类型")
parser.add_argument("--vectors", "-v", default=None, help="向量文件(.npy)")
parser.add_argument("--metadata", "-m", default=None, help="元数据文件(.json)")
parser.add_argument("--collection", "-c", default="rag_collection", help="集合名称")
parser.add_argument("--query", "-q", default=None, help="查询文本")
parser.add_argument("--top-k", "-k", type=int, default=5, help="返回数量")
parser.add_argument("--host", default=None, help="服务器地址")
args = parser.parse_args()
# 初始化管理器
if args.db_type == "milvus":
host = args.host or "localhost"
manager = MilvusManager(host=host, collection_name=args.collection)
else:
host = args.host or "localhost"
port = 6333 if args.host else 6333
manager = QdrantManager(host=host, port=port, collection_name=args.collection)
# 执行操作
if args.action == "create":
if args.db_type == "milvus":
manager.create_collection(dimension=384, drop_existing=True)
else:
manager.create_collection(dimension=384)
elif args.action == "insert":
if not args.vectors or not args.metadata:
raise ValueError("插入操作需要--vectors和--metadata")
vectors = np.load(args.vectors)
with open(args.metadata, 'r') as f:
metadata_list = json.load(f)
# 创建记录
records = []
for i, (vec, meta) in enumerate(zip(vectors, metadata_list)):
rec = VectorRecord(
id=str(uuid.uuid4()),
vector=vec,
metadata=meta.get("metadata", meta),
text=meta.get("content", meta.get("text", ""))
)
records.append(rec)
manager.insert(records)
elif args.action in ["search", "hybrid"]:
if not args.query:
raise ValueError("搜索操作需要--query")
# 加载查询向量(这里简化处理,实际应使用嵌入模型)
# 假设查询文本已编码
query_vec = np.random.randn(384).astype(np.float32)
query_vec = query_vec / np.linalg.norm(query_vec)
if args.action == "search":
results = manager.search(query_vec, top_k=args.top_k)
else:
results = manager.hybrid_search(query_vec, args.query, top_k=args.top_k)
print(f"\n[搜索结果] 找到 {len(results)} 条:")
for i, r in enumerate(results[:3]):
text = r.get("text", r.get("entity", {}).get("text", ""))[:100]
score = r.get("distance", r.get("score", 0))
print(f"{i+1}. [得分: {score:.4f}] {text}...")
if __name__ == "__main__":
main()
6.2.3 混合检索
Python
#!/usr/bin/env python3
"""
Script: hybrid_retrieval_rrf.py
功能: 实现混合检索(稠密向量+BM25)与RRF融合算法,含可视化对比
使用方式: python hybrid_retrieval_rrf.py --dense-results dense.json --sparse-results sparse.json --output fused.json
"""
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
from dataclasses import dataclass
from collections import defaultdict
@dataclass
class RetrievalResult:
"""检索结果项"""
doc_id: str
score: float
rank: int
content: str = ""
source: str = "" # 'dense' or 'sparse'
class ReciprocalRankFusion:
"""倒数排序融合实现"""
def __init__(self, k: int = 60):
"""
初始化RRF
Args:
k: 平滑常数,通常取60
"""
self.k = k
def fuse(self,
dense_results: List[RetrievalResult],
sparse_results: List[RetrievalResult],
top_n: int = 10) -> List[RetrievalResult]:
"""
融合两组检索结果
Args:
dense_results: 稠密检索结果(向量相似度)
sparse_results: 稀疏检索结果(BM25)
top_n: 返回的最终数量
Returns:
融合后的排序结果
"""
# 构建文档集合
all_doc_ids = set(r.doc_id for r in dense_results + sparse_results)
# 构建排名映射
dense_ranks = {r.doc_id: r.rank for r in dense_results}
sparse_ranks = {r.doc_id: r.rank for r in sparse_results}
# 存储文档信息
doc_info = {r.doc_id: r for r in dense_results + sparse_results}
# 计算RRF分数
rrf_scores = {}
for doc_id in all_doc_ids:
score = 0.0
# 稠密检索贡献
if doc_id in dense_ranks:
score += 1.0 / (self.k + dense_ranks[doc_id])
# 稀疏检索贡献
if doc_id in sparse_ranks:
score += 1.0 / (self.k + sparse_ranks[doc_id])
rrf_scores[doc_id] = score
# 排序
sorted_docs = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
# 构建结果
fused_results = []
for new_rank, (doc_id, score) in enumerate(sorted_docs[:top_n], 1):
original = doc_info[doc_id]
fused_results.append(RetrievalResult(
doc_id=doc_id,
score=score,
rank=new_rank,
content=original.content,
source="fused"
))
return fused_results
class HybridRetriever:
"""混合检索器"""
def __init__(self,
vector_weight: float = 1.0,
bm25_weight: float = 1.0,
rrf_k: int = 60):
"""
初始化混合检索器
Args:
vector_weight: 向量检索权重
bm25_weight: BM25检索权重
rrf_k: RRF常数
"""
self.vector_weight = vector_weight
self.bm25_weight = bm25_weight
self.rrf = ReciprocalRankFusion(k=rrf_k)
self.documents = {}
def index(self,
texts: List[str],
embeddings: np.ndarray,
doc_ids: List[str] = None):
"""
索引文档
Args:
texts: 文档文本列表
embeddings: 文档嵌入向量
doc_ids: 文档ID列表(可选)
"""
if doc_ids is None:
doc_ids = [f"doc_{i}" for i in range(len(texts))]
# 构建倒排索引(简化版BM25)
self.inverted_index = defaultdict(list)
self.doc_freqs = defaultdict(int)
self.doc_lengths = []
for idx, (doc_id, text) in enumerate(zip(doc_ids, texts)):
self.documents[doc_id] = {
"text": text,
"embedding": embeddings[idx],
"length": len(text.split())
}
# 构建倒排索引
words = text.lower().split()
self.doc_lengths.append(len(words))
unique_words = set(words)
for word in unique_words:
self.inverted_index[word].append((doc_id, words.count(word)))
self.doc_freqs[word] += 1
self.avg_doc_length = np.mean(self.doc_lengths)
self.total_docs = len(texts)
self.embeddings = embeddings
self.doc_ids = doc_ids
print(f"[INFO] 索引完成: {len(texts)} 文档")
print(f"[INFO] 词汇表大小: {len(self.inverted_index)}")
def dense_search(self,
query_embedding: np.ndarray,
top_k: int = 20) -> List[RetrievalResult]:
"""
稠密向量检索
Args:
query_embedding: 查询向量
top_k: 返回数量
Returns:
检索结果
"""
# 计算余弦相似度
similarities = np.dot(self.embeddings, query_embedding)
similarities = similarities / (np.linalg.norm(self.embeddings, axis=1) *
np.linalg.norm(query_embedding) + 1e-10)
# 获取Top-K
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for rank, idx in enumerate(top_indices, 1):
doc_id = self.doc_ids[idx]
results.append(RetrievalResult(
doc_id=doc_id,
score=float(similarities[idx]),
rank=rank,
content=self.documents[doc_id]["text"][:200],
source="dense"
))
return results
def bm25_search(self,
query: str,
top_k: int = 20,
k1: float = 1.5,
b: float = 0.75) -> List[RetrievalResult]:
"""
BM25关键词检索
Args:
query: 查询文本
top_k: 返回数量
k1: BM25参数
b: BM25参数
Returns:
检索结果
"""
query_words = query.lower().split()
scores = defaultdict(float)
for word in query_words:
if word not in self.inverted_index:
continue
idf = np.log((self.total_docs - self.doc_freqs[word] + 0.5) /
(self.doc_freqs[word] + 0.5) + 1.0)
for doc_id, tf in self.inverted_index[word]:
doc_len = self.documents[doc_id]["length"]
norm = 1 - b + b * (doc_len / self.avg_doc_length)
score = idf * (tf * (k1 + 1)) / (tf + k1 * norm)
scores[doc_id] += score
# 排序
sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
results = []
for rank, (doc_id, score) in enumerate(sorted_docs[:top_k], 1):
results.append(RetrievalResult(
doc_id=doc_id,
score=score,
rank=rank,
content=self.documents[doc_id]["text"][:200],
source="sparse"
))
return results
def hybrid_search(self,
query: str,
query_embedding: np.ndarray,
top_k: int = 10) -> Tuple[List[RetrievalResult], Dict]:
"""
执行混合检索
Args:
query: 查询文本
query_embedding: 查询向量
top_k: 最终返回数量
Returns:
(融合结果, 中间结果字典)
"""
# 两路召回
dense_results = self.dense_search(query_embedding, top_k=top_k*2)
sparse_results = self.bm25_search(query, top_k=top_k*2)
# 调整分数(可选,用于加权)
for r in dense_results:
r.score *= self.vector_weight
for r in sparse_results:
r.score *= self.bm25_weight
# RRF融合
fused_results = self.rrf.fuse(dense_results, sparse_results, top_n=top_k)
# 收集中间结果用于分析
analysis = {
"dense_only": dense_results[:top_k],
"sparse_only": sparse_results[:top_k],
"fused": fused_results,
"overlap": len(set(r.doc_id for r in dense_results[:top_k]) &
set(r.doc_id for r in sparse_results[:top_k]))
}
return fused_results, analysis
class HybridSearchVisualizer:
"""混合检索可视化"""
def visualize_fusion(self,
analysis: Dict,
output_path: str = "hybrid_search_analysis.png"):
"""
可视化混合检索结果对比
Args:
analysis: 分析数据
output_path: 输出路径
"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
dense_results = analysis["dense_only"]
sparse_results = analysis["sparse_only"]
fused_results = analysis["fused"]
# 1. 排名对比(前5个)
ax1 = axes[0, 0]
doc_ids = [f"doc_{i}" for i in range(5)]
dense_ranks = [next((r.rank for r in dense_results if r.doc_id == did), 20)
for did in doc_ids]
sparse_ranks = [next((r.rank for r in sparse_results if r.doc_id == did), 20)
for did in doc_ids]
fused_ranks = [next((r.rank for r in fused_results if r.doc_id == did), 20)
for did in doc_ids]
x = np.arange(len(doc_ids))
width = 0.25
ax1.bar(x - width, dense_ranks, width, label='Dense', color='blue', alpha=0.6)
ax1.bar(x, sparse_ranks, width, label='Sparse', color='red', alpha=0.6)
ax1.bar(x + width, fused_ranks, width, label='Fused', color='green', alpha=0.6)
ax1.set_ylabel('Rank (lower is better)')
ax1.set_title('Ranking Comparison (Sample Docs)')
ax1.set_xticks(x)
ax1.set_xticklabels(doc_ids, rotation=45)
ax1.legend()
ax1.invert_yaxis() # 排名越小越好
# 2. RRF分数分布
ax2 = axes[0, 1]
fused_scores = [r.score for r in fused_results]
ax2.barh(range(len(fused_scores)), fused_scores, color='green', alpha=0.6)
ax2.set_xlabel('RRF Score')
ax2.set_ylabel('Result Rank')
ax2.set_title('RRF Score Distribution (Fused Results)')
# 3. 检索重叠度
ax3 = axes[1, 0]
overlap = analysis["overlap"]
total = len(dense_results)
categories = ['Dense Only', 'Sparse Only', 'Overlap']
values = [total - overlap, total - overlap, overlap]
colors = ['blue', 'red', 'purple']
ax3.pie(values, labels=categories, colors=colors, autopct='%1.1f%%')
ax3.set_title('Result Overlap Between Dense and Sparse')
# 4. 累计得分对比
ax4 = axes[1, 1]
dense_cumsum = np.cumsum([r.score for r in dense_results])
sparse_cumsum = np.cumsum([r.score for r in sparse_results])
fused_cumsum = np.cumsum([r.score for r in fused_results])
ax4.plot(range(1, len(dense_cumsum)+1), dense_cumsum,
marker='o', label='Dense', color='blue')
ax4.plot(range(1, len(sparse_cumsum)+1), sparse_cumsum,
marker='s', label='Sparse', color='red')
ax4.plot(range(1, len(fused_cumsum)+1), fused_cumsum,
marker='^', label='Fused', color='green', linewidth=2)
ax4.set_xlabel('Top-K')
ax4.set_ylabel('Cumulative Score')
ax4.set_title('Cumulative Relevance Score')
ax4.legend()
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 可视化已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="混合检索与RRF融合")
parser.add_argument("--mode", choices=["demo", "fuse"], default="demo", help="运行模式")
parser.add_argument("--dense-results", "-d", default=None, help="稠密检索结果JSON")
parser.add_argument("--sparse-results", "-s", default=None, help="稀疏检索结果JSON")
parser.add_argument("--output", "-o", default="fused_results.json", help="输出文件")
parser.add_argument("--visualize", "-v", action="store_true", help="生成可视化")
parser.add_argument("--rrf-k", "-k", type=int, default=60, help="RRF常数k")
args = parser.parse_args()
if args.mode == "fuse":
# 加载已有结果进行融合
with open(args.dense_results, 'r') as f:
dense_data = json.load(f)
with open(args.sparse_results, 'r') as f:
sparse_data = json.load(f)
dense_results = [RetrievalResult(**r) for r in dense_data]
sparse_results = [RetrievalResult(**r) for r in sparse_data]
rrf = ReciprocalRankFusion(k=args.rrf_k)
fused = rrf.fuse(dense_results, sparse_results)
# 保存
output = [vars(r) for r in fused]
with open(args.output, 'w') as f:
json.dump(output, f, indent=2)
print(f"[INFO] 融合结果已保存至: {args.output}")
else:
# 演示模式:创建模拟数据并展示
print("[DEMO] 运行混合检索演示...")
# 模拟文档和嵌入
np.random.seed(42)
n_docs = 100
dim = 384
texts = [
f"This is document {i} about machine learning and information retrieval" if i % 3 == 0
else f"Document {i} discusses deep learning neural networks" if i % 3 == 1
else f"Paper {i} on natural language processing and semantic search"
for i in range(n_docs)
]
embeddings = np.random.randn(n_docs, dim).astype(np.float32)
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# 初始化检索器
retriever = HybridRetriever(vector_weight=1.0, bm25_weight=1.0, rrf_k=60)
retriever.index(texts, embeddings)
# 模拟查询
query = "machine learning semantic search"
query_vec = np.random.randn(dim).astype(np.float32)
query_vec = query_vec / np.linalg.norm(query_vec)
# 执行混合检索
results, analysis = retriever.hybrid_search(query, query_vec, top_k=10)
print(f"\n[结果] 查询: '{query}'")
print(f"[统计] 稠密/稀疏重叠文档数: {analysis['overlap']}/10")
print("\nTop-5 融合结果:")
for i, r in enumerate(results[:5]):
print(f"{i+1}. {r.doc_id} (RRF分数: {r.score:.4f})")
# 可视化
if args.visualize:
visualizer = HybridSearchVisualizer()
visualizer.visualize_fusion(analysis)
if __name__ == "__main__":
main()
6.2.4 索引优化
Python
#!/usr/bin/env python3
"""
Script: index_optimization.py
功能: HNSW索引参数调优与向量量化(Quantization)实现,含性能对比可视化
使用方式: python index_optimization.py --vectors embeddings.npy --benchmark
"""
import os
import time
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
from dataclasses import dataclass
import hnswlib # 使用hnswlib作为HNSW实现
@dataclass
class IndexConfig:
"""索引配置"""
method: str # 'hnsw', 'flat', 'quantized'
dim: int
max_elements: int
M: int = 16 # HNSW参数
ef_construction: int = 200
ef_search: int = 64
quantization: str = None # 'pq', 'sq', None
class VectorIndex:
"""向量索引基类"""
def __init__(self, config: IndexConfig):
self.config = config
self.index = None
self.vectors = None
def build(self, vectors: np.ndarray):
raise NotImplementedError
def search(self, query: np.ndarray, k: int = 10) -> Tuple[List[int], List[float]]:
raise NotImplementedError
def get_size(self) -> int:
"""获取索引大小(字节)"""
raise NotImplementedError
class HNSWIndex(VectorIndex):
"""HNSW索引实现"""
def build(self, vectors: np.ndarray):
"""构建HNSW索引"""
print(f"[HNSW] 构建索引: M={self.config.M}, efConstruction={self.config.ef_construction}")
num_elements, dim = vectors.shape
assert dim == self.config.dim
# 创建索引
index = hnswlib.Index(space='cosine', dim=dim)
index.init_index(
max_elements=self.config.max_elements,
ef_construction=self.config.ef_construction,
M=self.config.M
)
# 添加数据
index.add_items(vectors, ids=np.arange(num_elements))
index.set_ef(self.config.ef_search)
self.index = index
self.vectors = vectors
def search(self, query: np.ndarray, k: int = 10) -> Tuple[List[int], List[float]]:
"""搜索"""
labels, distances = self.index.knn_query(query, k=k)
return labels[0].tolist(), distances[0].tolist()
def get_size(self) -> int:
"""估计索引大小(简化计算)"""
# HNSW大小 ≈ 向量数据 + 图结构
vector_size = self.vectors.nbytes
graph_size = self.vectors.shape[0] * self.config.M * 4 * 2 # 近似
return vector_size + graph_size
def set_ef(self, ef: int):
"""调整搜索参数"""
self.config.ef_search = ef
self.index.set_ef(ef)
class ProductQuantizationIndex(VectorIndex):
"""乘积量化索引"""
def __init__(self, config: IndexConfig, n_subspaces: int = 8, n_clusters: int = 256):
super().__init__(config)
self.n_subspaces = n_subspaces
self.n_clusters = n_clusters
self.codebooks = None
self.codes = None
def build(self, vectors: np.ndarray):
"""构建PQ索引"""
print(f"[PQ] 构建乘积量化索引: {self.n_subspaces}子空间, {self.n_clusters}聚类中心")
n, dim = vectors.shape
sub_dim = dim // self.n_subspaces
self.codebooks = []
self.codes = np.zeros((n, self.n_subspaces), dtype=np.uint8)
# 对每个子空间训练k-means
for i in range(self.n_subspaces):
sub_vectors = vectors[:, i*sub_dim:(i+1)*sub_dim]
# 使用简化的k-means(实际应使用更高效实现)
# 这里随机选择中心点作为演示
indices = np.random.choice(n, self.n_clusters, replace=False)
centroids = sub_vectors[indices]
# 分配代码
distances = np.linalg.norm(sub_vectors[:, np.newaxis] - centroids, axis=2)
self.codes[:, i] = np.argmin(distances, axis=1)
self.codebooks.append(centroids)
self.vectors = vectors # 保留原始向量用于对比
def search(self, query: np.ndarray, k: int = 10) -> Tuple[List[int], List[float]]:
"""非对称距离计算(ADC)搜索"""
n = len(self.codes)
sub_dim = self.config.dim // self.n_subspaces
# 计算查询到每个子空间中心点的距离
all_distances = np.zeros((n, self.n_subspaces))
for i in range(self.n_subspaces):
sub_query = query[i*sub_dim:(i+1)*sub_dim]
centroids = self.codebooks[i]
# 查询到中心点的距离
query_to_centroids = np.linalg.norm(sub_query - centroids, axis=1)
# 根据代码获取距离
all_distances[:, i] = query_to_centroids[self.codes[:, i]]
# 汇总距离(近似)
total_distances = np.sum(all_distances, axis=1)
# 返回Top-K
top_k = np.argsort(total_distances)[:k]
return top_k.tolist(), total_distances[top_k].tolist()
def get_size(self) -> int:
"""计算压缩后大小"""
code_size = self.codes.nbytes
codebook_size = sum(cb.nbytes for cb in self.codebooks)
return code_size + codebook_size
class ScalarQuantizationIndex(VectorIndex):
"""标量量化(INT8)"""
def build(self, vectors: np.ndarray):
"""构建INT8量化索引"""
print("[SQ] 构建INT8标量量化索引")
# 计算最小值和最大值用于量化
self.min_vals = vectors.min(axis=0)
self.max_vals = vectors.max(axis=0)
self.scale = (self.max_vals - self.min_vals) / 255.0
self.scale[self.scale == 0] = 1e-10
# 量化为INT8
self.quantized = ((vectors - self.min_vals) / self.scale).astype(np.uint8)
self.vectors = vectors
def search(self, query: np.ndarray, k: int = 10) -> Tuple[List[int], List[float]]:
"""搜索(反量化后计算)"""
# 反量化(简化,实际应使用快速距离计算)
reconstructed = self.quantized.astype(np.float32) * self.scale + self.min_vals
# 计算距离
distances = np.linalg.norm(reconstructed - query, axis=1)
top_k = np.argsort(distances)[:k]
return top_k.tolist(), distances[top_k].tolist()
def get_size(self) -> int:
"""压缩后大小"""
return self.quantized.nbytes
class IndexBenchmark:
"""索引性能评测"""
def __init__(self, vectors: np.ndarray, queries: np.ndarray, ground_truth: np.ndarray = None):
"""
初始化评测
Args:
vectors: 文档向量
queries: 查询向量
ground_truth: 真实的最近邻(用于计算召回率)
"""
self.vectors = vectors
self.queries = queries
self.ground_truth = ground_truth
def benchmark(self, index: VectorIndex, k: int = 10) -> Dict:
"""
评测索引
Returns:
性能指标字典
"""
# 构建时间
start = time.time()
index.build(self.vectors)
build_time = time.time() - start
# 搜索时间
latencies = []
for q in self.queries:
start = time.time()
index.search(q, k=k)
latencies.append((time.time() - start) * 1000) # ms
avg_latency = np.mean(latencies)
p99_latency = np.percentile(latencies, 99)
# 召回率(如果有真实标签)
recall = 0.0
if self.ground_truth is not None:
recalls = []
for i, q in enumerate(self.queries):
ids, _ = index.search(q, k=k)
if len(ids) > 0:
hits = len(set(ids) & set(self.ground_truth[i]))
recalls.append(hits / k)
recall = np.mean(recalls)
# 索引大小
size_bytes = index.get_size()
size_mb = size_bytes / (1024 * 1024)
return {
"build_time_sec": build_time,
"avg_latency_ms": avg_latency,
"p99_latency_ms": p99_latency,
"recall_at_k": recall,
"index_size_mb": size_mb,
"compression_ratio": self.vectors.nbytes / size_bytes if size_bytes > 0 else 1.0
}
def visualize_benchmark(results: Dict[str, Dict], output_path: str = "index_benchmark.png"):
"""可视化评测结果"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
methods = list(results.keys())
colors = ['blue', 'green', 'red', 'orange']
# 1. 延迟对比
ax1 = axes[0, 0]
latencies = [results[m]["avg_latency_ms"] for m in methods]
bars1 = ax1.bar(methods, latencies, color=colors[:len(methods)], alpha=0.6)
ax1.set_ylabel('Average Latency (ms)')
ax1.set_title('Query Latency Comparison')
for bar in bars1:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}', ha='center', va='bottom')
# 2. 召回率对比
ax2 = axes[0, 1]
recalls = [results[m]["recall_at_k"] for m in methods]
bars2 = ax2.bar(methods, recalls, color=colors[:len(methods)], alpha=0.6)
ax2.set_ylabel('Recall@K')
ax2.set_title('Recall Rate Comparison')
ax2.set_ylim([0, 1])
for bar in bars2:
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}', ha='center', va='bottom')
# 3. 索引大小对比
ax3 = axes[1, 0]
sizes = [results[m]["index_size_mb"] for m in methods]
bars3 = ax3.bar(methods, sizes, color=colors[:len(methods)], alpha=0.6)
ax3.set_ylabel('Index Size (MB)')
ax3.set_title('Storage Footprint')
for bar in bars3:
height = bar.get_height()
ax3.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f}', ha='center', va='bottom')
# 4. 压缩比与召回率权衡
ax4 = axes[1, 1]
compression_ratios = [results[m]["compression_ratio"] for m in methods]
ax4.scatter(compression_ratios, recalls, s=200, c=colors[:len(methods)], alpha=0.6)
for i, method in enumerate(methods):
ax4.annotate(method, (compression_ratios[i], recalls[i]),
xytext=(5, 5), textcoords='offset points')
ax4.set_xlabel('Compression Ratio (higher is better)')
ax4.set_ylabel('Recall@K')
ax4.set_title('Compression vs Quality Trade-off')
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 评测图表已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="向量索引优化")
parser.add_argument("--vectors", "-v", default=None, help="向量文件(.npy)")
parser.add_argument("--benchmark", "-b", action="store_true", help="运行评测")
parser.add_argument("--n-vectors", "-n", type=int, default=10000, help="向量数量")
parser.add_argument("--dim", "-d", type=int, default=384, help="向量维度")
parser.add_argument("--hnsw-m", type=int, default=16, help="HNSW M参数")
parser.add_argument("--ef-construction", type=int, default=200, help="HNSW构建参数")
parser.add_argument("--quantization", "-q", choices=[None, "pq", "sq"],
default=None, help="量化方法")
parser.add_argument("--output", "-o", default="index_benchmark.png", help="输出图表")
args = parser.parse_args()
# 加载或生成数据
if args.vectors:
vectors = np.load(args.vectors)
print(f"[INFO] 加载向量: {vectors.shape}")
else:
print(f"[INFO] 生成随机测试数据: {args.n_vectors} x {args.dim}")
np.random.seed(42)
vectors = np.random.randn(args.n_vectors, args.dim).astype(np.float32)
vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
n, dim = vectors.shape
if args.benchmark:
# 生成查询
n_queries = 100
queries = vectors[np.random.choice(n, n_queries)]
# 计算真实最近邻(使用暴力搜索)
print("[INFO] 计算真实最近邻...")
ground_truth = []
for q in queries:
dists = np.linalg.norm(vectors - q, axis=1)
gt = np.argsort(dists)[1:11] # 排除自身
ground_truth.append(gt)
benchmark = IndexBenchmark(vectors, queries, ground_truth)
# 测试不同配置
results = {}
# 1. 标准HNSW
config1 = IndexConfig(
method="hnsw",
dim=dim,
max_elements=n,
M=args.hnsw_m,
ef_construction=args.ef_construction,
ef_search=64
)
hnsw_index = HNSWIndex(config1)
results["HNSW"] = benchmark.benchmark(hnsw_index, k=10)
# 2. HNSW + 高EF(质量优先)
config2 = IndexConfig(
method="hnsw_high",
dim=dim,
max_elements=n,
M=32,
ef_construction=400,
ef_search=128
)
hnsw_high = HNSWIndex(config2)
results["HNSW-High"] = benchmark.benchmark(hnsw_high, k=10)
# 3. 乘积量化
if args.quantization == "pq" or args.quantization is None:
config3 = IndexConfig(method="pq", dim=dim, max_elements=n)
pq_index = ProductQuantizationIndex(config3, n_subspaces=8, n_clusters=256)
results["PQ"] = benchmark.benchmark(pq_index, k=10)
# 4. 标量量化
if args.quantization == "sq" or args.quantization is None:
config4 = IndexConfig(method="sq", dim=dim, max_elements=n)
sq_index = ScalarQuantizationIndex(config4)
results["SQ-INT8"] = benchmark.benchmark(sq_index, k=10)
# 打印结果
print("\n" + "="*60)
print("索引性能评测结果")
print("="*60)
for method, metrics in results.items():
print(f"\n{method}:")
for metric, value in metrics.items():
if isinstance(value, float):
print(f" {metric}: {value:.4f}")
else:
print(f" {metric}: {value}")
# 保存JSON
with open("benchmark_results.json", 'w') as f:
json.dump(results, f, indent=2)
# 可视化
visualize_benchmark(results, args.output)
else:
print("[INFO] 使用 --benchmark 运行性能评测")
if __name__ == "__main__":
main()
6.3.1 查询重写
Python
#!/usr/bin/env python3
"""
Script: hyde_query_rewrite.py
功能: HyDE(Hypothetical Document Embedding)查询重写实现,支持假设文档生成与融合检索
使用方式: python hyde_query_rewrite.py --query "your question" --api-key YOUR_OPENAI_KEY
"""
import os
import json
import argparse
import numpy as np
from typing import List, Dict, Optional
from dataclasses import dataclass
from sentence_transformers import SentenceTransformer
@dataclass
class HyDERewriteResult:
"""HyDE重写结果"""
original_query: str
hypothetical_doc: str
hyde_embedding: np.ndarray
combined_embedding: np.ndarray
retrieval_results: List[Dict]
class HyDERewriter:
"""
Hypothetical Document Embeddings (HyDE) 实现
基于Gao et al. (2022)的论文实现,使用LLM生成假设答案文档,
通过假设文档的嵌入向量进行检索,而非直接使用查询嵌入。
"""
DEFAULT_PROMPT = """Given the question below, write a detailed hypothetical document that would answer this question.
The document should be detailed, factual, and contain specific information that would be found in relevant source documents.
Question: {query}
Hypothetical Document:"""
def __init__(self,
llm_provider: str = "openai",
embedding_model: str = 'all-MiniLM-L6-v2',
api_key: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 300):
"""
初始化HyDE重写器
Args:
llm_provider: LLM提供者 ('openai', 'local', 'azure')
embedding_model: 本地嵌入模型名称
api_key: API密钥
temperature: 生成温度
max_tokens: 最大生成token数
"""
self.llm_provider = llm_provider
self.temperature = temperature
self.max_tokens = max_tokens
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
# 初始化嵌入模型
print(f"[INFO] 加载嵌入模型: {embedding_model}")
self.embedder = SentenceTransformer(embedding_model)
self.embedding_dim = self.embedder.get_sentence_embedding_dimension()
# 初始化LLM客户端
self._init_llm()
def _init_llm(self):
"""初始化LLM客户端"""
if self.llm_provider == "openai":
try:
from openai import OpenAI
except ImportError:
raise ImportError("请安装openai库: pip install openai")
if not self.api_key:
raise ValueError("需要提供OpenAI API Key")
self.client = OpenAI(api_key=self.api_key)
self.model = "gpt-3.5-turbo-instruct" # 或其他模型
elif self.llm_provider == "local":
# 本地模型支持(如使用transformers)
print("[WARN] 本地LLM支持需要额外配置,当前使用模拟模式")
self.client = None
else:
raise ValueError(f"不支持的LLM提供者: {self.llm_provider}")
def generate_hypothetical_document(self, query: str, prompt_template: Optional[str] = None) -> str:
"""
生成假设文档
Args:
query: 原始查询
prompt_template: 自定义提示模板
Returns:
生成的假设文档
"""
prompt = (prompt_template or self.DEFAULT_PROMPT).format(query=query)
if self.llm_provider == "openai":
response = self.client.completions.create(
model=self.model,
prompt=prompt,
max_tokens=self.max_tokens,
temperature=self.temperature,
stop=None
)
hypothetical_doc = response.choices[0].text.strip()
else:
# 模拟模式:返回简单假设
hypothetical_doc = f"This document explains that {query} is an important topic with several key aspects..."
return hypothetical_doc
def rewrite(self,
query: str,
alpha: float = 0.8,
use_original: bool = True) -> HyDERewriteResult:
"""
执行HyDE重写
Args:
query: 原始查询
alpha: 假设文档嵌入的权重(0-1)
use_original: 是否结合原始查询嵌入
Returns:
HyDE重写结果
"""
# 生成假设文档
print(f"[HyDE] 生成假设文档...")
hypothetical_doc = self.generate_hypothetical_document(query)
print(f"[HyDE] 假设文档长度: {len(hypothetical_doc)} 字符")
# 计算嵌入
hyde_embedding = self.embedder.encode(hypothetical_doc, convert_to_numpy=True)
if use_original:
original_embedding = self.embedder.encode(query, convert_to_numpy=True)
# 加权融合
combined = alpha * hyde_embedding + (1 - alpha) * original_embedding
combined = combined / np.linalg.norm(combined)
else:
combined = hyde_embedding
return HyDERewriteResult(
original_query=query,
hypothetical_doc=hypothetical_doc,
hyde_embedding=hyde_embedding,
combined_embedding=combined,
retrieval_results=[]
)
def retrieve_with_hyde(self,
query: str,
index_vectors: np.ndarray,
index_texts: List[str],
top_k: int = 10,
alpha: float = 0.8) -> HyDERewriteResult:
"""
使用HyDE进行检索(端到端)
Args:
query: 查询
index_vectors: 索引向量
index_texts: 索引文本
top_k: 返回数量
alpha: 融合权重
Returns:
包含检索结果的重写结果
"""
# 重写查询
result = self.rewrite(query, alpha=alpha)
# 向量检索
similarities = np.dot(index_vectors, result.combined_embedding)
top_indices = np.argsort(similarities)[::-1][:top_k]
# 构建结果
retrieval_results = []
for rank, idx in enumerate(top_indices, 1):
retrieval_results.append({
"rank": rank,
"index": int(idx),
"text": index_texts[idx][:200],
"similarity": float(similarities[idx])
})
result.retrieval_results = retrieval_results
return result
class HyDEComparison:
"""HyDE效果对比分析"""
def compare(self,
queries: List[str],
hyde_rewriter: HyDERewriter,
index_vectors: np.ndarray,
index_texts: List[str],
ground_truth_indices: List[List[int]]) -> Dict:
"""
对比原始查询与HyDE的检索效果
Returns:
对比指标
"""
results = {
"baseline": {"recalls": [], "mrrs": []},
"hyde": {"recalls": [], "mrrs": []}
}
for query, gt in zip(queries, ground_truth_indices):
# Baseline检索
query_emb = hyde_rewriter.embedder.encode(query)
baseline_sims = np.dot(index_vectors, query_emb)
baseline_top = np.argsort(baseline_sims)[::-1][:10]
# HyDE检索
hyde_result = hyde_rewriter.retrieve_with_hyde(query, index_vectors, index_texts, top_k=10)
hyde_top = [r["index"] for r in hyde_result.retrieval_results]
# 计算Recall@10
baseline_recall = len(set(baseline_top) & set(gt)) / len(gt)
hyde_recall = len(set(hyde_top) & set(gt)) / len(gt)
# 计算MRR
baseline_mrr = 0
hyde_mrr = 0
for i, idx in enumerate(baseline_top):
if idx in gt:
baseline_mrr = 1.0 / (i + 1)
break
for i, idx in enumerate(hyde_top):
if idx in gt:
hyde_mrr = 1.0 / (i + 1)
break
results["baseline"]["recalls"].append(baseline_recall)
results["baseline"]["mrrs"].append(baseline_mrr)
results["hyde"]["recalls"].append(hyde_recall)
results["hyde"]["mrrs"].append(hyde_mrr)
# 汇总
summary = {
"baseline": {
"avg_recall@10": np.mean(results["baseline"]["recalls"]),
"mrr": np.mean(results["baseline"]["mrrs"])
},
"hyde": {
"avg_recall@10": np.mean(results["hyde"]["recalls"]),
"mrr": np.mean(results["hyde"]["mrrs"])
}
}
summary["improvement"] = {
"recall": (summary["hyde"]["avg_recall@10"] - summary["baseline"]["avg_recall@10"]) /
(summary["baseline"]["avg_recall@10"] + 1e-10),
"mrr": (summary["hyde"]["mrr"] - summary["baseline"]["mrr"]) /
(summary["baseline"]["mrr"] + 1e-10)
}
return summary
def main():
parser = argparse.ArgumentParser(description="HyDE查询重写")
parser.add_argument("--query", "-q", default=None, help="查询文本")
parser.add_argument("--api-key", "-k", default=None, help="OpenAI API Key")
parser.add_argument("--vectors", "-v", default=None, help="索引向量文件(.npy)")
parser.add_argument("--texts", "-t", default=None, help="索引文本文件(.json)")
parser.add_argument("--alpha", "-a", type=float, default=0.8, help="HyDE权重")
parser.add_argument("--output", "-o", default="hyde_result.json", help="输出文件")
parser.add_argument("--demo", "-d", action="store_true", help="运行演示模式")
args = parser.parse_args()
if args.demo or (args.query and not args.vectors):
# 演示模式
print("[DEMO] HyDE查询重写演示")
# 模拟数据
np.random.seed(42)
n_docs = 100
dim = 384
texts = [
f"Document about machine learning and neural networks, specifically discussing {topic}"
for topic in ["supervised learning", "deep learning", "reinforcement learning",
"NLP", "computer vision", "optimization"] * 16 + ["other topics"] * 4
]
# 生成相关向量(模拟语义聚类)
vectors = np.random.randn(n_docs, dim).astype(np.float32)
vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
# 初始化HyDE(模拟模式)
rewriter = HyDERewriter(llm_provider="local", embedding_model='all-MiniLM-L6-v2')
# 测试查询
test_queries = [
"What are neural networks used for?",
"Explain deep learning applications",
"How does machine learning work?"
]
print(f"\n测试查询: '{test_queries[0]}'\n")
result = rewriter.retrieve_with_hyde(
test_queries[0],
vectors,
texts,
top_k=5,
alpha=args.alpha
)
print(f"假设文档:\n{result.hypothetical_doc}\n")
print("检索结果:")
for r in result.retrieval_results[:3]:
print(f"{r['rank']}. [{r['similarity']:.3f}] {r['text'][:100]}...")
# 保存
output_data = {
"query": result.original_query,
"hypothetical_document": result.hypothetical_doc,
"retrieval_results": result.retrieval_results
}
with open(args.output, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"\n结果已保存至: {args.output}")
elif args.query and args.vectors:
# 实际运行模式
vectors = np.load(args.vectors)
if args.texts:
with open(args.texts, 'r') as f:
texts = json.load(f)
else:
texts = [f"Document {i}" for i in range(len(vectors))]
rewriter = HyDERewriter(
llm_provider="openai",
api_key=args.api_key,
embedding_model='all-MiniLM-L6-v2'
)
result = rewriter.retrieve_with_hyde(
args.query,
vectors,
texts,
top_k=10,
alpha=args.alpha
)
print(f"\n查询: {result.original_query}")
print(f"假设文档: {result.hypothetical_doc[:300]}...")
print(f"\nTop-3 结果:")
for r in result.retrieval_results[:3]:
print(f"{r['rank']}. 相似度: {r['similarity']:.4f}")
print(f" 文本: {r['text'][:150]}...\n")
# 保存
with open(args.output, 'w') as f:
json.dump({
"original_query": result.original_query,
"hypothetical_document": result.hypothetical_doc,
"hyde_embedding_shape": list(result.hyde_embedding.shape),
"retrieval_results": result.retrieval_results
}, f, indent=2)
else:
parser.print_help()
if __name__ == "__main__":
main()
6.3.2 重排序
Python
#!/usr/bin/env python3
"""
Script: reranking_system.py
功能: 交叉编码器重排序(Cross-Encoder Rerank)实现,支持Cohere API与本地模型微调
使用方式: python reranking_system.py --query "question" --candidates candidates.json --model cross-encoder
"""
import os
import json
import argparse
import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
@dataclass
class RerankResult:
"""重排序结果"""
doc_id: str
original_rank: int
reranked_score: float
text: str
cross_encoder_score: float
class CrossEncoderReranker:
"""
交叉编码器重排序器
使用双塔架构的初始检索后,应用交叉编码器进行精确重排序。
交叉编码器将查询和文档拼接后输入Transformer,计算相关性分数。
"""
def __init__(self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: str = None,
max_length: int = 512):
"""
初始化重排序器
Args:
model_name: 交叉编码器模型名称
device: 计算设备 ('cuda', 'cpu', None为自动)
max_length: 最大序列长度
"""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError("请安装sentence-transformers: pip install sentence-transformers")
print(f"[INFO] 加载交叉编码器: {model_name}")
self.model = CrossEncoder(model_name, device=device, max_length=max_length)
self.model_name = model_name
def rerank(self,
query: str,
documents: List[Dict],
top_k: int = None) -> List[RerankResult]:
"""
重排序文档
Args:
query: 查询文本
documents: 候选文档列表,每个文档包含'id', 'text', 'score'
top_k: 返回的最终数量(None表示返回全部重排序)
Returns:
重排序后的结果列表
"""
if not documents:
return []
# 构建输入对 [(query, doc1), (query, doc2), ...]
pairs = [(query, doc.get("text", doc.get("content", ""))) for doc in documents]
# 预测相关性分数
print(f"[Rerank] 对 {len(pairs)} 个文档进行交叉编码...")
scores = self.model.predict(pairs, convert_to_numpy=True, show_progress_bar=False)
# 构建结果
results = []
for i, (doc, score) in enumerate(zip(documents, scores)):
results.append(RerankResult(
doc_id=doc.get("id", f"doc_{i}"),
original_rank=i + 1,
reranked_score=float(score),
text=doc.get("text", doc.get("content", ""))[:200],
cross_encoder_score=float(score)
))
# 按交叉编码器分数排序
results.sort(key=lambda x: x.reranked_score, reverse=True)
if top_k:
results = results[:top_k]
# 更新重排序后的排名
for new_rank, result in enumerate(results, 1):
result.reranked_position = new_rank
return results
def rerank_batch(self,
queries: List[str],
doc_lists: List[List[Dict]]]) -> List[List[RerankResult]]:
"""
批量重排序
Args:
queries: 查询列表
doc_lists: 每个查询对应的候选文档列表
Returns:
重排序结果列表的列表
"""
all_results = []
for query, docs in zip(queries, doc_lists):
results = self.rerank(query, docs)
all_results.append(results)
return all_results
class CohereReranker:
"""Cohere Rerank API封装"""
def __init__(self, api_key: str = None, model: str = "rerank-english-v2.0"):
"""
初始化Cohere重排序器
Args:
api_key: Cohere API密钥
model: 重排序模型版本
"""
try:
import cohere
except ImportError:
raise ImportError("请安装cohere: pip install cohere")
self.api_key = api_key or os.getenv("COHERE_API_KEY")
if not self.api_key:
raise ValueError("需要提供Cohere API Key")
self.client = cohere.Client(self.api_key)
self.model = model
print(f"[INFO] 初始化Cohere Rerank: {model}")
def rerank(self,
query: str,
documents: List[Dict],
top_k: int = None,
max_tokens_per_doc: int = 4096) -> List[RerankResult]:
"""
使用Cohere API重排序
Args:
query: 查询
documents: 候选文档
top_k: 返回数量
max_tokens_per_doc: 每文档最大token数
Returns:
重排序结果
"""
docs = [doc.get("text", doc.get("content", "")) for doc in documents]
response = self.client.rerank(
model=self.model,
query=query,
documents=docs,
top_n=top_k or len(documents),
max_chunks_per_doc=max_tokens_per_doc
)
results = []
for r in response.results:
original_doc = documents[r.index]
results.append(RerankResult(
doc_id=original_doc.get("id", f"doc_{r.index}"),
original_rank=r.index + 1,
reranked_score=r.relevance_score,
text=docs[r.index][:200],
cross_encoder_score=r.relevance_score
))
return results
class RerankerTrainer:
"""
交叉编码器微调(简化演示)
实际微调需要准备标注数据(查询-文档-相关性三元组),
使用sentence-transformers的CrossEncoder进行训练。
"""
def prepare_training_data(self,
queries: List[str],
documents: List[str],
labels: List[float]) -> List[List]:
"""
准备训练数据
Args:
queries: 查询列表
documents: 文档列表
labels: 相关性标签(0-1或离散等级)
Returns:
训练样本列表
"""
# 转换为 sentence-transformers 格式
train_samples = []
for q, d, label in zip(queries, documents, labels):
train_samples.append([q, d, float(label)])
return train_samples
def train(self,
train_samples: List[List],
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
output_path: str = "./fine_tuned_reranker",
epochs: int = 3,
batch_size: int = 16):
"""
微调模型
Args:
train_samples: 训练样本
model_name: 基础模型
output_path: 输出路径
epochs: 训练轮数
batch_size: 批次大小
"""
try:
from sentence_transformers import CrossEncoder, InputExample
from torch.utils.data import DataLoader
except ImportError:
raise ImportError("需要sentence-transformers和torch")
# 加载模型
model = CrossEncoder(model_name, num_labels=1)
# 准备数据
train_examples = [InputExample(texts=[t[0], t[1]], label=t[2]) for t in train_samples]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
# 训练
model.fit(
train_dataloader=train_dataloader,
epochs=epochs,
warmup_steps=int(len(train_dataloader) * 0.1),
output_path=output_path
)
print(f"[INFO] 模型已保存至: {output_path}")
return model
class RerankingVisualizer:
"""重排序可视化"""
def visualize_comparison(self,
query: str,
original_results: List[Dict],
reranked_results: List[RerankResult],
output_path: str = "rerank_comparison.png"):
"""
可视化重排序前后对比
Args:
query: 查询
original_results: 原始检索结果
reranked_results: 重排序后结果
output_path: 输出路径
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# 1. 排名变化对比(前10个)
ax1 = axes[0, 0]
n_show = min(10, len(reranked_results))
doc_ids = [r.doc_id for r in reranked_results[:n_show]]
orig_ranks = [r.original_rank for r in reranked_results[:n_show]]
new_ranks = list(range(1, n_show + 1))
for i, (old, new) in enumerate(zip(orig_ranks, new_ranks)):
color = 'green' if new < old else ('red' if new > old else 'gray')
ax1.arrow(old, i, new-old, 0, head_width=0.1, head_length=0.2,
fc=color, ec=color, alpha=0.6, length_includes_head=True)
ax1.scatter([old, new], [i, i], color=color, s=50)
ax1.set_yticks(range(n_show))
ax1.set_yticklabels([f"Doc {d[-8:]}" for d in doc_ids])
ax1.set_xlabel('Rank Position')
ax1.set_title(f'Ranking Changes (Query: {query[:30]}...)')
ax1.grid(True, alpha=0.3)
# 2. 分数分布
ax2 = axes[0, 1]
orig_scores = [r.get("score", 0) for r in original_results[:n_show]]
rerank_scores = [r.reranked_score for r in reranked_results[:n_show]]
x = np.arange(len(orig_scores))
width = 0.35
ax2.bar(x - width/2, orig_scores, width, label='Original (Bi-Encoder)', alpha=0.6)
ax2.bar(x + width/2, rerank_scores, width, label='Reranked (Cross-Encoder)', alpha=0.6)
ax2.set_xlabel('Document Index')
ax2.set_ylabel('Relevance Score')
ax2.set_title('Score Distribution Comparison')
ax2.legend()
# 3. Top-K位置变化热力图
ax3 = axes[1, 0]
rank_changes = np.zeros((n_show, n_show))
for i, r in enumerate(reranked_results[:n_show]):
old_pos = r.original_rank - 1
if old_pos < n_show:
rank_changes[old_pos, i] = 1
im = ax3.imshow(rank_changes, cmap='Blues', aspect='auto')
ax3.set_xlabel('New Rank')
ax3.set_ylabel('Original Rank')
ax3.set_title('Rank Position Transition Matrix')
plt.colorbar(im, ax=ax3)
# 4. 提升统计
ax4 = axes[1, 1]
improvements = [r.original_rank - (i+1) for i, r in enumerate(reranked_results[:n_show])]
colors = ['green' if x > 0 else 'red' for x in improvements]
ax4.barh(range(n_show), improvements, color=colors, alpha=0.6)
ax4.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
ax4.set_xlabel('Rank Improvement')
ax4.set_ylabel('Document')
ax4.set_title('Per-Document Rank Improvement')
ax4.set_yticks(range(n_show))
ax4.set_yticklabels([f"Doc {i+1}" for i in range(n_show)])
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 可视化已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="重排序系统")
parser.add_argument("--query", "-q", default="What is machine learning?", help="查询文本")
parser.add_argument("--candidates", "-c", default=None, help="候选文档JSON")
parser.add_argument("--model", "-m", choices=["cross-encoder", "cohere"],
default="cross-encoder", help="重排序模型")
parser.add_argument("--api-key", "-k", default=None, help="API Key (Cohere)")
parser.add_argument("--top-k", type=int, default=10, help="重排序后的Top-K")
parser.add_argument("--visualize", "-v", action="store_true", help="生成可视化")
args = parser.parse_args()
# 加载或生成候选
if args.candidates:
with open(args.candidates, 'r') as f:
candidates = json.load(f)
else:
# 模拟候选
candidates = [
{"id": f"doc_{i}", "text": f"This is a document about {'neural networks' if i % 2 == 0 else 'cooking recipes'} and related topics.",
"score": 0.9 - i*0.05}
for i in range(20)
]
print(f"[INFO] 加载了 {len(candidates)} 个候选文档")
# 初始化重排序器
if args.model == "cross-encoder":
reranker = CrossEncoderReranker()
else:
reranker = CohereReranker(api_key=args.api_key)
# 执行重排序
results = reranker.rerank(args.query, candidates, top_k=args.top_k)
print(f"\n查询: {args.query}")
print("重排序结果 (Top-5):")
for i, r in enumerate(results[:5]):
change = r.original_rank - (i+1)
change_str = f"+{change}" if change > 0 else str(change)
print(f"{i+1}. [原排名: {r.original_rank}, 变化: {change_str}] 分数: {r.reranked_score:.4f}")
print(f" 文本: {r.text[:100]}...")
# 可视化
if args.visualize:
visualizer = RerankingVisualizer()
visualizer.visualize_comparison(args.query, candidates, results)
# 保存结果
output = {
"query": args.query,
"reranked_results": [
{
"doc_id": r.doc_id,
"original_rank": r.original_rank,
"new_rank": i+1,
"score": r.reranked_score,
"text": r.text
}
for i, r in enumerate(results)
]
}
with open("rerank_results.json", 'w') as f:
json.dump(output, f, indent=2)
print(f"\n结果已保存至: rerank_results.json")
if __name__ == "__main__":
main()
6.3.3 多跳检索
Python
#!/usr/bin/env python3
"""
Script: graphrag_multihop.py
功能: GraphRAG多跳检索实现,包含知识图谱构建、社区检测与多跳推理
使用方式: python graphrag_multihop.py --build --docs documents.json 或 --query "your question"
"""
import json
import argparse
import numpy as np
import networkx as nx
from typing import List, Dict, Set, Tuple, Optional
from dataclasses import dataclass, field
from collections import defaultdict
import community.community_louvain as community_louvain
@dataclass
class Entity:
"""知识图谱实体"""
id: str
name: str
type: str # 'PERSON', 'ORG', 'CONCEPT', etc.
embedding: Optional[np.ndarray] = None
source_chunks: List[str] = field(default_factory=list)
@dataclass
class Relation:
"""关系"""
source: str
target: str
relation_type: str
weight: float = 1.0
evidence: List[str] = field(default_factory=list)
class KnowledgeGraphBuilder:
"""知识图谱构建器"""
def __init__(self, embedding_model=None):
self.graph = nx.DiGraph()
self.entities = {}
self.embedding_model = embedding_model
def extract_entities_relations(self, text: str, chunk_id: str) -> Tuple[List[Entity], List[Relation]]:
"""
从文本中提取实体和关系
实际应用应使用spaCy、Transformers NER或LLM进行抽取。
这里使用基于规则的模拟实现。
"""
entities = []
relations = []
# 模拟实体抽取(关键词匹配)
import re
# 简单模式匹配作为演示
patterns = {
"CONCEPT": r'\b(machine learning|deep learning|neural network|algorithm|model|dataset)\b',
"PERSON": r'\b([A-Z][a-z]+ [A-Z][a-z]+)\b',
"ORG": r'\b(Google|OpenAI|Microsoft|Meta|University of [A-Z][a-z]+)\b',
"TECH": r'\b(Transformer|BERT|GPT|CNN|RNN|LSTM)\b'
}
found_entities = {}
for ent_type, pattern in patterns.items():
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
name = match.group(0)
ent_id = f"{ent_type}_{hash(name) % 10000}"
if ent_id not in found_entities:
ent = Entity(
id=ent_id,
name=name,
type=ent_type,
source_chunks=[chunk_id]
)
found_entities[ent_id] = ent
entities.append(ent)
# 模拟关系抽取(共现实体间建立关系)
ent_list = list(found_entities.values())
for i, ent1 in enumerate(ent_list):
for ent2 in ent_list[i+1:]:
# 计算共现权重(基于距离)
rel = Relation(
source=ent1.id,
target=ent2.id,
relation_type="CO_OCCURS_WITH",
weight=1.0,
evidence=[chunk_id]
)
relations.append(rel)
return entities, relations
def add_document(self, doc_id: str, text: str, embedding: Optional[np.ndarray] = None):
"""添加文档到图谱"""
entities, relations = self.extract_entities_relations(text, doc_id)
# 添加实体节点
for ent in entities:
if ent.id not in self.graph:
self.graph.add_node(
ent.id,
name=ent.name,
type=ent.type,
embedding=embedding,
chunks=[doc_id]
)
self.entities[ent.id] = ent
else:
self.graph.nodes[ent.id]["chunks"].append(doc_id)
# 添加关系边
for rel in relations:
if self.graph.has_edge(rel.source, rel.target):
self.graph[rel.source][rel.target]["weight"] += rel.weight
self.graph[rel.source][rel.target]["evidence"].extend(rel.evidence)
else:
self.graph.add_edge(
rel.source,
rel.target,
relation_type=rel.relation_type,
weight=rel.weight,
evidence=rel.evidence
)
def detect_communities(self) -> Dict[str, int]:
"""
使用Louvain算法检测社区
Returns:
节点到社区的映射
"""
# 转换为无向图进行社区检测
undirected = self.graph.to_undirected()
partition = community_louvain.best_partition(undirected)
# 为每个社区生成摘要
communities = defaultdict(list)
for node, comm_id in partition.items():
communities[comm_id].append(node)
print(f"[INFO] 检测到 {len(communities)} 个社区")
for comm_id, nodes in list(communities.items())[:3]:
print(f" 社区 {comm_id}: {len(nodes)} 实体")
return partition
def generate_community_summaries(self, partition: Dict[str, int]) -> Dict[int, str]:
"""
为每个社区生成摘要
实际应使用LLM生成自然语言摘要。
"""
summaries = {}
communities = defaultdict(list)
for node, comm_id in partition.items():
communities[comm_id].append(self.entities[node].name)
for comm_id, entities in communities.items():
# 简单摘要:列出主要实体
summaries[comm_id] = f"Community {comm_id}: " + ", ".join(entities[:5])
return summaries
class MultiHopRetriever:
"""多跳检索器"""
def __init__(self, graph: nx.DiGraph, entity_embeddings: Dict[str, np.ndarray] = None):
self.graph = graph
self.entity_embeddings = entity_embeddings or {}
def seed_entity_extraction(self, query: str) -> List[str]:
"""
从查询中提取种子实体
实际应使用NER模型。
"""
# 简单匹配
seeds = []
for node, data in self.graph.nodes(data=True):
if data.get("name", "").lower() in query.lower():
seeds.append(node)
return seeds if seeds else list(self.graph.nodes())[:3]
def personalized_ppr(self,
seed_entities: List[str],
alpha: float = 0.85,
max_iter: int = 100,
tol: float = 1e-6) -> Dict[str, float]:
"""
个性化PageRank多跳扩散
Args:
seed_entities: 种子实体ID
alpha: 随机游走概率
max_iter: 最大迭代次数
tol: 收敛阈值
Returns:
实体重要性分数
"""
# 构建转移矩阵(简化版)
nodes = list(self.graph.nodes())
node_idx = {n: i for i, n in enumerate(nodes)}
n = len(nodes)
# 初始化分布(从种子实体均匀开始)
r = np.zeros(n)
for seed in seed_entities:
if seed in node_idx:
r[node_idx[seed]] = 1.0 / len(seed_entities)
# 构建邻接矩阵
A = nx.to_numpy_array(self.graph, weight='weight')
# 归一化
row_sums = A.sum(axis=1)
row_sums[row_sums == 0] = 1
P = A / row_sums[:, np.newaxis]
# 迭代计算
teleport = r.copy()
for _ in range(max_iter):
r_new = alpha * P.T @ r + (1 - alpha) * teleport
if np.linalg.norm(r_new - r) < tol:
break
r = r_new
return {nodes[i]: score for i, score in enumerate(r)}
def multi_hop_retrieve(self,
query: str,
depth: int = 2,
top_k: int = 10) -> Dict:
"""
执行多跳检索
Args:
query: 查询
depth: 跳数
top_k: 返回文档数
Returns:
检索结果及相关子图
"""
# 提取种子实体
seeds = self.seed_entity_extraction(query)
print(f"[INFO] 种子实体: {[self.graph.nodes[s]['name'] for s in seeds[:3]]}")
# 执行PPR扩散
ppr_scores = self.personalized_ppr(seeds)
# 基于PPR分数检索相关实体
sorted_entities = sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)
top_entities = [e for e, _ in sorted_entities[:20]]
# 检索关联文档
retrieved_docs = set()
for ent_id in top_entities:
chunks = self.graph.nodes[ent_id].get("chunks", [])
retrieved_docs.update(chunks)
# 构建子图
subgraph = self.graph.subgraph(top_entities)
# 生成推理路径(简单版:种子到高排名实体的路径)
paths = []
for target, score in sorted_entities[:5]:
for seed in seeds:
if nx.has_path(self.graph, seed, target):
path = nx.shortest_path(self.graph, seed, target)
path_names = [self.graph.nodes[n]["name"] for n in path]
paths.append({
"path": path_names,
"confidence": score
})
return {
"retrieved_documents": list(retrieved_docs)[:top_k],
"relevant_entities": [self.graph.nodes[e] for e in top_entities[:10]],
"subgraph": subgraph,
"reasoning_paths": paths,
"ppr_scores": {k: v for k, v in sorted_entities[:20]}
}
class GraphRAGPipeline:
"""完整的GraphRAG流水线"""
def __init__(self):
self.kg_builder = KnowledgeGraphBuilder()
self.retriever = None
self.communities = {}
self.summaries = {}
def build_index(self, documents: List[Dict]):
"""
构建索引
Args:
documents: 文档列表,每个包含'id', 'text', 'embedding'
"""
print("[INFO] 构建知识图谱...")
for doc in documents:
self.kg_builder.add_document(
doc["id"],
doc["text"],
doc.get("embedding")
)
# 社区检测
self.communities = self.kg_builder.detect_communities()
self.summaries = self.kg_builder.generate_community_summaries(self.communities)
# 初始化检索器
self.retriever = MultiHopRetriever(self.kg_builder.graph)
print(f"[INFO] 索引构建完成: {self.kg_builder.graph.number_of_nodes()} 实体, "
f"{self.kg_builder.graph.number_of_edges()} 关系")
def query(self, query_text: str, depth: int = 2) -> Dict:
"""
执行查询
Args:
query_text: 查询文本
depth: 多跳深度
Returns:
包含检索结果、推理路径和社区摘要的答案
"""
if not self.retriever:
raise ValueError("索引未构建")
# 多跳检索
retrieval_results = self.retriever.multi_hop_retrieve(query_text, depth=depth)
# 获取相关社区摘要
relevant_communities = set()
for ent_id in retrieval_results["relevant_entities"]:
if isinstance(ent_id, dict):
ent_id = ent_id["id"]
if ent_id in self.communities:
relevant_communities.add(self.communities[ent_id])
community_contexts = [self.summaries[c] for c in relevant_communities]
return {
"query": query_text,
"retrieval": retrieval_results,
"community_summaries": community_contexts,
"graph_stats": {
"total_entities": self.kg_builder.graph.number_of_nodes(),
"total_relations": self.kg_builder.graph.number_of_edges()
}
}
def main():
parser = argparse.ArgumentParser(description="GraphRAG多跳检索")
parser.add_argument("--build", "-b", action="store_true", help="构建索引模式")
parser.add_argument("--query", "-q", default=None, help="查询文本")
parser.add_argument("--docs", "-d", default=None, help="文档JSON文件")
parser.add_argument("--depth", type=int, default=2, help="多跳深度")
parser.add_argument("--output", "-o", default="graphrag_result.json", help="输出文件")
args = parser.parse_args()
pipeline = GraphRAGPipeline()
if args.build:
if not args.docs:
# 生成演示数据
print("[DEMO] 生成演示文档...")
docs = [
{
"id": f"doc_{i}",
"text": f"Machine learning is closely related to neural networks. Deep learning uses multi-layer neural networks. "
f"Transformers are a type of neural network architecture used in NLP. BERT is a transformer model. "
f"GPT is also based on transformers. {i}"
}
for i in range(10)
]
else:
with open(args.docs, 'r') as f:
docs = json.load(f)
pipeline.build_index(docs)
# 保存图谱
graph_data = nx.node_link_data(pipeline.kg_builder.graph)
with open("knowledge_graph.json", 'w') as f:
json.dump(graph_data, f)
print("[INFO] 知识图谱已保存至: knowledge_graph.json")
elif args.query:
# 加载已有图谱
if os.path.exists("knowledge_graph.json"):
with open("knowledge_graph.json", 'r') as f:
graph_data = json.load(f)
pipeline.kg_builder.graph = nx.node_link_graph(graph_data)
pipeline.retriever = MultiHopRetriever(pipeline.kg_builder.graph)
print("[INFO] 已加载知识图谱")
# 执行查询
result = pipeline.query(args.query, depth=args.depth)
print(f"\n查询: {args.query}")
print(f"检索到 {len(result['retrieval']['retrieved_documents'])} 个文档")
print(f"相关实体: {[e.get('name', e) for e in result['retrieval']['relevant_entities'][:5]]}")
print(f"推理路径数: {len(result['retrieval']['reasoning_paths'])}")
# 保存结果
# 转换不可序列化对象
result["retrieval"]["subgraph"] = nx.node_link_data(result["retrieval"]["subgraph"])
with open(args.output, 'w') as f:
json.dump(result, f, indent=2, default=str)
print(f"\n结果已保存至: {args.output}")
else:
parser.print_help()
if __name__ == "__main__":
main()
6.3.4 查询路由
Python
#!/usr/bin/env python3
"""
Script: query_router.py
功能: 查询路由系统,支持元数据过滤与多索引联合查询,含A/B测试支持
使用方式: python query_router.py --query "question about revenue 2024" --metadata "year=2024,type=finance"
"""
import re
import json
import argparse
from typing import List, Dict, Optional, Callable
from dataclasses import dataclass, field
from enum import Enum
import numpy as np
class QueryType(Enum):
"""查询类型"""
FACTUAL = "factual" # 事实查询
ANALYTICAL = "analytical" # 分析查询
COMPARATIVE = "comparative" # 比较查询
TEMPORAL = "temporal" # 时间相关
ENTITY = "entity" # 实体查询
@dataclass
class IndexDefinition:
"""索引定义"""
name: str
description: str
metadata_schema: Dict[str, type]
embedding_dim: int
filterable_fields: List[str]
priority: int = 0
@dataclass
class RoutingDecision:
"""路由决策"""
selected_indices: List[str]
filters: Dict[str, any]
query_rewrite: Optional[str] = None
strategy: str = "unknown"
class QueryClassifier:
"""查询分类器"""
# 关键词模式
PATTERNS = {
QueryType.TEMPORAL: r'\b(202\d|last year|this quarter|January|February|March|April|May|June|July|August|September|October|November|December)\b',
QueryType.COMPARATIVE: r'\b(compare|versus|vs|difference between|better than|worse than)\b',
QueryType.ANALYTICAL: r'\b(why|how|explain|analyze|cause|reason|impact|effect)\b',
QueryType.FACTUAL: r'\b(what is|who is|when|where|how many|how much|list of)\b',
QueryType.ENTITY: r'\b(who|where is|company|person|organization)\b'
}
def classify(self, query: str) -> QueryType:
"""分类查询"""
query_lower = query.lower()
scores = {}
for qtype, pattern in self.PATTERNS.items():
matches = len(re.findall(pattern, query_lower, re.IGNORECASE))
scores[qtype] = matches
if max(scores.values()) > 0:
return max(scores, key=scores.get)
return QueryType.FACTUAL
def extract_temporal_constraints(self, query: str) -> Optional[Dict]:
"""提取时间约束"""
# 年份匹配
years = re.findall(r'\b(202\d|201\d)\b', query)
if years:
return {"year": int(years[0])}
# 相对时间
if "last year" in query.lower():
return {"relative": "last_year"}
if "this year" in query.lower():
return {"relative": "current_year"}
return None
def extract_entity_constraints(self, query: str) -> List[str]:
"""提取实体约束"""
# 简单实现:引号内的内容或特定格式
entities = re.findall(r'"([^"]+)"', query)
return entities
class MetadataFilterParser:
"""元数据过滤解析器"""
OPERATORS = {
'eq': lambda x, y: x == y,
'gt': lambda x, y: x > y,
'lt': lambda x, y: x < y,
'gte': lambda x, y: x >= y,
'lte': lambda x, y: x <= y,
'in': lambda x, y: x in y,
'contains': lambda x, y: y in x if isinstance(x, str) else False
}
def parse(self, filter_string: str) -> Dict:
"""
解析过滤字符串
格式: "field1=value1,field2>value2"
"""
if not filter_string:
return {}
filters = {}
conditions = filter_string.split(',')
for cond in conditions:
cond = cond.strip()
if '=' in cond:
field, value = cond.split('=', 1)
filters[field.strip()] = {"op": "eq", "value": self._convert_value(value.strip())}
elif '>' in cond:
field, value = cond.split('>', 1)
filters[field.strip()] = {"op": "gt", "value": self._convert_value(value.strip())}
elif '<' in cond:
field, value = cond.split('<', 1)
filters[field.strip()] = {"op": "lt", "value": self._convert_value(value.strip())}
return filters
def _convert_value(self, value: str):
"""转换值类型"""
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value
def apply_filter(self, metadata: Dict, filters: Dict) -> bool:
"""应用过滤条件"""
for field, condition in filters.items():
if field not in metadata:
return False
op_func = self.OPERATORS.get(condition["op"])
if not op_func:
continue
if not op_func(metadata[field], condition["value"]):
return False
return True
class QueryRouter:
"""查询路由器"""
def __init__(self):
self.indices: Dict[str, IndexDefinition] = {}
self.classifier = QueryClassifier()
self.filter_parser = MetadataFilterParser()
self.routing_rules: List[Callable] = []
self.query_cache = {}
def register_index(self, index_def: IndexDefinition):
"""注册索引"""
self.indices[index_def.name] = index_def
print(f"[INFO] 注册索引: {index_def.name}")
def add_routing_rule(self, rule: Callable):
"""添加路由规则"""
self.routing_rules.append(rule)
def route(self,
query: str,
explicit_filters: Optional[str] = None,
user_context: Optional[Dict] = None) -> RoutingDecision:
"""
路由查询到适当索引
Args:
query: 查询文本
explicit_filters: 显式过滤条件
user_context: 用户上下文(权限、偏好等)
Returns:
路由决策
"""
# 解析显式过滤
filters = self.filter_parser.parse(explicit_filters) if explicit_filters else {}
# 查询分类
query_type = self.classifier.classify(query)
print(f"[ROUTE] 查询类型: {query_type.value}")
# 提取隐含约束
temporal = self.classifier.extract_temporal_constraints(query)
if temporal:
filters.update(temporal)
# 选择索引
selected = []
# 基于类型的路由
if query_type == QueryType.TEMPORAL and "year" in filters:
# 路由到时间序列索引
selected.append("temporal_index")
if any(field in filters for field in ["company", "organization"]):
selected.append("entity_index")
# 默认索引
if not selected:
selected = ["default_knowledge_base"]
# 应用自定义规则
for rule in self.routing_rules:
result = rule(query, query_type, filters, user_context)
if result:
selected = result.get("indices", selected)
filters.update(result.get("additional_filters", {}))
# 去重
selected = list(set(selected))
return RoutingDecision(
selected_indices=selected,
filters=filters,
query_rewrite=self._rewrite_query(query, query_type),
strategy="type_based"
)
def _rewrite_query(self, query: str, query_type: QueryType) -> Optional[str]:
"""根据类型重写查询"""
if query_type == QueryType.ANALYTICAL:
return f"Explain in detail: {query}"
elif query_type == QueryType.COMPARATIVE:
return f"Compare and contrast: {query}"
return None
def execute_routed_query(self,
query: str,
decision: RoutingDecision,
retrievers: Dict[str, Callable]) -> Dict:
"""
执行路由后的查询
Args:
query: 原始查询
decision: 路由决策
retrievers: 索引名称到检索函数的映射
Returns:
聚合结果
"""
all_results = []
for idx_name in decision.selected_indices:
if idx_name not in retrievers:
continue
print(f"[EXEC] 查询索引: {idx_name}")
retriever = retrievers[idx_name]
# 应用过滤
results = retriever(
query=decision.query_rewrite or query,
filters=decision.filters
)
all_results.extend([
{**r, "source_index": idx_name}
for r in results
])
# 合并与去重
seen = set()
unique_results = []
for r in sorted(all_results, key=lambda x: x.get("score", 0), reverse=True):
doc_id = r.get("id") or r.get("doc_id")
if doc_id not in seen:
seen.add(doc_id)
unique_results.append(r)
return {
"routed_indices": decision.selected_indices,
"applied_filters": decision.filters,
"results": unique_results[:10],
"total_candidates": len(all_results)
}
class ABRouter(QueryRouter):
"""支持A/B测试的路由器"""
def __init__(self, experiment_config: Optional[Dict] = None):
super().__init__()
self.experiments = experiment_config or {}
self.user_assignments = {}
def assign_experiment(self, user_id: str, experiment_id: str) -> str:
"""为用户分配实验组"""
if experiment_id not in self.experiments:
return "control"
# 基于用户ID哈希分配(确保一致性)
import hashlib
hash_val = int(hashlib.md5(f"{user_id}_{experiment_id}".encode()).hexdigest(), 16)
variants = self.experiments[experiment_id]["variants"]
assignment = variants[hash_val % len(variants)]
self.user_assignments[(user_id, experiment_id)] = assignment
return assignment
def route(self, query: str, user_id: str = "anonymous", **kwargs) -> RoutingDecision:
"""带A/B测试的路由"""
# 检查实验
for exp_id, config in self.experiments.items():
if config.get("query_filter", lambda x: True)(query):
variant = self.assign_experiment(user_id, exp_id)
# 应用实验配置
if variant != "control" and "routing_override" in config:
override = config["routing_override"]
return RoutingDecision(
selected_indices=override.get("indices", ["default"]),
filters=override.get("filters", {}),
strategy=f"experiment_{exp_id}_{variant}"
)
return super().route(query, **kwargs)
def main():
parser = argparse.ArgumentParser(description="查询路由系统")
parser.add_argument("--query", "-q", required=True, help="查询文本")
parser.add_argument("--metadata", "-m", default=None, help="元数据过滤")
parser.add_argument("--user-id", "-u", default="anonymous", help="用户ID")
parser.add_argument("--demo", "-d", action="store_true", help="运行演示")
args = parser.parse_args()
# 初始化路由器
router = QueryRouter()
# 注册示例索引
router.register_index(IndexDefinition(
name="finance_docs",
description="财务文档",
metadata_schema={"year": int, "type": str},
embedding_dim=384,
filterable_fields=["year", "department", "type"]
))
router.register_index(IndexDefinition(
name="hr_policies",
description="人力资源政策",
metadata_schema={"effective_date": str, "region": str},
embedding_dim=384,
filterable_fields=["region", "status"]
))
# 添加自定义路由规则
def finance_rule(query, qtype, filters, context):
if "revenue" in query.lower() or "profit" in query.lower():
return {"indices": ["finance_docs"], "additional_filters": {"type": "financial_report"}}
return None
router.add_routing_rule(finance_rule)
# 模拟检索函数
mock_retrievers = {
"finance_docs": lambda q, f: [
{"id": f"fin_{i}", "score": 0.9-i*0.05, "text": f"Financial data for {f}"}
for i in range(5)
],
"hr_policies": lambda q, f: [
{"id": f"hr_{i}", "score": 0.85-i*0.05, "text": f"HR policy regarding {q}"}
for i in range(5)
],
"default_knowledge_base": lambda q, f: [
{"id": f"gen_{i}", "score": 0.8-i*0.05, "text": f"General info about {q}"}
for i in range(5)
]
}
# 执行路由
print(f"\n查询: {args.query}")
print(f"用户: {args.user_id}")
print(f"元数据过滤: {args.metadata}")
decision = router.route(args.query, args.metadata, {"user_id": args.user_id})
print(f"\n路由决策:")
print(f" 选中索引: {decision.selected_indices}")
print(f" 应用过滤: {decision.filters}")
print(f" 查询重写: {decision.query_rewrite or '无'}")
print(f" 策略: {decision.strategy}")
# 执行查询
results = router.execute_routed_query(args.query, decision, mock_retrievers)
print(f"\n检索结果 (Top-3):")
for i, r in enumerate(results["results"][:3]):
print(f" {i+1}. [{r['source_index']}] {r['text'][:60]}... (得分: {r['score']:.3f})")
# 保存结果
output = {
"query": args.query,
"routing_decision": {
"selected_indices": decision.selected_indices,
"filters": decision.filters,
"strategy": decision.strategy
},
"retrieval_results": results["results"]
}
with open("routing_result.json", 'w') as f:
json.dump(output, f, indent=2)
print(f"\n结果已保存至: routing_result.json")
if __name__ == "__main__":
main()
6.4.1 上下文组装
Python
复制
#!/usr/bin/env python3
"""
Script: context_assembler.py
功能: 上下文组装系统,实现Token预算管理、MMR多样性排序与动态截断
使用方式: python context_assembler.py --query "question" --chunks chunks.json --max-tokens 2000
"""
import json
import argparse
import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
from transformers import GPT2TokenizerFast
@dataclass
class Chunk:
"""文本块"""
id: str
content: str
score: float # 相关性分数
metadata: Dict
embedding: np.ndarray = None
class TokenBudgetManager:
"""Token预算管理器"""
def __init__(self,
max_tokens: int = 4000,
reserve_tokens: int = 500,
tokenizer_name: str = "gpt2"):
"""
初始化
Args:
max_tokens: 总Token预算
reserve_tokens: 为生成预留的Token数
tokenizer_name: 分词器名称
"""
self.max_tokens = max_tokens
self.reserve_tokens = reserve_tokens
self.available_tokens = max_tokens - reserve_tokens
# 加载分词器
try:
self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
except:
# 备用:使用简单字符估算
self.tokenizer = None
self.char_per_token = 4 # 经验值
def count_tokens(self, text: str) -> int:
"""计算文本的Token数"""
if self.tokenizer:
return len(self.tokenizer.encode(text))
return len(text) // self.char_per_token
def fit_chunks(self,
chunks: List[Chunk],
query: str,
strategy: str = "greedy") -> Tuple[List[Chunk], int]:
"""
在预算内适配文本块
Args:
chunks: 候选块
query: 查询文本(也占用预算)
strategy: 适配策略 ('greedy', 'mmr', 'diverse')
Returns:
(选中的块, 使用的Token数)
"""
query_tokens = self.count_tokens(query)
remaining = self.available_tokens - query_tokens
if remaining <= 0:
return [], 0
selected = []
total_tokens = 0
if strategy == "greedy":
# 贪婪选择:按相关性排序,依次加入直到预算耗尽
sorted_chunks = sorted(chunks, key=lambda x: x.score, reverse=True)
for chunk in sorted_chunks:
chunk_tokens = self.count_tokens(chunk.content)
if total_tokens + chunk_tokens <= remaining:
selected.append(chunk)
total_tokens += chunk_tokens
else:
# 尝试截断
if remaining - total_tokens > 50: # 至少保留50 tokens
truncated = self._truncate_at_sentence(
chunk.content,
remaining - total_tokens
)
if truncated:
chunk.content = truncated
chunk.metadata["truncated"] = True
selected.append(chunk)
total_tokens += self.count_tokens(truncated)
break
elif strategy == "mmr":
# 使用MMR选择
selected = self._select_mmr(chunks, remaining)
total_tokens = sum(self.count_tokens(c.content) for c in selected)
return selected, total_tokens + query_tokens
def _truncate_at_sentence(self, text: str, max_tokens: int) -> str:
"""在句子边界处截断"""
target_chars = max_tokens * (self.char_per_token if not self.tokenizer else 4)
# 寻找句子边界
sentences = []
current = ""
for char in text:
current += char
if char in ".!?。!?":
if len(current) + sum(len(s) for s in sentences) < target_chars:
sentences.append(current)
current = ""
else:
break
return "".join(sentences) if sentences else text[:target_chars]
def _select_mmr(self,
candidates: List[Chunk],
budget: int,
lambda_param: float = 0.5) -> List[Chunk]:
"""
最大边际相关性选择
MMR = λ * Relevance - (1-λ) * max(Similarity with selected)
"""
selected = []
remaining = candidates.copy()
used_budget = 0
while remaining and used_budget < budget:
mmr_scores = []
for chunk in remaining:
chunk_tokens = self.count_tokens(chunk.content)
if used_budget + chunk_tokens > budget:
continue
# 相关性
relevance = chunk.score
# 多样性(与已选块的最大相似度)
if selected and chunk.embedding is not None:
similarities = [
self._cosine_sim(chunk.embedding, s.embedding)
for s in selected if s.embedding is not None
]
max_sim = max(similarities) if similarities else 0
else:
max_sim = 0
mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim
mmr_scores.append((chunk, mmr_score, chunk_tokens))
if not mmr_scores:
break
# 选择MMR分数最高的
best_chunk, best_score, best_tokens = max(mmr_scores, key=lambda x: x[1])
selected.append(best_chunk)
used_budget += best_tokens
remaining.remove(best_chunk)
return selected
def _cosine_sim(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算余弦相似度"""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)
class PriorityAssembler:
"""优先级组装器"""
def __init__(self, token_budget_manager: TokenBudgetManager):
self.tbm = token_budget_manager
def assemble(self,
query: str,
chunks: List[Chunk],
priority_rules: List[Dict] = None) -> Dict:
"""
组装上下文
Args:
query: 查询
chunks: 候选块(已按某种方式排序)
priority_rules: 优先级规则
Returns:
组装结果
"""
# 按优先级排序(如果有)
if priority_rules:
chunks = self._apply_priority_rules(chunks, priority_rules)
# 适配预算
selected, used_tokens = self.tbm.fit_chunks(chunks, query, strategy="mmr")
# 组装最终文本
context_parts = []
for i, chunk in enumerate(selected):
header = f"[{i+1}] Source: {chunk.metadata.get('source', 'unknown')}"
if chunk.metadata.get("truncated"):
header += " (truncated)"
context_parts.append(f"{header}\n{chunk.content}")
assembled_context = "\n\n".join(context_parts)
return {
"assembled_context": assembled_context,
"selected_chunks": [
{
"id": c.id,
"content": c.content[:200] + "..." if len(c.content) > 200 else c.content,
"original_score": c.score,
"metadata": c.metadata
}
for c in selected
],
"usage": {
"total_tokens": used_tokens,
"max_tokens": self.tbm.max_tokens,
"reserved_tokens": self.tbm.reserve_tokens,
"chunk_count": len(selected)
},
"truncated": any(c.metadata.get("truncated") for c in selected)
}
def _apply_priority_rules(self, chunks: List[Chunk], rules: List[Dict]) -> List[Chunk]:
"""应用优先级规则"""
def get_priority(chunk):
for rule in rules:
if self._matches_rule(chunk, rule["condition"]):
return rule["priority"]
return 0
return sorted(chunks, key=lambda x: (get_priority(x), x.score), reverse=True)
def _matches_rule(self, chunk: Chunk, condition: Dict) -> bool:
"""检查块是否匹配条件"""
for key, value in condition.items():
if chunk.metadata.get(key) != value:
return False
return True
class ContextAssemblerPipeline:
"""完整组装流水线"""
def __init__(self, max_tokens: int = 4000):
self.tbm = TokenBudgetManager(max_tokens=max_tokens)
self.assembler = PriorityAssembler(self.tbm)
def process(self,
query: str,
retrieved_chunks: List[Dict],
embeddings: Dict[str, np.ndarray] = None) -> Dict:
"""
处理检索结果并组装上下文
Args:
query: 查询
retrieved_chunks: 检索到的块列表
embeddings: 块的嵌入向量(用于MMR)
Returns:
组装结果
"""
# 转换为Chunk对象
chunks = []
for i, rc in enumerate(retrieved_chunks):
chunk_id = rc.get("id") or rc.get("doc_id") or f"chunk_{i}"
emb = embeddings.get(chunk_id) if embeddings else None
chunks.append(Chunk(
id=chunk_id,
content=rc.get("text") or rc.get("content", ""),
score=rc.get("score") or rc.get("similarity", 0),
metadata=rc.get("metadata", {}),
embedding=emb
))
# 组装
result = self.assembler.assemble(query, chunks)
return result
def main():
parser = argparse.ArgumentParser(description="上下文组装")
parser.add_argument("--query", "-q", required=True, help="查询文本")
parser.add_argument("--chunks", "-c", default=None, help="检索块JSON文件")
parser.add_argument("--max-tokens", "-t", type=int, default=2000, help="最大Token数")
parser.add_argument("--reserve", "-r", type=int, default=500, help="预留Token数")
parser.add_argument("--output", "-o", default="assembled_context.json", help="输出文件")
args = parser.parse_args()
# 加载或生成示例数据
if args.chunks:
with open(args.chunks, 'r') as f:
data = json.load(f)
if isinstance(data, dict):
retrieved = data.get("results", [])
else:
retrieved = data
else:
# 生成示例
retrieved = [
{
"id": f"doc_{i}",
"content": f"This is a detailed explanation about topic {i}. " * 20 +
f"It contains important information relevant to the query.",
"score": 0.95 - i*0.05,
"metadata": {"source": f"document_{i}.pdf", "page": i+1}
}
for i in range(10)
]
print(f"[INFO] 加载了 {len(retrieved)} 个候选块")
print(f"[INFO] 查询: {args.query}")
# 生成模拟嵌入(实际应从向量数据库获取)
np.random.seed(42)
embeddings = {}
for r in retrieved:
emb = np.random.randn(384)
emb = emb / np.linalg.norm(emb)
embeddings[r["id"]] = emb
# 组装
pipeline = ContextAssemblerPipeline(max_tokens=args.max_tokens)
result = pipeline.process(args.query, retrieved, embeddings)
print(f"\n组装完成:")
print(f" 选中块数: {result['usage']['chunk_count']}")
print(f" 使用Token: {result['usage']['total_tokens']}/{result['usage']['max_tokens']}")
print(f" 是否截断: {result['truncated']}")
print(f"\n组装后的上下文 (前500字符):\n{result['assembled_context'][:500]}...")
# 保存
with open(args.output, 'w') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存至: {args.output}")
if __name__ == "__main__":
main()
6.4.2 引用溯源
Python
#!/usr/bin/env python3
"""
Script: citation_attribution.py
功能: 答案引用溯源系统,实现检索结果高亮、原文链接生成与NLI验证
使用方式: python citation_attribution.py --answer generated_answer.txt --context assembled_context.json
"""
import json
import re
import argparse
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import numpy as np
from sentence_transformers import CrossEncoder
@dataclass
class AttributionResult:
"""溯源结果"""
claim: str
supporting_evidence: List[Dict]
confidence: float
citation_markers: List[str]
class NLIVerifier:
"""自然语言推理验证器"""
def __init__(self, model_name: str = "cross-encoder/nli-deberta-v3-base"):
"""
初始化NLI模型
Args:
model_name: 交叉编码器模型
"""
print(f"[INFO] 加载NLI模型: {model_name}")
self.model = CrossEncoder(model_name)
self.labels = ["contradiction", "entailment", "neutral"]
def verify(self, premise: str, hypothesis: str) -> Dict:
"""
验证假设是否被前提支持
Args:
premise: 前提(检索到的文本)
hypothesis: 假设(生成的陈述)
Returns:
验证结果
"""
scores = self.model.predict([[premise, hypothesis]], apply_softmax=True)[0]
label_idx = np.argmax(scores)
return {
"label": self.labels[label_idx],
"confidence": float(scores[label_idx]),
"scores": {
"contradiction": float(scores[0]),
"entailment": float(scores[1]),
"neutral": float(scores[2])
}
}
def batch_verify(self,
premise_hypothesis_pairs: List[Tuple[str, str]]) -> List[Dict]:
"""批量验证"""
if not premise_hypothesis_pairs:
return []
scores = self.model.predict(premise_hypothesis_pairs, apply_softmax=True)
results = []
for score_vec in scores:
label_idx = np.argmax(score_vec)
results.append({
"label": self.labels[label_idx],
"confidence": float(score_vec[label_idx]),
"entailment_score": float(score_vec[1])
})
return results
class ClaimExtractor:
"""陈述提取器"""
def extract(self, text: str) -> List[str]:
"""
将文本分割为原子陈述
Args:
text: 输入文本
Returns:
陈述列表
"""
# 按句子分割
sentences = re.split(r'(?<=[.!?。!?])\s+', text)
# 过滤太短或太长的句子
claims = []
for sent in sentences:
sent = sent.strip()
if 10 < len(sent) < 500:
# 移除引用标记如 [1], [2], etc.
clean_sent = re.sub(r'\[\d+\]', '', sent).strip()
if clean_sent:
claims.append(clean_sent)
return claims
class CitationGenerator:
"""引用生成器"""
def __init__(self, use_brackets: bool = True):
self.use_brackets = use_brackets
self.citation_counter = 0
self.citation_map = {}
def generate_marker(self, source_id: str, source_meta: Dict) -> str:
"""生成引用标记"""
if source_id not in self.citation_map:
self.citation_counter += 1
self.citation_map[source_id] = {
"number": self.citation_counter,
"metadata": source_meta
}
num = self.citation_map[source_id]["number"]
return f"[{num}]" if self.use_brackets else f"({num})"
def generate_bibliography(self) -> List[Dict]:
"""生成参考文献列表"""
return [
{
"citation_number": info["number"],
"source_id": sid,
**info["metadata"]
}
for sid, info in sorted(self.citation_map.items(), key=lambda x: x[1]["number"])
]
class AttributionEngine:
"""溯源引擎"""
def __init__(self,
nli_model: str = "cross-encoder/nli-deberta-v3-base",
entailment_threshold: float = 0.7):
"""
初始化
Args:
nli_model: NLI模型名称
entailment_threshold: 蕴含判定阈值
"""
self.nli = NLIVerifier(nli_model)
self.claim_extractor = ClaimExtractor()
self.citation_gen = CitationGenerator()
self.threshold = entailment_threshold
def attribute(self,
generated_answer: str,
retrieved_contexts: List[Dict],
link_format: str = "markdown") -> Dict:
"""
执行溯源
Args:
generated_answer: 生成的答案
retrieved_contexts: 检索上下文(包含id, text, metadata)
link_format: 链接格式 ('markdown', 'html', 'plain')
Returns:
溯源结果
"""
# 提取陈述
claims = self.claim_extractor.extract(generated_answer)
print(f"[INFO] 提取了 {len(claims)} 个陈述")
# 为每个陈述寻找证据
attributions = []
all_verifications = []
for claim in claims:
# 验证与每个上下文的关系
evidence_list = []
verification_pairs = [
(ctx.get("text") or ctx.get("content", ""), claim)
for ctx in retrieved_contexts
]
verifications = self.nli.batch_verify(verification_pairs)
for ctx, verify_result in zip(retrieved_contexts, verifications):
if verify_result["label"] == "entailment" and verify_result["confidence"] > 0.5:
evidence_list.append({
"context_id": ctx.get("id") or ctx.get("chunk_id", "unknown"),
"context_text": (ctx.get("text") or ctx.get("content", ""))[:200],
"confidence": verify_result["confidence"],
"verification_score": verify_result["entailment_score"],
"metadata": ctx.get("metadata", {})
})
# 排序证据
evidence_list.sort(key=lambda x: x["confidence"], reverse=True)
# 生成引用标记
citations = []
for ev in evidence_list[:3]: # 最多3个引用
marker = self.citation_gen.generate_marker(
ev["context_id"],
ev["metadata"]
)
citations.append(marker)
attributions.append(AttributionResult(
claim=claim,
supporting_evidence=evidence_list[:3],
confidence=max([e["confidence"] for e in evidence_list]) if evidence_list else 0.0,
citation_markers=citations
))
# 生成带引用的答案
attributed_answer = self._insert_citations(generated_answer, attributions, link_format)
# 计算整体忠实度
faithfulness = sum(1 for a in attributions if a.confidence > self.threshold) / len(claims) if claims else 0
return {
"attributed_answer": attributed_answer,
"claims": [
{
"text": a.claim,
"citations": a.citation_markers,
"evidence_count": len(a.supporting_evidence),
"confidence": a.confidence
}
for a in attributions
],
"bibliography": self.citation_gen.generate_bibliography(),
"faithfulness_score": faithfulness,
"supported_claims": sum(1 for a in attributions if a.supporting_evidence),
"total_claims": len(claims)
}
def _insert_citations(self,
answer: str,
attributions: List[AttributionResult],
format: str) -> str:
"""在答案中插入引用标记"""
result = answer
# 简单实现:在句子末尾插入引用
for attr in attributions:
if attr.citation_markers:
citation_str = "".join(attr.citation_markers)
# 在陈述后插入引用(简化处理)
escaped_claim = re.escape(attr.claim[:50])
pattern = f"({escaped_claim}.*?)([.!?])"
replacement = f"\\1{citation_str}\\2"
result = re.sub(pattern, replacement, result, count=1)
return result
def highlight_evidence(self,
context: str,
claim: str,
output_format: str = "html") -> str:
"""
高亮证据文本
Args:
context: 原始上下文
claim: 支持的陈述
output_format: 输出格式
Returns:
带高亮的文本
"""
# 简单实现:将相关句子标记为高亮
sentences = re.split(r'(?<=[.!?])\s+', context)
# 找到最相似的句子(简化版)
highlighted = []
for sent in sentences:
# 简单重叠检测
claim_words = set(claim.lower().split())
sent_words = set(sent.lower().split())
overlap = len(claim_words & sent_words) / len(claim_words) if claim_words else 0
if overlap > 0.3:
if output_format == "html":
sent = f'<mark style="background-color: yellow;">{sent}</mark>'
elif output_format == "markdown":
sent = f"**{sent}**"
highlighted.append(sent)
return " ".join(highlighted)
def main():
parser = argparse.ArgumentParser(description="引用溯源")
parser.add_argument("--answer", "-a", required=True, help="生成的答案文本或文件")
parser.add_argument("--context", "-c", required=True, help="上下文JSON文件")
parser.add_argument("--output", "-o", default="attribution_result.json", help="输出文件")
parser.add_argument("--format", "-f", choices=["markdown", "html", "plain"],
default="markdown", help="引用格式")
parser.add_argument("--threshold", "-t", type=float, default=0.7, help="蕴含阈值")
args = parser.parse_args()
# 加载答案
if os.path.exists(args.answer):
with open(args.answer, 'r') as f:
answer = f.read()
else:
answer = args.answer
# 加载上下文
with open(args.context, 'r') as f:
contexts = json.load(f)
if isinstance(contexts, dict):
contexts = contexts.get("selected_chunks", [])
print(f"[INFO] 答案长度: {len(answer)} 字符")
print(f"[INFO] 上下文块数: {len(contexts)}")
# 执行溯源
engine = AttributionEngine(entailment_threshold=args.threshold)
result = engine.attribute(answer, contexts, link_format=args.format)
print(f"\n溯源结果:")
print(f" 忠实度分数: {result['faithfulness_score']:.2f}")
print(f" 支持陈述: {result['supported_claims']}/{result['total_claims']}")
print(f"\n带引用的答案:\n{result['attributed_answer'][:500]}...")
print(f"\n参考文献:")
for bib in result['bibliography'][:3]:
print(f" [{bib['citation_number']}] {bib.get('source', 'unknown')}")
# 保存
with open(args.output, 'w') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存至: {args.output}")
if __name__ == "__main__":
import os
main()
6.4.3 答案验证
Python
#!/usr/bin/env python3
"""
Script: self_rag_verification.py
功能: Self-RAG实现,包含反思机制、幻觉检测与自我修正
使用方式: python self_rag_verification.py --query "question" --context context.json --model gpt-4
"""
import json
import re
import argparse
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
class ReflectionToken(Enum):
"""反思标记"""
RETRIEVE = "[Retrieve]"
NO_RETRIEVE = "[No Retrieve]"
RELEVANT = "[Relevant]"
IRRELEVANT = "[Irrelevant]"
SUPPORTED = "[Supported]"
CONTRADICTORY = "[Contradictory]"
CORRECTION = "[Correction]"
FINISH = "[Finish]"
@dataclass
class ReflectionTrace:
"""反思轨迹"""
step: int
action: str
content: str
verification_result: Optional[str] = None
class SelfRAG:
"""
Self-Reflective Retrieval-Augmented Generation
基于Asai et al. (2023)的Self-RAG框架实现,
支持自适应检索与自我验证。
"""
def __init__(self,
llm_provider: str = "openai",
api_key: Optional[str] = None,
model: str = "gpt-4",
reflection_trigger: str = "adaptive"):
"""
初始化Self-RAG
Args:
llm_provider: LLM提供者
api_key: API密钥
model: 模型名称
reflection_trigger: 反思触发策略 ('always', 'adaptive', 'never')
"""
self.llm_provider = llm_provider
self.model = model
self.reflection_trigger = reflection_trigger
self.api_key = api_key
self._init_llm()
self.traces = []
def _init_llm(self):
"""初始化LLM"""
if self.llm_provider == "openai":
try:
from openai import OpenAI
except ImportError:
raise ImportError("请安装openai")
import os
self.client = OpenAI(api_key=self.api_key or os.getenv("OPENAI_API_KEY"))
else:
raise ValueError(f"不支持的提供者: {self.llm_provider}")
def generate_with_reflection(self,
query: str,
contexts: List[Dict],
max_iterations: int = 3) -> Dict:
"""
带反思的生成
Args:
query: 查询
contexts: 检索上下文
max_iterations: 最大迭代次数
Returns:
生成结果与反思轨迹
"""
self.traces = []
current_answer = ""
iteration = 0
# 初始反思:是否需要检索
if self.reflection_trigger == "always" or self._needs_retrieval(query):
self.traces.append(ReflectionTrace(0, "initial", ReflectionToken.RETRIEVE.value))
current_context = self._format_contexts(contexts)
else:
self.traces.append(ReflectionTrace(0, "initial", ReflectionToken.NO_RETRIEVE.value))
current_context = ""
while iteration < max_iterations:
iteration += 1
# 生成
prompt = self._build_prompt(query, current_context, current_answer)
response = self._call_llm(prompt)
# 解析反思标记
parsed = self._parse_reflection_tokens(response)
if parsed["action"] == "finish":
current_answer += parsed["content"]
self.traces.append(ReflectionTrace(
iteration, "generate", parsed["content"], "finished"
))
break
elif parsed["action"] == "retrieve":
# 请求更多检索(简化处理,实际应触发新的检索)
self.traces.append(ReflectionTrace(
iteration, "retrieve", "Requesting more information", None
))
break
elif parsed["action"] == "verify":
# 验证生成的内容
claim = parsed["content"]
verification = self._verify_claim(claim, contexts)
self.traces.append(ReflectionTrace(
iteration, "verify", claim, verification["result"]
))
if verification["result"] == "unsupported":
# 需要修正
correction = self._generate_correction(claim, contexts)
current_answer += correction + " "
self.traces.append(ReflectionTrace(
iteration, "correct", correction, None
))
else:
current_answer += claim + " "
else:
current_answer += parsed["content"] + " "
# 最终验证
final_check = self._hallucination_check(current_answer, contexts)
return {
"answer": current_answer.strip(),
"reflection_traces": [
{
"step": t.step,
"action": t.action,
"content": t.content[:100] + "..." if len(t.content) > 100 else t.content,
"verification": t.verification_result
}
for t in self.traces
],
"hallucination_detected": final_check["has_hallucination"],
"faithfulness_score": final_check["faithfulness"],
"iterations": iteration
}
def _needs_retrieval(self, query: str) -> bool:
"""判断是否需要检索"""
# 启发式规则:事实性问题通常需要检索
factual_patterns = r'\b(what|who|when|where|how many|how much|is|are|was|were)\b'
return bool(re.search(factual_patterns, query.lower()))
def _format_contexts(self, contexts: List[Dict]) -> str:
"""格式化上下文"""
parts = []
for i, ctx in enumerate(contexts, 1):
text = ctx.get("text") or ctx.get("content", "")
source = ctx.get("metadata", {}).get("source", f"doc_{i}")
parts.append(f"[{i}] Source: {source}\n{text}")
return "\n\n".join(parts)
def _build_prompt(self, query: str, context: str, partial_answer: str) -> str:
"""构建提示"""
reflection_instructions = """
You are a Self-RAG system. Follow these rules:
1. Generate the answer step by step
2. Use [Retrieve] if you need more information
3. Use [Verify] before stating a fact, then wait for feedback
4. Use [Supported] or [Contradictory] based on verification
5. Use [Finish] when the answer is complete
Available context:
{context}
Current partial answer: {partial}
""".format(context=context, partial=partial_answer or "None")
return f"{reflection_instructions}\n\nQuestion: {query}\nAnswer:"
def _call_llm(self, prompt: str) -> str:
"""调用LLM"""
if self.llm_provider == "openai":
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant with reflection capabilities."},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=500
)
return response.choices[0].message.content
return ""
def _parse_reflection_tokens(self, text: str) -> Dict:
"""解析反思标记"""
# 检测特殊标记
tokens = [t.value for t in ReflectionToken]
for token in tokens:
if token in text:
if token == ReflectionToken.FINISH.value:
content = text.split(token)[0].strip()
return {"action": "finish", "content": content}
elif token == ReflectionToken.RETRIEVE.value:
return {"action": "retrieve", "content": text}
elif token == ReflectionToken.VERIFY.value:
content = text.split(token)[1].split(".")[0].strip()
return {"action": "verify", "content": content}
elif token in [ReflectionToken.SUPPORTED.value, ReflectionToken.CONTRADICTORY.value]:
return {"action": "verified", "content": text}
return {"action": "continue", "content": text}
def _verify_claim(self, claim: str, contexts: List[Dict]) -> Dict:
"""验证陈述"""
# 简化实现:检查与上下文的相似度
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-MiniLM-L6-v2')
claim_emb = model.encode(claim, convert_to_tensor=True)
best_score = 0
for ctx in contexts:
ctx_text = ctx.get("text") or ctx.get("content", "")
ctx_emb = model.encode(ctx_text, convert_to_tensor=True)
score = util.pytorch_cos_sim(claim_emb, ctx_emb).item()
best_score = max(best_score, score)
threshold = 0.7
if best_score > threshold:
return {"result": "supported", "confidence": best_score}
else:
return {"result": "unsupported", "confidence": best_score}
def _generate_correction(self, wrong_claim: str, contexts: List[Dict]) -> str:
"""生成修正"""
context_text = " ".join([
(c.get("text") or c.get("content", ""))[:500]
for c in contexts[:2]
])
prompt = f"""Based on the following context, correct the previous statement:
Context: {context_text}
Incorrect statement: {wrong_claim}
Corrected statement:"""
return self._call_llm(prompt)
def _hallucination_check(self, answer: str, contexts: List[Dict]) -> Dict:
"""幻觉检测"""
# 提取陈述并验证
sentences = re.split(r'(?<=[.!?])\s+', answer)
verified = 0
for sent in sentences:
if len(sent.strip()) < 10:
continue
result = self._verify_claim(sent, contexts)
if result["result"] == "supported":
verified += 1
total = len([s for s in sentences if len(s.strip()) > 10])
faithfulness = verified / total if total > 0 else 1.0
return {
"has_hallucination": faithfulness < 0.8,
"faithfulness": faithfulness,
"verified_statements": verified,
"total_statements": total
}
def main():
parser = argparse.ArgumentParser(description="Self-RAG验证")
parser.add_argument("--query", "-q", required=True, help="查询")
parser.add_argument("--context", "-c", required=True, help="上下文JSON")
parser.add_argument("--api-key", "-k", default=None, help="API Key")
parser.add_argument("--model", "-m", default="gpt-4", help="模型")
parser.add_argument("--output", "-o", default="self_rag_result.json", help="输出")
args = parser.parse_args()
# 加载上下文
with open(args.context, 'r') as f:
contexts = json.load(f)
if isinstance(contexts, dict):
contexts = contexts.get("selected_chunks", [])
# 初始化Self-RAG
self_rag = SelfRAG(
llm_provider="openai",
api_key=args.api_key,
model=args.model,
reflection_trigger="adaptive"
)
# 生成
print(f"[INFO] 执行Self-RAG生成...")
result = self_rag.generate_with_reflection(args.query, contexts)
print(f"\n最终答案:\n{result['answer']}")
print(f"\n忠实度: {result['faithfulness_score']:.2f}")
print(f"幻觉检测: {'是' if result['hallucination_detected'] else '否'}")
print(f"迭代次数: {result['iterations']}")
# 保存
with open(args.output, 'w') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存至: {args.output}")
if __name__ == "__main__":
main()
6.4.4 缓存策略
Python
#!/usr/bin/env python3
"""
Script: semantic_cache.py
功能: 语义缓存(Semantic Cache)与精确匹配缓存实现,含缓存命中率可视化
使用方式: python semantic_cache.py --query "your question" --cache-file cache.db
"""
import json
import hashlib
import argparse
import numpy as np
import pickle
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
import sqlite3
@dataclass
class CacheEntry:
"""缓存条目"""
query_hash: str
query_embedding: np.ndarray
answer: str
timestamp: float
hit_count: int = 0
ttl: int = 3600 # 默认1小时
class SemanticCache:
"""
语义缓存
基于查询嵌入向量的相似性进行缓存匹配,
支持TTL(生存时间)与LRU(最近最少使用)策略。
"""
def __init__(self,
embedding_dim: int = 384,
similarity_threshold: float = 0.95,
ttl_seconds: int = 3600,
max_size: int = 10000):
"""
初始化语义缓存
Args:
embedding_dim: 嵌入维度
similarity_threshold: 相似度阈值
ttl_seconds: 生存时间
max_size: 最大条目数
"""
self.embedding_dim = embedding_dim
self.threshold = similarity_threshold
self.ttl = ttl_seconds
self.max_size = max_size
# 内存缓存
self.cache: Dict[str, CacheEntry] = {}
self.query_hashes = [] # 用于LRU
# 嵌入模型(用于新查询编码)
self.embedder = None
self._init_embedder()
# 统计
self.stats = {"hits": 0, "misses": 0, "semantic_hits": 0}
def _init_embedder(self):
"""初始化嵌入模型"""
try:
from sentence_transformers import SentenceTransformer
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
except:
print("[WARN] 无法加载嵌入模型,使用模拟模式")
self.embedder = None
def _get_embedding(self, query: str) -> np.ndarray:
"""获取查询嵌入"""
if self.embedder:
return self.embedder.encode(query)
else:
# 模拟嵌入(基于哈希)
np.random.seed(hash(query) % 2**32)
vec = np.random.randn(self.embedding_dim)
return vec / np.linalg.norm(vec)
def _compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
"""计算余弦相似度"""
return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2) + 1e-10)
def _is_expired(self, entry: CacheEntry) -> bool:
"""检查是否过期"""
return time.time() - entry.timestamp > self.ttl
def get(self, query: str) -> Tuple[Optional[str], str]:
"""
获取缓存
Args:
query: 查询文本
Returns:
(缓存的答案, 命中类型)
"""
query_hash = hashlib.md5(query.encode()).hexdigest()
# L1: 精确匹配
if query_hash in self.cache:
entry = self.cache[query_hash]
if not self._is_expired(entry):
entry.hit_count += 1
entry.timestamp = time.time() # 更新LRU
self.stats["hits"] += 1
return entry.answer, "exact"
else:
# 过期删除
del self.cache[query_hash]
self.query_hashes.remove(query_hash)
# L2: 语义匹配
query_emb = self._get_embedding(query)
best_match = None
best_sim = 0
for key, entry in self.cache.items():
if self._is_expired(entry):
continue
sim = self._compute_similarity(query_emb, entry.query_embedding)
if sim > best_sim and sim >= self.threshold:
best_sim = sim
best_match = entry
if best_match:
best_match.hit_count += 1
self.stats["semantic_hits"] += 1
self.stats["hits"] += 1
return best_match.answer, f"semantic({best_sim:.3f})"
self.stats["misses"] += 1
return None, "miss"
def put(self, query: str, answer: str, embedding: Optional[np.ndarray] = None):
"""
存入缓存
Args:
query: 查询
answer: 答案
embedding: 嵌入向量(可选,自动计算)
"""
# 检查容量
if len(self.cache) >= self.max_size:
# LRU淘汰
oldest = min(self.cache.values(), key=lambda x: x.timestamp)
del self.cache[oldest.query_hash]
self.query_hashes.remove(oldest.query_hash)
query_hash = hashlib.md5(query.encode()).hexdigest()
emb = embedding if embedding is not None else self._get_embedding(query)
entry = CacheEntry(
query_hash=query_hash,
query_embedding=emb,
answer=answer,
timestamp=time.time()
)
self.cache[query_hash] = entry
self.query_hashes.append(query_hash)
def get_stats(self) -> Dict:
"""获取统计信息"""
total = self.stats["hits"] + self.stats["misses"]
hit_rate = self.stats["hits"] / total if total > 0 else 0
return {
"total_queries": total,
"cache_hits": self.stats["hits"],
"exact_hits": self.stats["hits"] - self.stats["semantic_hits"],
"semantic_hits": self.stats["semantic_hits"],
"cache_misses": self.stats["misses"],
"hit_rate": hit_rate,
"semantic_hit_rate": self.stats["semantic_hits"] / total if total > 0 else 0,
"current_size": len(self.cache),
"max_size": self.max_size
}
class PersistentSemanticCache(SemanticCache):
"""持久化语义缓存(基于SQLite)"""
def __init__(self, db_path: str = "semantic_cache.db", **kwargs):
super().__init__(**kwargs)
self.db_path = db_path
self._init_db()
self._load_from_db()
def _init_db(self):
"""初始化数据库"""
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS cache_entries (
query_hash TEXT PRIMARY KEY,
query_embedding BLOB,
answer TEXT,
timestamp REAL,
hit_count INTEGER DEFAULT 0,
ttl INTEGER
)
""")
self.conn.commit()
def _load_from_db(self):
"""从数据库加载缓存"""
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM cache_entries")
for row in cursor.fetchall():
hash_val, emb_blob, answer, ts, hits, ttl = row
emb = np.frombuffer(emb_blob, dtype=np.float32)
# 检查过期
if time.time() - ts > ttl:
continue
entry = CacheEntry(
query_hash=hash_val,
query_embedding=emb,
answer=answer,
timestamp=ts,
hit_count=hits,
ttl=ttl
)
self.cache[hash_val] = entry
def put(self, query: str, answer: str, embedding: Optional[np.ndarray] = None):
"""存入缓存(覆盖父类以持久化)"""
super().put(query, answer, embedding)
# 保存到数据库
entry = self.cache[hashlib.md5(query.encode()).hexdigest()]
cursor = self.conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO cache_entries
(query_hash, query_embedding, answer, timestamp, hit_count, ttl)
VALUES (?, ?, ?, ?, ?, ?)
""", (
entry.query_hash,
entry.query_embedding.astype(np.float32).tobytes(),
entry.answer,
entry.timestamp,
entry.hit_count,
entry.ttl
))
self.conn.commit()
class CacheVisualizer:
"""缓存可视化"""
def visualize_stats(self, stats: Dict, output_path: str = "cache_stats.png"):
"""可视化统计"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 1. 命中率饼图
ax1 = axes[0]
sizes = [stats["exact_hits"], stats["semantic_hits"], stats["cache_misses"]]
labels = ['Exact Match', 'Semantic Match', 'Misses']
colors = ['#66b3ff', '#99ff99', '#ff9999']
explode = (0.05, 0.05, 0)
ax1.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%',
shadow=True, startangle=90)
ax1.set_title(f'Cache Hit Rate (Total: {stats["total_queries"]})')
# 2. 缓存使用情况
ax2 = axes[1]
usage = stats["current_size"] / stats["max_size"] * 100
ax2.barh(['Cache Usage'], [usage], color='skyblue')
ax2.set_xlim([0, 100])
ax2.set_xlabel('Percentage (%)')
ax2.set_title(f'Cache Capacity: {stats["current_size"]}/{stats["max_size"]}')
ax2.text(usage + 2, 0, f'{usage:.1f}%', va='center')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 统计图表已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="语义缓存")
parser.add_argument("--query", "-q", default=None, help="查询文本")
parser.add_argument("--answer", "-a", default=None, help="答案文本")
parser.add_argument("--db-file", "-d", default="semantic_cache.db", help="数据库文件")
parser.add_argument("--threshold", "-t", type=float, default=0.95, help="相似度阈值")
parser.add_argument("--demo", action="store_true", help="运行演示")
args = parser.parse_args()
# 初始化缓存
cache = PersistentSemanticCache(
db_path=args.db_file,
similarity_threshold=args.threshold
)
if args.demo:
print("[DEMO] 运行缓存演示...")
# 模拟查询序列
queries = [
"What is machine learning?",
"What is machine learning", # 近似重复(语义匹配)
"How does deep learning work?",
"What is machine learning?", # 精确重复
"Explain neural networks",
"How does deep learning work" # 近似重复
]
answers = [
"Machine learning is a subset of AI...",
"Deep learning uses neural networks with multiple layers...",
"Neural networks are computing systems inspired by biological neural networks..."
]
for i, q in enumerate(queries):
# 尝试获取
cached, hit_type = cache.get(q)
if cached:
print(f"Query {i+1}: '{q[:40]}...' -> [HIT: {hit_type}]")
else:
# 模拟生成
ans = answers[i % len(answers)]
cache.put(q, ans)
print(f"Query {i+1}: '{q[:40]}...' -> [MISS] (Cached)")
# 打印统计
stats = cache.get_stats()
print(f"\n缓存统计:")
print(f" 命中率: {stats['hit_rate']:.2%}")
print(f" 精确匹配: {stats['exact_hits']}")
print(f" 语义匹配: {stats['semantic_hits']}")
# 可视化
visualizer = CacheVisualizer()
visualizer.visualize_stats(stats)
elif args.query and args.answer:
# 存入
cache.put(args.query, args.answer)
print(f"[INFO] 已缓存查询")
elif args.query:
# 查询
cached, hit_type = cache.get(args.query)
if cached:
print(f"[HIT: {hit_type}] 答案: {cached[:200]}...")
else:
print("[MISS] 缓存未命中")
else:
parser.print_help()
if __name__ == "__main__":
main()
6.5.1 离线评估
Python
#!/usr/bin/env python3
"""
Script: ragas_evaluation.py
功能: RAGAS指标计算(Faithfulness, Answer Relevancy, Context Precision/Recall),含可视化报告
使用方式: python ragas_evaluation.py --qa-pairs qa.json --output evaluation_report.json
"""
import json
import argparse
import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
from sentence_transformers import SentenceTransformer, util
@dataclass
class QAPair:
"""问答对"""
question: str
answer: str
contexts: List[str]
ground_truth: str = None
class RAGASEvaluator:
"""
RAGAS评估器
实现RAGAS框架的核心指标:
- Faithfulness: 答案对上下文的事实一致性
- Answer Relevancy: 答案与问题的相关性
- Context Precision: 检索上下文的相关比例
- Context Recall: 检索上下文对答案的覆盖度
"""
def __init__(self,
llm_provider: str = "local",
embedding_model: str = "all-MiniLM-L6-v2"):
"""
初始化评估器
Args:
llm_provider: LLM提供者(用于NLI和生成)
embedding_model: 嵌入模型
"""
self.llm_provider = llm_provider
self.embedder = SentenceTransformer(embedding_model)
if llm_provider == "openai":
try:
from openai import OpenAI
self.client = OpenAI()
except:
self.client = None
def faithfulness(self, answer: str, contexts: List[str]) -> Dict:
"""
计算忠实度
步骤:
1. 将答案分解为原子陈述
2. 验证每个陈述是否被上下文支持
Args:
answer: 生成的答案
contexts: 检索上下文列表
Returns:
忠实度分数和详细分析
"""
# 陈述提取(简化:按句子分割)
statements = self._extract_statements(answer)
if not statements:
return {"score": 0.0, "reason": "No statements extracted"}
# 验证每个陈述
supported = 0
details = []
for stmt in statements:
# 检查是否有上下文支持该陈述
is_supported, best_evidence = self._verify_statement(stmt, contexts)
if is_supported:
supported += 1
details.append({
"statement": stmt,
"supported": is_supported,
"evidence": best_evidence[:200] if best_evidence else None
})
score = supported / len(statements)
return {
"score": score,
"supported_statements": supported,
"total_statements": len(statements),
"details": details
}
def answer_relevancy(self, question: str, answer: str) -> Dict:
"""
计算答案相关性
方法:生成潜在问题并计算与原始问题的相似度
Args:
question: 原始问题
answer: 答案
Returns:
相关性分数
"""
# 生成潜在问题(实际应使用LLM)
artificial_questions = self._generate_questions_from_answer(answer, n=3)
if not artificial_questions:
return {"score": 0.0, "reason": "No questions generated"}
# 编码
q_emb = self.embedder.encode(question, convert_to_tensor=True)
aq_embs = self.embedder.encode(artificial_questions, convert_to_tensor=True)
# 计算平均相似度
similarities = util.pytorch_cos_sim(q_emb, aq_embs)[0]
mean_sim = float(np.mean(similarities.cpu().numpy()))
return {
"score": mean_sim,
"generated_questions": artificial_questions,
"similarities": similarities.tolist()
}
def context_precision(self,
question: str,
contexts: List[str],
ground_truth: str = None) -> Dict:
"""
计算上下文精确率
评估检索的上下文中相关部分的比例
Args:
question: 问题
contexts: 检索上下文(已排序)
ground_truth: 标准答案(可选)
Returns:
精确率分数
"""
if not contexts:
return {"score": 0.0}
# 计算每个上下文与问题的相关性
q_emb = self.embedder.encode(question)
relevant_count = 0
precisions = []
for i, ctx in enumerate(contexts, 1):
ctx_emb = self.embedder.encode(ctx)
similarity = np.dot(q_emb, ctx_emb)
# 假设相似度>0.7为相关
is_relevant = similarity > 0.7
if is_relevant:
relevant_count += 1
# 计算Precision@k
precisions.append(relevant_count / i)
# 平均精确率(AP)
ap = np.mean(precisions) if precisions else 0.0
return {
"score": ap,
"relevant_chunks": relevant_count,
"total_chunks": len(contexts),
"precision_at_k": precisions
}
def context_recall(self,
question: str,
answer: str,
contexts: List[str]) -> Dict:
"""
计算上下文召回率
评估答案中有多少信息能在上下文中找到支持
Args:
question: 问题
answer: 答案
contexts: 上下文
Returns:
召回率分数
"""
# 提取答案陈述
answer_statements = self._extract_statements(answer)
if not answer_statements:
return {"score": 0.0}
# 合并上下文
all_context = " ".join(contexts)
# 检查每个陈述是否在上下文中
retrieved = 0
for stmt in answer_statements:
# 简单实现:语义相似度
stmt_emb = self.embedder.encode(stmt)
ctx_emb = self.embedder.encode(all_context)
similarity = np.dot(stmt_emb, ctx_emb)
if similarity > 0.6: # 阈值
retrieved += 1
score = retrieved / len(answer_statements)
return {
"score": score,
"retrieved_statements": retrieved,
"total_statements": len(answer_statements)
}
def evaluate(self, qa_pairs: List[QAPair]) -> Dict:
"""
批量评估
Args:
qa_pairs: 问答对列表
Returns:
综合评估结果
"""
results = {
"faithfulness": [],
"answer_relevancy": [],
"context_precision": [],
"context_recall": []
}
for i, qa in enumerate(qa_pairs):
print(f"[EVAL] 评估样本 {i+1}/{len(qa_pairs)}...")
# 计算各项指标
f = self.faithfulness(qa.answer, qa.contexts)
ar = self.answer_relevancy(qa.question, qa.answer)
cp = self.context_precision(qa.question, qa.contexts, qa.ground_truth)
cr = self.context_recall(qa.question, qa.answer, qa.contexts)
results["faithfulness"].append(f["score"])
results["answer_relevancy"].append(ar["score"])
results["context_precision"].append(cp["score"])
results["context_recall"].append(cr["score"])
# 汇总
summary = {
"faithfulness": {
"mean": np.mean(results["faithfulness"]),
"std": np.std(results["faithfulness"])
},
"answer_relevancy": {
"mean": np.mean(results["answer_relevancy"]),
"std": np.std(results["answer_relevancy"])
},
"context_precision": {
"mean": np.mean(results["context_precision"]),
"std": np.std(results["context_precision"])
},
"context_recall": {
"mean": np.mean(results["context_recall"]),
"std": np.std(results["context_recall"])
}
}
# 综合RAGAS分数(加权平均)
summary["ragas_score"] = (
0.3 * summary["faithfulness"]["mean"] +
0.3 * summary["answer_relevancy"]["mean"] +
0.2 * summary["context_precision"]["mean"] +
0.2 * summary["context_recall"]["mean"]
)
return {
"summary": summary,
"detailed_scores": results
}
def _extract_statements(self, text: str) -> List[str]:
"""提取原子陈述"""
# 按句子分割
import re
sentences = re.split(r'(?<=[.!?。!?])\s+', text)
return [s.strip() for s in sentences if len(s.strip()) > 10]
def _verify_statement(self, statement: str, contexts: List[str]) -> Tuple[bool, str]:
"""验证陈述是否被支持"""
stmt_emb = self.embedder.encode(statement)
best_sim = 0
best_ctx = ""
for ctx in contexts:
ctx_emb = self.embedder.encode(ctx)
sim = np.dot(stmt_emb, ctx_emb)
if sim > best_sim:
best_sim = sim
best_ctx = ctx
# 阈值判断
is_supported = best_sim > 0.65
return is_supported, best_ctx
def _generate_questions_from_answer(self, answer: str, n: int = 3) -> List[str]:
"""从答案生成问题(简化实现)"""
# 实际应使用LLM生成
# 这里使用模板生成
questions = []
# 提取关键句生成问题
sentences = self._extract_statements(answer)[:n]
for sent in sentences:
# 简单启发式转换
if "is" in sent:
q = sent.replace("is", "what is", 1) + "?"
else:
q = f"What can you tell me about: {sent[:50]}?"
questions.append(q)
return questions
class EvaluationVisualizer:
"""评估可视化"""
def visualize(self, results: Dict, output_path: str = "ragas_report.png"):
"""生成可视化报告"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
summary = results["summary"]
detailed = results["detailed_scores"]
metrics = ["faithfulness", "answer_relevancy", "context_precision", "context_recall"]
titles = ["Faithfulness", "Answer Relevancy", "Context Precision", "Context Recall"]
for idx, (metric, title) in enumerate(zip(metrics, titles)):
ax = axes[idx // 2, idx % 2]
scores = detailed[metric]
# 箱线图
bp = ax.boxplot(scores, patch_artist=True)
bp['boxes'][0].set_facecolor('lightblue')
# 添加均值线
mean_val = summary[metric]["mean"]
ax.axhline(y=mean_val, color='r', linestyle='--', label=f'Mean: {mean_val:.3f}')
ax.set_ylabel('Score')
ax.set_title(title)
ax.set_ylim([0, 1])
ax.legend()
ax.grid(True, alpha=0.3)
plt.suptitle(f"RAGAS Evaluation Report (Overall Score: {summary['ragas_score']:.3f})",
fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 评估报告已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="RAGAS评估")
parser.add_argument("--qa-pairs", "-q", required=True, help="问答对JSON文件")
parser.add_argument("--output", "-o", default="ragas_results.json", help="输出文件")
parser.add_argument("--visualize", "-v", action="store_true", help="生成可视化")
args = parser.parse_args()
# 加载数据
with open(args.qa_pairs, 'r') as f:
data = json.load(f)
# 转换为QAPair
qa_pairs = []
for item in data:
qa = QAPair(
question=item["question"],
answer=item["answer"],
contexts=item.get("contexts", []),
ground_truth=item.get("ground_truth")
)
qa_pairs.append(qa)
print(f"[INFO] 加载了 {len(qa_pairs)} 个问答对")
# 评估
evaluator = RAGASEvaluator()
results = evaluator.evaluate(qa_pairs)
# 打印摘要
print("\n评估结果摘要:")
print(f" RAGAS总分: {results['summary']['ragas_score']:.3f}")
print(f" 忠实度: {results['summary']['faithfulness']['mean']:.3f} ± {results['summary']['faithfulness']['std']:.3f}")
print(f" 答案相关性: {results['summary']['answer_relevancy']['mean']:.3f}")
print(f" 上下文精确率: {results['summary']['context_precision']['mean']:.3f}")
print(f" 上下文召回率: {results['summary']['context_recall']['mean']:.3f}")
# 保存
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n详细结果已保存至: {args.output}")
# 可视化
if args.visualize:
visualizer = EvaluationVisualizer()
visualizer.visualize(results)
if __name__ == "__main__":
main()
6.5.2 在线反馈
Python
#!/usr/bin/env python3
"""
Script: online_feedback_system.py
功能: 在线反馈收集系统,支持显式反馈(点赞/点踩)与隐式信号(停留时间、复制行为)
使用方式: python online_feedback_system.py --collect --session-id "sess_001"
"""
import json
import time
import argparse
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
import sqlite3
import threading
from collections import defaultdict
@dataclass
class FeedbackRecord:
"""反馈记录"""
session_id: str
query: str
answer: str
timestamp: str
explicit_rating: Optional[int] = None # 1-5星,或-1/1表示点踩/点赞
implicit_signals: Dict = None
metadata: Dict = None
def __post_init__(self):
if self.implicit_signals is None:
self.implicit_signals = {}
if self.metadata is None:
self.metadata = {}
class ImplicitSignalTracker:
"""
隐式反馈信号追踪器
收集用户行为信号推断满意度:
- Dwell time: 答案展示后的停留时间
- Copy events: 复制答案内容
- Click-through: 点击引用链接
- Follow-up queries: 后续查询(可能表示答案不完整)
"""
def __init__(self):
self.sessions: Dict[str, Dict] = {}
self.lock = threading.Lock()
def start_session(self, session_id: str, query: str, answer: str):
"""开始追踪会话"""
with self.lock:
self.sessions[session_id] = {
"start_time": time.time(),
"query": query,
"answer": answer,
"events": [],
"copy_count": 0,
"click_count": 0,
"scroll_depth": 0
}
def record_event(self, session_id: str, event_type: str, data: Dict = None):
"""记录事件"""
with self.lock:
if session_id not in self.sessions:
return
event = {
"type": event_type,
"timestamp": time.time(),
"data": data or {}
}
self.sessions[session_id]["events"].append(event)
if event_type == "copy":
self.sessions[session_id]["copy_count"] += 1
elif event_type == "click_citation":
self.sessions[session_id]["click_count"] += 1
elif event_type == "scroll":
self.sessions[session_id]["scroll_depth"] = data.get("depth", 0)
def end_session(self, session_id: str) -> Dict:
"""结束会话并计算隐式分数"""
with self.lock:
if session_id not in self.sessions:
return {}
sess = self.sessions[session_id]
duration = time.time() - sess["start_time"]
# 计算满意度分数(启发式)
score = 0.5 # 基准
# 停留时间(30秒以上加分)
if duration > 30:
score += 0.2
elif duration < 5:
score -= 0.2
# 复制行为(强烈正信号)
if sess["copy_count"] > 0:
score += 0.15 * min(sess["copy_count"], 2)
# 点击引用(正信号)
if sess["click_count"] > 0:
score += 0.1 * min(sess["click_count"], 2)
# 滚动深度
if sess["scroll_depth"] > 0.8:
score += 0.1
# 检查是否有后续查询(负信号:可能未解决)
follow_up = any(e["type"] == "new_query" for e in sess["events"])
if follow_up:
score -= 0.15
signals = {
"dwell_time_seconds": duration,
"copy_events": sess["copy_count"],
"citation_clicks": sess["click_count"],
"scroll_depth": sess["scroll_depth"],
"inferred_satisfaction": max(0, min(1, score))
}
# 清理
del self.sessions[session_id]
return signals
class FeedbackDatabase:
"""反馈数据库"""
def __init__(self, db_path: str = "feedback.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
"""初始化数据库表"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
query TEXT,
answer TEXT,
timestamp TEXT,
explicit_rating INTEGER,
implicit_satisfaction REAL,
dwell_time REAL,
copy_events INTEGER,
citation_clicks INTEGER,
metadata TEXT
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_timestamp ON feedback(timestamp)
""")
conn.commit()
conn.close()
def save_feedback(self, record: FeedbackRecord):
"""保存反馈"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
implicit = record.implicit_signals or {}
cursor.execute("""
INSERT INTO feedback
(session_id, query, answer, timestamp, explicit_rating,
implicit_satisfaction, dwell_time, copy_events, citation_clicks, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.session_id,
record.query,
record.answer,
record.timestamp,
record.explicit_rating,
implicit.get("inferred_satisfaction"),
implicit.get("dwell_time_seconds", 0),
implicit.get("copy_events", 0),
implicit.get("citation_clicks", 0),
json.dumps(record.metadata)
))
conn.commit()
conn.close()
def get_statistics(self, days: int = 7) -> Dict:
"""获取统计信息"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
since = (datetime.now() - timedelta(days=days)).isoformat()
# 基础统计
cursor.execute("""
SELECT
COUNT(*) as total,
AVG(explicit_rating) as avg_rating,
AVG(implicit_satisfaction) as avg_implicit,
AVG(dwell_time) as avg_dwell
FROM feedback
WHERE timestamp > ?
""", (since,))
row = cursor.fetchone()
# 显式反馈分布
cursor.execute("""
SELECT explicit_rating, COUNT(*)
FROM feedback
WHERE timestamp > ? AND explicit_rating IS NOT NULL
GROUP BY explicit_rating
""", (since,))
rating_dist = {r[0]: r[1] for r in cursor.fetchall()}
conn.close()
return {
"total_interactions": row[0],
"average_explicit_rating": row[1],
"average_implicit_satisfaction": row[2],
"average_dwell_time": row[3],
"rating_distribution": rating_dist,
"period_days": days
}
class FeedbackCollector:
"""反馈收集器"""
def __init__(self):
self.db = FeedbackDatabase()
self.tracker = ImplicitSignalTracker()
self.active_sessions = {}
def start_interaction(self, session_id: str, query: str, answer: str):
"""开始交互追踪"""
self.tracker.start_session(session_id, query, answer)
self.active_sessions[session_id] = {
"query": query,
"answer": answer,
"start_time": datetime.now().isoformat()
}
def record_explicit_feedback(self, session_id: str, rating: int, comment: str = None):
"""
记录显式反馈
Args:
session_id: 会话ID
rating: 1-5星,或-1/0/1(点踩/中立/点赞)
comment: 可选评论
"""
if session_id not in self.active_sessions:
return
# 结束隐式追踪并获取信号
implicit = self.tracker.end_session(session_id)
# 构建记录
sess = self.active_sessions[session_id]
record = FeedbackRecord(
session_id=session_id,
query=sess["query"],
answer=sess["answer"],
timestamp=datetime.now().isoformat(),
explicit_rating=rating,
implicit_signals=implicit,
metadata={"comment": comment} if comment else {}
)
# 保存
self.db.save_feedback(record)
# 清理
del self.active_sessions[session_id]
print(f"[INFO] 反馈已保存: Session={session_id}, Rating={rating}")
def record_event(self, session_id: str, event_type: str, data: Dict = None):
"""记录交互事件"""
self.tracker.record_event(session_id, event_type, data)
class FeedbackDashboard:
"""反馈仪表板"""
def generate_report(self, stats: Dict, output_path: str = "feedback_report.png"):
"""生成可视化报告"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# 1. 评分分布
ax1 = axes[0, 0]
if stats["rating_distribution"]:
ratings = list(stats["rating_distribution"].keys())
counts = list(stats["rating_distribution"].values())
ax1.bar(ratings, counts, color='skyblue')
ax1.set_xlabel('Rating')
ax1.set_ylabel('Count')
ax1.set_title('Explicit Rating Distribution')
# 2. 满意度趋势(模拟时间序列)
ax2 = axes[0, 1]
ax2.hist([stats["average_explicit_rating"] or 0,
stats["average_implicit_satisfaction"] or 0],
bins=10, range=(0, 5), color=['blue', 'green'], alpha=0.6,
label=['Explicit', 'Implicit'])
ax2.set_xlabel('Satisfaction Score')
ax2.set_ylabel('Frequency')
ax2.set_title('Satisfaction Comparison')
ax2.legend()
# 3. 平均停留时间
ax3 = axes[1, 0]
dwell = stats["average_dwell_time"] or 0
ax3.barh(['Average Dwell Time'], [dwell], color='orange')
ax3.set_xlabel('Seconds')
ax3.set_title(f'Avg Dwell Time: {dwell:.1f}s')
# 4. 总体指标
ax4 = axes[1, 1]
ax4.axis('off')
metrics_text = f"""
Total Interactions: {stats['total_interactions']}
Avg Explicit Rating: {stats['average_explicit_rating'] or 'N/A':.2f}
Avg Implicit Sat: {stats['average_implicit_satisfaction'] or 'N/A':.2f}
Avg Dwell Time: {stats['average_dwell_time'] or 0:.1f}s
Period: Last {stats['period_days']} days
"""
ax4.text(0.1, 0.5, metrics_text, fontsize=12, verticalalignment='center',
family='monospace')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 报告已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="在线反馈系统")
parser.add_argument("--collect", "-c", action="store_true", help="收集反馈模式")
parser.add_argument("--stats", "-s", action="store_true", help="显示统计")
parser.add_argument("--session-id", default="demo_session", help="会话ID")
parser.add_argument("--rating", "-r", type=int, default=None, help="评分 (1-5)")
parser.add_argument("--event", "-e", default=None, help="事件类型")
parser.add_argument("--query", "-q", default="What is RAG?", help="查询")
parser.add_argument("--answer", "-a", default="RAG stands for...", help="答案")
args = parser.parse_args()
collector = FeedbackCollector()
dashboard = FeedbackDashboard()
if args.collect:
# 模拟交互流程
print(f"[INFO] 开始会话: {args.session_id}")
collector.start_interaction(args.session_id, args.query, args.answer)
# 模拟一些事件
time.sleep(1)
collector.record_event(args.session_id, "scroll", {"depth": 0.8})
collector.record_event(args.session_id, "copy")
if args.rating is not None:
collector.record_explicit_feedback(args.session_id, args.rating)
print("[INFO] 反馈已记录")
else:
print("[INFO] 交互进行中... 使用 --rating 提交评分")
elif args.stats:
stats = collector.db.get_statistics(days=7)
print(json.dumps(stats, indent=2))
dashboard.generate_report(stats)
else:
parser.print_help()
if __name__ == "__main__":
main()
6.5.3 A/B测试
Python
#!/usr/bin/env python3
"""
Script: ab_testing_framework.py
功能: RAG系统A/B测试框架,支持分块策略、提示词、模型版本的对比实验
使用方式: python ab_testing_framework.py --create-experiment --name chunking_test --variants semantic,recursive
"""
import json
import hashlib
import random
import argparse
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import sqlite3
import numpy as np
from scipy import stats
@dataclass
class Experiment:
"""实验定义"""
id: str
name: str
variants: List[str] # 包括control
traffic_split: List[float] # 各变体流量比例
target_metric: str
min_sample_size: int
status: str = "running" # running, paused, completed
@dataclass
class ExperimentEvent:
"""实验事件"""
experiment_id: str
user_id: str
variant: str
event_type: str # impression, conversion, feedback
metric_value: float
timestamp: str
metadata: Dict
class ABTestManager:
"""A/B测试管理器"""
def __init__(self, db_path: str = "ab_tests.db"):
self.db_path = db_path
self.experiments: Dict[str, Experiment] = {}
self._init_db()
self._load_experiments()
def _init_db(self):
"""初始化数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS experiments (
id TEXT PRIMARY KEY,
config TEXT,
created_at TEXT,
status TEXT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
experiment_id TEXT,
user_id TEXT,
variant TEXT,
event_type TEXT,
metric_value REAL,
timestamp TEXT,
metadata TEXT
)
""")
conn.commit()
conn.close()
def _load_experiments(self):
"""加载实验"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT id, config FROM experiments WHERE status='running'")
for row in cursor.fetchall():
exp_id, config_json = row
config = json.loads(config_json)
self.experiments[exp_id] = Experiment(**config)
conn.close()
def create_experiment(self,
name: str,
variants: List[str],
traffic_split: Optional[List[float]] = None,
target_metric: str = "feedback_rating",
min_sample_size: int = 100) -> str:
"""
创建实验
Args:
name: 实验名称
variants: 变体列表(第一个为control)
traffic_split: 流量分配(默认均等)
target_metric: 目标指标
min_sample_size: 最小样本数
Returns:
实验ID
"""
if traffic_split is None:
traffic_split = [1.0 / len(variants)] * len(variants)
assert len(variants) == len(traffic_split)
assert abs(sum(traffic_split) - 1.0) < 0.01
exp_id = f"exp_{name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
exp = Experiment(
id=exp_id,
name=name,
variants=variants,
traffic_split=traffic_split,
target_metric=target_metric,
min_sample_size=min_sample_size
)
# 保存
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO experiments (id, config, created_at, status)
VALUES (?, ?, ?, ?)
""", (exp_id, json.dumps(asdict(exp)), datetime.now().isoformat(), "running"))
conn.commit()
conn.close()
self.experiments[exp_id] = exp
print(f"[INFO] 创建实验: {exp_id}")
print(f" 变体: {variants}")
print(f" 流量分配: {traffic_split}")
return exp_id
def assign_variant(self, experiment_id: str, user_id: str) -> str:
"""
为用户分配变体(一致性哈希)
Args:
experiment_id: 实验ID
user_id: 用户ID
Returns:
分配的变体名称
"""
if experiment_id not in self.experiments:
return "control"
exp = self.experiments[experiment_id]
# 一致性哈希
hash_val = int(hashlib.md5(f"{exp.id}_{user_id}".encode()).hexdigest(), 16)
normalized = hash_val / (2**128)
# 根据流量分配选择
cumulative = 0
for variant, split in zip(exp.variants, exp.traffic_split):
cumulative += split
if normalized <= cumulative:
return variant
return exp.variants[-1]
def record_event(self,
experiment_id: str,
user_id: str,
variant: str,
event_type: str,
metric_value: float = 0.0,
metadata: Dict = None):
"""记录实验事件"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO events
(experiment_id, user_id, variant, event_type, metric_value, timestamp, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
experiment_id,
user_id,
variant,
event_type,
metric_value,
datetime.now().isoformat(),
json.dumps(metadata or {})
))
conn.commit()
conn.close()
def get_results(self, experiment_id: str) -> Dict:
"""
获取实验结果与统计分析
Args:
experiment_id: 实验ID
Returns:
实验结果与统计检验
"""
if experiment_id not in self.experiments:
return {"error": "Experiment not found"}
exp = self.experiments[experiment_id]
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 获取各变体指标
results = {}
for variant in exp.variants:
cursor.execute("""
SELECT AVG(metric_value), COUNT(*), STDDEV(metric_value)
FROM events
WHERE experiment_id = ? AND variant = ? AND event_type = ?
""", (experiment_id, variant, exp.target_metric))
row = cursor.fetchone()
results[variant] = {
"mean": row[0] or 0,
"n": row[1] or 0,
"std": row[2] or 0
}
conn.close()
# 统计检验(Control vs Treatment)
control = exp.variants[0]
control_data = results[control]
comparisons = {}
for variant in exp.variants[1:]:
treat_data = results[variant]
# 两样本t检验(简化,实际应获取原始数据)
if treat_data["n"] > 30 and control_data["n"] > 30:
# 使用均值和标准差进行近似t检验
se = np.sqrt(
control_data["std"]**2 / control_data["n"] +
treat_data["std"]**2 / treat_data["n"]
)
t_stat = (treat_data["mean"] - control_data["mean"]) / (se + 1e-10)
# 自由度(Welch-Satterthwaite方程近似)
df = treat_data["n"] + control_data["n"] - 2
# p值
p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df))
# 置信区间
ci_low = (treat_data["mean"] - control_data["mean"]) - 1.96 * se
ci_high = (treat_data["mean"] - control_data["mean"]) + 1.96 * se
comparisons[f"{control}_vs_{variant}"] = {
"control_mean": control_data["mean"],
"treatment_mean": treat_data["mean"],
"absolute_diff": treat_data["mean"] - control_data["mean"],
"relative_lift": (treat_data["mean"] - control_data["mean"]) /
(control_data["mean"] + 1e-10),
"t_statistic": t_stat,
"p_value": p_value,
"significant": p_value < 0.05,
"confidence_interval": [ci_low, ci_high],
"sample_sizes": {
"control": control_data["n"],
"treatment": treat_data["n"]
}
}
return {
"experiment_id": experiment_id,
"experiment_name": exp.name,
"target_metric": exp.target_metric,
"variant_stats": results,
"comparisons": comparisons,
"recommendation": self._generate_recommendation(comparisons)
}
def _generate_recommendation(self, comparisons: Dict) -> str:
"""生成实验建议"""
if not comparisons:
return "数据不足"
significant_wins = []
for comp_name, stats in comparisons.items():
if stats["significant"] and stats["relative_lift"] > 0:
significant_wins.append((comp_name, stats["relative_lift"]))
if significant_wins:
best = max(significant_wins, key=lambda x: x[1])
return f"建议采用变体: {best[0].split('_vs_')[1]} (提升 {best[1]:.2%})"
return "当前无显著差异,建议继续实验或检查样本量"
def visualize_results(self, results: Dict, output_path: str = "ab_test_results.png"):
"""可视化实验结果"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# 1. 各变体表现对比
ax1 = axes[0]
variants = list(results["variant_stats"].keys())
means = [results["variant_stats"][v]["mean"] for v in variants]
stds = [results["variant_stats"][v]["std"] for v in variants]
x = np.arange(len(variants))
bars = ax1.bar(x, means, yerr=stds, capsize=5, color=['blue', 'green', 'red', 'orange'][:len(variants)], alpha=0.6)
ax1.set_xlabel('Variant')
ax1.set_ylabel(f'Mean {results["target_metric"]}')
ax1.set_title('Variant Performance Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels(variants, rotation=45)
# 添加数值标签
for bar in bars:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}', ha='center', va='bottom')
# 2. 效应量与置信区间
ax2 = axes[1]
comp_names = list(results["comparisons"].keys())
lifts = [results["comparisons"][c]["relative_lift"] for c in comp_names]
cis = [results["comparisons"][c]["confidence_interval"] for c in comp_names]
y_pos = np.arange(len(comp_names))
ax2.barh(y_pos, lifts, color=['green' if l > 0 else 'red' for l in lifts], alpha=0.6)
# 添加CI误差线
for i, (lift, ci) in enumerate(zip(lifts, cis)):
ax2.plot([ci[0], ci[1]], [i, i], 'k-', linewidth=2)
ax2.plot([ci[0], ci[1]], [i, i], 'k|', markersize=10)
ax2.axvline(x=0, color='black', linestyle='--', linewidth=0.8)
ax2.set_yticks(y_pos)
ax2.set_yticklabels(comp_names)
ax2.set_xlabel('Relative Lift')
ax2.set_title('Treatment Effect with 95% CI')
plt.suptitle(f"A/B Test Results: {results['experiment_name']}", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"[INFO] 结果图表已保存至: {output_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(description="A/B测试框架")
parser.add_argument("--create-experiment", "-c", action="store_true", help="创建实验")
parser.add_argument("--name", "-n", default="test_exp", help="实验名称")
parser.add_argument("--variants", "-v", default="control,treatment", help="变体列表(逗号分隔)")
parser.add_argument("--traffic-split", "-t", default=None, help="流量分配(逗号分隔)")
parser.add_argument("--assign", "-a", action="store_true", help="分配变体")
parser.add_argument("--experiment-id", "-e", default=None, help="实验ID")
parser.add_argument("--user-id", "-u", default="user_001", help="用户ID")
parser.add_argument("--record", "-r", action="store_true", help="记录事件")
parser.add_argument("--metric-value", "-m", type=float, default=0.0, help="指标值")
parser.add_argument("--results", action="store_true", help="查看结果")
parser.add_argument("--visualize", action="store_true", help="可视化")
args = parser.parse_args()
manager = ABTestManager()
if args.create_experiment:
variants = args.variants.split(',')
split = None
if args.traffic_split:
split = [float(x) for x in args.traffic_split.split(',')]
exp_id = manager.create_experiment(
name=args.name,
variants=variants,
traffic_split=split,
target_metric="feedback_rating",
min_sample_size=100
)
print(f"实验ID: {exp_id}")
elif args.assign:
if not args.experiment_id:
print("请提供 --experiment-id")
return
variant = manager.assign_variant(args.experiment_id, args.user_id)
print(f"用户 {args.user_id} 分配到变体: {variant}")
elif args.record:
if not args.experiment_id:
print("请提供 --experiment-id")
return
variant = manager.assign_variant(args.experiment_id, args.user_id)
manager.record_event(
args.experiment_id,
args.user_id,
variant,
"feedback_rating",
args.metric_value,
{"query": "test_query"}
)
print(f"已记录事件: {variant} -> {args.metric_value}")
elif args.results:
if not args.experiment_id:
# 列出所有实验
print("活跃实验:")
for exp_id, exp in manager.experiments.items():
print(f" {exp_id}: {exp.name}")
else:
results = manager.get_results(args.experiment_id)
print(json.dumps(results, indent=2))
if args.visualize:
manager.visualize_results(results)
else:
parser.print_help()
if __name__ == "__main__":
main()
6.5.4 持续学习
Python
#!/usr/bin/env python3
"""
Script: continual_learning.py
功能: 持续学习系统,实现Bad Case收集、自动标注与模型微调触发机制
使用方式: python continual_learning.py --collect-bad-case --query "q" --bad-answer "wrong" --correct-answer "right"
"""
import json
import os
import argparse
from typing import List, Dict, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import sqlite3
import numpy as np
from collections import defaultdict
@dataclass
class BadCase:
"""Bad Case记录"""
id: str
query: str
retrieved_contexts: List[str]
generated_answer: str
correct_answer: Optional[str] # 用户反馈或标注
failure_type: str # 'retrieval', 'generation', 'hallucination', 'incomplete'
timestamp: str
user_feedback: Optional[str] = None
severity: int = 1 # 1-5
class BadCaseMiner:
"""Bad Case挖掘器"""
FAILURE_PATTERNS = {
"retrieval_failure": {
"indicators": ["不知道", "无法找到", "没有相关信息", "检索失败"],
"description": "未能检索到相关文档"
},
"hallucination": {
"indicators": ["事实上", "实际上", "错误信息"],
"description": "生成内容包含幻觉"
},
"incomplete": {
"indicators": ["部分正确", "不完整", "缺少"],
"description": "答案不完整"
}
}
def classify_failure(self,
query: str,
answer: str,
contexts: List[str],
user_feedback: str = None) -> str:
"""
分类失效类型
Args:
query: 查询
answer: 生成的答案
contexts: 检索上下文
user_feedback: 用户反馈
Returns:
失效类型
"""
if user_feedback:
# 基于用户反馈分类
feedback_lower = user_feedback.lower()
for ftype, info in self.FAILURE_PATTERNS.items():
if any(ind in feedback_lower for ind in info["indicators"]):
return ftype
# 自动检测
if not contexts or all(len(c.strip()) < 10 for c in contexts):
return "retrieval_failure"
# 检查答案长度(过短可能不完整)
if len(answer) < 50:
return "incomplete"
# 检查是否包含不确定表述
uncertainty_phrases = ["我不确定", "可能没有", "也许是", "可能"]
if any(p in answer for p in uncertainty_phrases):
return "hallucination"
return "unknown"
def mine_hard_negatives(self,
query: str,
positive_contexts: List[str],
all_documents: List[str],
top_k: int = 5) -> List[str]:
"""
挖掘困难负样本
Args:
query: 查询
positive_contexts: 正样本(检索到的相关文档)
all_documents: 所有文档
top_k: 返回数量
Returns:
困难负样本列表
"""
# 简单实现:随机选择非正样本的文档
# 实际应使用BM25或向量相似度选择"接近但不相关"的文档
positive_set = set(positive_contexts)
negatives = [d for d in all_documents if d not in positive_set]
if len(negatives) <= top_k:
return negatives
# 随机选择(实际应基于相似度选择困难负样本)
import random
return random.sample(negatives, top_k)
class ContinuousLearningPipeline:
"""持续学习流水线"""
def __init__(self, db_path: str = "continual_learning.db"):
self.db_path = db_path
self.miner = BadCaseMiner()
self.bad_cases: List[BadCase] = []
self._init_db()
self.retraining_threshold = 100 # 触发微调的Bad Case数量阈值
def _init_db(self):
"""初始化数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS bad_cases (
id TEXT PRIMARY KEY,
query TEXT,
retrieved_contexts TEXT,
generated_answer TEXT,
correct_answer TEXT,
failure_type TEXT,
timestamp TEXT,
user_feedback TEXT,
severity INTEGER,
processed INTEGER DEFAULT 0
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS training_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query TEXT,
positive_doc TEXT,
negative_doc TEXT,
label INTEGER, -- 1 for positive, 0 for negative
source_bad_case TEXT,
created_at TEXT
)
""")
conn.commit()
conn.close()
def collect_bad_case(self,
query: str,
retrieved_contexts: List[str],
generated_answer: str,
correct_answer: Optional[str] = None,
user_feedback: Optional[str] = None) -> str:
"""
收集Bad Case
Returns:
Bad Case ID
"""
case_id = f"bc_{datetime.now().strftime('%Y%m%d%H%M%S')}_{hash(query) % 10000}"
failure_type = self.miner.classify_failure(
query, generated_answer, retrieved_contexts, user_feedback
)
bad_case = BadCase(
id=case_id,
query=query,
retrieved_contexts=retrieved_contexts,
generated_answer=generated_answer,
correct_answer=correct_answer,
failure_type=failure_type,
timestamp=datetime.now().isoformat(),
user_feedback=user_feedback,
severity=3 if failure_type == "hallucination" else 2
)
# 保存到数据库
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO bad_cases
(id, query, retrieved_contexts, generated_answer, correct_answer,
failure_type, timestamp, user_feedback, severity)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
bad_case.id,
bad_case.query,
json.dumps(bad_case.retrieved_contexts),
bad_case.generated_answer,
bad_case.correct_answer,
bad_case.failure_type,
bad_case.timestamp,
bad_case.user_feedback,
bad_case.severity
))
conn.commit()
conn.close()
self.bad_cases.append(bad_case)
print(f"[INFO] 收集Bad Case: {case_id}, 类型: {failure_type}")
# 检查是否触发再训练
if len(self.bad_cases) >= self.retraining_threshold:
self.trigger_retraining()
return case_id
def build_training_data(self) -> Dict[str, List]:
"""
从Bad Case构建训练数据
Returns:
训练数据集
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 获取未处理的Bad Case
cursor.execute("""
SELECT * FROM bad_cases
WHERE processed = 0
ORDER BY severity DESC, timestamp DESC
LIMIT 200
""")
rows = cursor.fetchall()
retrieval_training = [] # 用于检索模型
generation_training = [] # 用于生成模型
for row in rows:
(case_id, query, contexts_json, gen_answer, correct_answer,
failure_type, timestamp, feedback, severity, processed) = row
contexts = json.loads(contexts_json)
if failure_type == "retrieval_failure" and correct_answer:
# 构建检索训练数据(正样本:正确答案来源,负样本:实际检索到的无关内容)
# 这里简化处理
if contexts:
retrieval_training.append({
"query": query,
"positive": correct_answer, # 假设correct_answer是正样本来源
"negatives": contexts[:3], # 实际检索到的作为困难负样本
"type": "contrastive"
})
elif failure_type in ["hallucination", "incomplete"] and correct_answer:
# 构建生成训练数据(偏好对:错误答案 vs 正确答案)
generation_training.append({
"query": query,
"context": " ".join(contexts),
"rejected": gen_answer,
"chosen": correct_answer,
"type": "preference"
})
conn.close()
return {
"retrieval": retrieval_training,
"generation": generation_training,
"total_cases": len(rows)
}
def trigger_retraining(self):
"""触发模型微调"""
print("[ALERT] 触发持续学习...")
# 构建训练数据
train_data = self.build_training_data()
if train_data["total_cases"] < 50:
print("[INFO] Bad Case数量不足,跳过再训练")
return
# 模拟微调过程
print(f"[INFO] 构建训练数据:")
print(f" - 检索优化样本: {len(train_data['retrieval'])}")
print(f" - 生成优化样本: {len(train_data['generation'])}")
# 保存训练数据
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_file = f"training_data_{timestamp}.json"
with open(output_file, 'w') as f:
json.dump(train_data, f, indent=2)
print(f"[INFO] 训练数据已保存至: {output_file}")
# 标记已处理
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
UPDATE bad_cases SET processed = 1
WHERE processed = 0
""")
conn.commit()
conn.close()
print("[INFO] 已标记Bad Case为已处理,等待手动触发训练...")
# 这里应触发实际的训练流水线
# self._run_fine_tuning(train_data)
def get_statistics(self) -> Dict:
"""获取统计信息"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 按类型统计
cursor.execute("""
SELECT failure_type, COUNT(*), AVG(severity)
FROM bad_cases
GROUP BY failure_type
""")
type_stats = {row[0]: {"count": row[1], "avg_severity": row[2]}
for row in cursor.fetchall()}
# 时间趋势
cursor.execute("""
SELECT DATE(timestamp), COUNT(*)
FROM bad_cases
GROUP BY DATE(timestamp)
ORDER BY DATE(timestamp) DESC
LIMIT 7
""")
daily_trend = {row[0]: row[1] for row in cursor.fetchall()}
conn.close()
return {
"total_bad_cases": sum(s["count"] for s in type_stats.values()),
"by_type": type_stats,
"daily_trend": daily_trend,
"retraining_threshold": self.retraining_threshold
}
def main():
parser = argparse.ArgumentParser(description="持续学习系统")
parser.add_argument("--collect-bad-case", "-c", action="store_true", help="收集Bad Case")
parser.add_argument("--query", "-q", default="What is AI?", help="查询")
parser.add_argument("--contexts", default=None, help="检索上下文JSON文件")
parser.add_argument("--bad-answer", "-b", default="Wrong answer", help="错误答案")
parser.add_argument("--correct-answer", "-a", default=None, help="正确答案")
parser.add_argument("--feedback", "-f", default=None, help="用户反馈")
parser.add_argument("--stats", "-s", action="store_true", help="显示统计")
parser.add_argument("--build-training-data", "-t", action="store_true", help="构建训练数据")
args = parser.parse_args()
pipeline = ContinuousLearningPipeline()
if args.collect_bad_case:
contexts = []
if args.contexts:
with open(args.contexts, 'r') as f:
data = json.load(f)
contexts = [c.get("text", "") for c in data.get("selected_chunks", [])]
case_id = pipeline.collect_bad_case(
query=args.query,
retrieved_contexts=contexts,
generated_answer=args.bad_answer,
correct_answer=args.correct_answer,
user_feedback=args.feedback
)
print(f"Bad Case ID: {case_id}")
elif args.stats:
stats = pipeline.get_statistics()
print(json.dumps(stats, indent=2))
elif args.build_training_data:
data = pipeline.build_training_data()
print(f"构建完成: {data['total_cases']} 个案例")
print(f"检索训练样本: {len(data['retrieval'])}")
print(f"生成训练样本: {len(data['generation'])}")
else:
parser.print_help()
if __name__ == "__main__":
main()
以上代码构成了完整的企业级RAG知识库问答系统。每个脚本均可独立运行并具备可视化能力,涵盖从数据摄取、向量存储、检索优化到评估监控的全流程组件。