一、模块化设计与扩展架构
在前几篇中,我们详细介绍了RAG系统的架构、核心模块和部署方案。本篇将深入探讨如何扩展系统功能,实现自定义处理器和向量化方法,以应对更多样化的业务需求。
1.1 扩展性设计原则
我们的RAG系统遵循以下扩展性设计原则:
核心设计模式包括:
- 抽象基类:定义统一接口,确保实现一致性
- 工厂模式:根据配置动态创建适当的组件实例
- 策略模式:运行时选择不同的算法策略
- 装饰器模式:在不修改原有代码的情况下扩展功能
- 依赖注入:通过配置注入依赖,降低组件间耦合
二、自定义文档处理器开发
Read file: src/data_processing/processors/base.py
2.1 处理器抽象接口
所有处理器继承自BaseDocumentProcessor
抽象基类:
python
class BaseDocumentProcessor(ABC):
"""文档处理器基类。"""
def __init__(self, config: Optional[ProcessorConfig] = None):
"""使用配置初始化处理器。"""
self.config = config or ProcessorConfig()
self._setup_logging()
@abstractmethod
def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:
"""处理单个文件。"""
pass
@abstractmethod
def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
"""处理原始文本。"""
pass
2.2 实现自定义学术论文处理器
假设我们需要处理特定格式的学术论文PDF,可以这样实现:
python
class AcademicPaperProcessor(BaseDocumentProcessor):
"""学术论文处理器,专门处理论文PDF格式。"""
def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:
"""处理论文PDF文件。"""
if mime_type != "application/pdf":
raise ValueError(f"不支持的文件类型: {mime_type},仅支持PDF")
# 使用PDF解析库提取文本
text = self._extract_pdf_text(file_content)
# 解析论文结构(标题、摘要、章节等)
sections = self._parse_paper_structure(text)
# 为每个章节创建文档
documents = []
for section_name, section_text in sections.items():
# 创建元数据
metadata = {
"source": filename,
"section": section_name,
"paper_type": "academic",
"file_type": "pdf"
}
# 分块处理
chunks = self._split_section(section_text, section_name)
# 创建文档对象
for i, chunk in enumerate(chunks):
doc = Document(
page_content=chunk,
metadata={
**metadata,
"chunk_id": i,
"total_chunks": len(chunks)
}
)
documents.append(self._add_metadata(doc))
return documents
def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
"""处理原始文本。"""
# 解析论文结构
sections = self._parse_paper_structure(text)
# 同样的处理逻辑...
documents = []
# 处理各部分...
return documents
def _extract_pdf_text(self, pdf_bytes: bytes) -> str:
"""提取PDF文本,保留论文格式。"""
# 实现PDF文本提取,注意保留论文特有格式
import io
import PyPDF2
# 创建PDF阅读器
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
full_text = []
# 提取每页文本
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
full_text.append(page.extract_text())
return "\n\n".join(full_text)
def _parse_paper_structure(self, text: str) -> Dict[str, str]:
"""解析论文结构,识别标题、摘要、引言、方法、结果、讨论等章节。"""
import re
# 使用正则表达式或NLP技术识别论文结构
sections = {}
# 识别标题(通常在文档开始,字体较大)
title_match = re.search(r'^(.+?)(?=\n\n)', text)
if title_match:
sections["title"] = title_match.group(1).strip()
# 识别摘要(通常在标题后,以"Abstract"开头)
abstract_match = re.search(r'Abstract[:.\s]+(.+?)(?=\n\n\d+.|\n\nIntroduction)', text, re.DOTALL)
if abstract_match:
sections["abstract"] = abstract_match.group(1).strip()
# 识别其他章节
# 查找常见论文节标题
section_patterns = [
(r'Introduction(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "introduction"),
(r'Methods(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "methods"),
(r'Results(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "results"),
(r'Discussion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "discussion"),
(r'Conclusion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z]|$)', "conclusion"),
(r'References(?:\s|\n)+(.+?)(?=$)', "references")
]
for pattern, section_name in section_patterns:
section_match = re.search(pattern, text, re.DOTALL)
if section_match:
sections[section_name] = section_match.group(1).strip()
return sections
2.3 注册自定义处理器
开发完成后,将自定义处理器注册到处理器工厂:
python
# 在document_processor.py中注册新处理器
from src.data_processing.processors.academic_paper_processor import AcademicPaperProcessor
class DocumentProcessor:
"""文档处理器工厂,根据MIME类型选择合适的处理器。"""
def __init__(self, config=None):
"""初始化处理器工厂。"""
self.config = config or ProcessorConfig()
self._processors = {}
self._register_default_processors()
def _register_default_processors(self):
"""注册默认处理器。"""
# 现有处理器
self._processors.update({
"application/pdf": PDFProcessor(self.config),
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": WordProcessor(self.config),
# ...其他处理器
})
# 注册自定义学术论文处理器(覆盖默认PDF处理器)
if self.config.doc_type == DocumentType.ACADEMIC_PAPER:
self._processors["application/pdf"] = AcademicPaperProcessor(self.config)
def register_processor(self, mime_type, processor):
"""注册自定义处理器。"""
self._processors[mime_type] = processor
三、自定义向量化方法实现
Read file: src/data_processing/vectorization/base.py
3.1 向量化器接口与工厂模式
所有向量化器继承自BaseVectorizer
:
python
class BaseVectorizer(ABC):
"""向量化器基类,定义向量化器的接口。"""
def __init__(self, cache_dir: str = './cache/vectorization'):
"""初始化向量化器。"""
self.cache_dir = cache_dir
self._ensure_cache_dir()
@abstractmethod
def vectorize(self, text: str) -> np.ndarray:
"""将文本转换为向量。"""
pass
@abstractmethod
def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:
"""批量将文本转换为向量。"""
pass
工厂模式用于创建不同类型的向量化器:
python
class VectorizationFactory:
"""向量化工厂类,用于创建不同类型的向量化器。"""
_vectorizers = {
'tfidf': TfidfVectorizer,
'word2vec': Word2VecVectorizer,
'bert': BertVectorizer,
'bge-m3': BgeVectorizer
}
@classmethod
def create_vectorizer(cls, method: str = 'tfidf', **kwargs) -> BaseVectorizer:
"""创建向量化器。"""
method = method.lower()
if method not in cls._vectorizers:
supported_methods = ", ".join(cls._vectorizers.keys())
raise ValueError(f"不支持的向量化方法: {method}。支持的方法有: {supported_methods}")
# 创建向量化器...
3.2 OpenAI嵌入模型集成示例
下面我们实现一个OpenAI向量化器,将文本转换为OpenAI提供的嵌入向量:
python
import numpy as np
import os
import time
import logging
from openai import OpenAI
from typing import List
from .base import BaseVectorizer
class OpenAIVectorizer(BaseVectorizer):
"""使用OpenAI API的向量化器。"""
def __init__(self, model_name="text-embedding-3-small", batch_size=32,
api_key=None, dimensions=1536, cache_dir='./cache/vectorization'):
"""初始化OpenAI向量化器。
Args:
model_name: OpenAI嵌入模型名称
batch_size: 批处理大小
api_key: OpenAI API密钥
dimensions: 嵌入向量维度
cache_dir: 缓存目录
"""
super().__init__(cache_dir=cache_dir)
self.model_name = model_name
self.batch_size = batch_size
self.dimensions = dimensions
# 初始化OpenAI客户端
self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
self.logger = logging.getLogger(self.__class__.__name__)
def vectorize(self, text: str) -> np.ndarray:
"""将文本转换为向量。
Args:
text: 要向量化的文本
Returns:
文本的向量表示
"""
if not text.strip():
# 处理空文本
return np.zeros(self.dimensions)
try:
# 调用OpenAI API生成嵌入
response = self.client.embeddings.create(
model=self.model_name,
input=text,
dimensions=self.dimensions
)
# 提取嵌入向量
embedding = response.data[0].embedding
return np.array(embedding)
except Exception as e:
self.logger.error(f"OpenAI向量化失败: {str(e)}")
# 出错时返回零向量
return np.zeros(self.dimensions)
def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:
"""批量将文本转换为向量。
Args:
texts: 要向量化的文本列表
Returns:
文本的向量表示列表
"""
# 过滤空文本
texts = [text for text in texts if text.strip()]
if not texts:
return [np.zeros(self.dimensions)]
results = []
# 按批次处理,避免API限制
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i+self.batch_size]
try:
# 批量调用OpenAI API
response = self.client.embeddings.create(
model=self.model_name,
input=batch,
dimensions=self.dimensions
)
# 提取嵌入向量
batch_embeddings = [np.array(data.embedding) for data in response.data]
results.extend(batch_embeddings)
# 防止API限制
if len(texts) > self.batch_size and i + self.batch_size < len(texts):
time.sleep(0.5)
except Exception as e:
self.logger.error(f"OpenAI批量向量化失败: {str(e)}")
# 出错时为每个文本添加零向量
results.extend([np.zeros(self.dimensions) for _ in batch])
return results
def get_dimensions(self) -> int:
"""获取向量维度。"""
return self.dimensions
3.3 注册自定义向量化器
向量化工厂类需要注册新增的向量化器:
python
# 添加OpenAI向量化器到工厂类
from src.data_processing.vectorization.openai_vectorizer import OpenAIVectorizer
# 更新向量化工厂
class VectorizationFactory:
"""向量化工厂类,用于创建不同类型的向量化器。"""
_vectorizers = {
'tfidf': TfidfVectorizer,
'word2vec': Word2VecVectorizer,
'bert': BertVectorizer,
'bge-m3': BgeVectorizer,
'openai': OpenAIVectorizer # 添加新的向量化器
}
@staticmethod
def _get_config_from_env(method: str) -> Dict[str, Any]:
"""从环境变量获取配置。"""
config = {}
# ...现有配置逻辑...
# 添加OpenAI配置
elif method == 'openai':
config['model_name'] = os.getenv('OPENAI_EMBEDDING_MODEL', 'text-embedding-3-small')
config['batch_size'] = int(os.getenv('OPENAI_BATCH_SIZE', '32'))
config['dimensions'] = int(os.getenv('OPENAI_EMBEDDING_DIMENSIONS', '1536'))
# API密钥从环境变量获取
return config
四、混合检索策略实现
4.1 多模型混合检索器
在实际应用中,单一检索方法往往不能满足所有需求。我们可以实现一个混合检索策略,结合多种方法的优势:
python
import asyncio
import logging
from typing import List, Tuple, Dict, Any
import jieba
from langchain.schema import Document
from src.data_processing.vectorization.factory import VectorizationFactory
class HybridSearchRetriever:
"""混合检索器,结合多种检索方法的优势。"""
def __init__(self,
vector_store,
keyword_weight=0.3,
semantic_weight=0.7,
rerank_model=None):
"""初始化混合检索器。"""
self.vector_store = vector_store
self.keyword_weight = keyword_weight
self.semantic_weight = semantic_weight
self.rerank_model = rerank_model
self.logger = logging.getLogger(self.__class__.__name__)
# 初始化BM25索引
self._initialize_bm25_index()
# 初始化向量化器
self.vectorizer = VectorizationFactory.create_vectorizer('bge-m3')
def _initialize_bm25_index(self):
"""初始化BM25关键词索引。"""
from rank_bm25 import BM25Okapi
# 获取所有文档
docs = self.vector_store.get_all_documents()
texts = [doc.page_content for doc in docs]
# 分词处理
tokenized_corpus = [list(jieba.cut(text)) for text in texts]
self.bm25 = BM25Okapi(tokenized_corpus)
self.doc_ids = [doc.metadata.get('doc_id') for doc in docs]
self.documents = docs
async def retrieve(self, query: str, top_k: int = 5, threshold: float = 0.0):
"""混合检索实现。"""
# 1. 关键词检索
keyword_results = await self._keyword_search(query, top_k * 2)
# 2. 向量检索
vector_results = await self._vector_search(query, top_k * 2)
# 3. 合并结果
merged_results = self._merge_results(keyword_results, vector_results)
# 4. 重排序(如果有重排序模型)
if self.rerank_model and len(merged_results) > top_k:
merged_results = await self._rerank_results(query, merged_results, top_k)
# 5. 过滤低于阈值的结果
filtered_results = [
(doc, score) for doc, score in merged_results
if score >= threshold
]
# 返回前top_k个结果
return filtered_results[:top_k]
async def _keyword_search(self, query: str, top_k: int):
"""BM25关键词检索。"""
# 分词处理查询
tokenized_query = list(jieba.cut(query))
# 计算BM25分数
bm25_scores = self.bm25.get_scores(tokenized_query)
# 创建文档ID与分数的映射
results = []
for i, score in enumerate(bm25_scores):
if score > 0: # 只保留有分数的结果
results.append((self.documents[i], score))
# 按分数排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
async def _vector_search(self, query: str, top_k: int):
"""向量相似度检索。"""
# 调用向量存储的检索方法
return await self.vector_store.asimilarity_search_with_score(query, top_k)
def _merge_results(self, keyword_results, vector_results):
"""合并关键词和向量检索结果。"""
# 创建文档ID到结果的映射
merged_map = {}
# 处理关键词结果
max_keyword_score = max([score for _, score in keyword_results]) if keyword_results else 1.0
for doc, score in keyword_results:
doc_id = doc.metadata.get("id")
# 归一化分数到[0,1]
normalized_score = score / max_keyword_score
merged_map[doc_id] = {
"doc": doc,
"keyword_score": normalized_score,
"vector_score": 0.0
}
# 处理向量结果
for doc, score in vector_results:
doc_id = doc.metadata.get("id")
# 向量相似度已经在[0,1]范围内
if doc_id in merged_map:
# 更新现有条目
merged_map[doc_id]["vector_score"] = score
else:
# 添加新条目
merged_map[doc_id] = {
"doc": doc,
"keyword_score": 0.0,
"vector_score": score
}
# 计算加权总分
result_list = []
for item in merged_map.values():
final_score = (
self.keyword_weight * item["keyword_score"] +
self.semantic_weight * item["vector_score"]
)
result_list.append((item["doc"], final_score))
# 按最终得分排序
result_list.sort(key=lambda x: x[1], reverse=True)
return result_list
4.2 跨模态检索扩展
RAG系统除了处理文本,也可以扩展为处理图像、音频等多模态数据:
python
import os
import asyncio
import numpy as np
from typing import List, Tuple, Dict, Any
from langchain.schema import Document
class MultiModalRetriever:
"""多模态检索器,支持文本、图像等多种模态。"""
def __init__(self, vector_stores, embedding_models):
"""初始化多模态检索器。"""
self.vector_stores = vector_stores
self.embedding_models = embedding_models
async def retrieve(self, query, modal_type=None, top_k=5):
"""多模态检索实现。"""
# 自动检测模态类型
if modal_type == "auto":
modal_type = self._detect_modal_type(query)
# 如果是文本查询
if modal_type == "text":
# 使用CLIP文本编码器和BGE编码器生成向量
clip_embedding = self.embedding_models["clip"].encode_text(query)
bge_embedding = self.embedding_models["bge"].vectorize(query)
# 并行检索不同模态的数据
tasks = [
self.vector_stores["text"].asimilarity_search_by_vector(bge_embedding, top_k),
self.vector_stores["image"].asimilarity_search_by_vector(clip_embedding, top_k)
]
text_results, image_results = await asyncio.gather(*tasks)
# 合并结果
return self._merge_modal_results(text_results, image_results, top_k)
# 如果是图像查询
elif modal_type == "image":
# 使用CLIP图像编码器生成向量
image_embedding = self.embedding_models["clip"].encode_image(query)
# 检索图像相关的数据
results = await self.vector_stores["image"].asimilarity_search_by_vector(
image_embedding, top_k
)
return results
五、用户反馈优化机制
5.1 反馈数据收集与存储
为进一步提升RAG系统的检索质量,我们可以加入用户反馈机制:
python
import sqlite3
import logging
from datetime import datetime
from typing import List, Dict, Any
class FeedbackOptimizer:
"""基于用户反馈优化RAG检索结果。"""
def __init__(self, vector_store, feedback_db=None):
"""初始化反馈优化器。"""
self.vector_store = vector_store
self.feedback_db = feedback_db or self._initialize_feedback_db()
self.logger = logging.getLogger(self.__class__.__name__)
def _initialize_feedback_db(self):
"""初始化反馈数据库。"""
conn = sqlite3.connect('data/feedback.db')
c = conn.cursor()
# 创建反馈表
c.execute('''
CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query_text TEXT,
doc_id TEXT,
is_relevant INTEGER,
timestamp TEXT
)
''')
# 创建查询日志表
c.execute('''
CREATE TABLE IF NOT EXISTS query_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query_text TEXT,
results_count INTEGER,
timestamp TEXT
)
''')
conn.commit()
return conn
def record_feedback(self, query, doc_id, is_relevant):
"""记录用户反馈。"""
cursor = self.feedback_db.cursor()
cursor.execute(
"INSERT INTO feedback (query_text, doc_id, is_relevant, timestamp) VALUES (?, ?, ?, datetime('now'))",
(query, doc_id, 1 if is_relevant else 0)
)
self.feedback_db.commit()
def record_query(self, query, results_count):
"""记录查询日志。"""
cursor = self.feedback_db.cursor()
cursor.execute(
"INSERT INTO query_log (query_text, results_count, timestamp) VALUES (?, ?, datetime('now'))",
(query, results_count)
)
self.feedback_db.commit()
def get_relevance_feedback_for_query(self, query, limit=10):
"""获取特定查询的相关性反馈。"""
cursor = self.feedback_db.cursor()
cursor.execute(
"SELECT doc_id, is_relevant, COUNT(*) FROM feedback WHERE query_text = ? GROUP BY doc_id, is_relevant",
(query,)
)
return cursor.fetchall()
def optimize_results(self, query, initial_results, top_k=5):
"""基于历史反馈优化检索结果。"""
# 获取查询的历史反馈
feedback = self.get_relevance_feedback_for_query(query)
# 如果没有反馈,直接返回原始结果
if not feedback:
return initial_results[:top_k]
# 将反馈转换为字典形式,方便查找
feedback_dict = {}
for doc_id, is_relevant, count in feedback:
if doc_id not in feedback_dict:
feedback_dict[doc_id] = {"relevant": 0, "irrelevant": 0}
if is_relevant:
feedback_dict[doc_id]["relevant"] += count
else:
feedback_dict[doc_id]["irrelevant"] += count
# 应用反馈调整分数
adjusted_results = []
for doc, score in initial_results:
doc_id = doc.metadata.get("id")
# 计算反馈调整因子
adjustment = 0
if doc_id in feedback_dict:
# 相关反馈提高分数,不相关反馈降低分数
relevant = feedback_dict[doc_id]["relevant"]
irrelevant = feedback_dict[doc_id]["irrelevant"]
# 根据反馈比例调整分数
if relevant + irrelevant > 0:
adjustment = (relevant - irrelevant) / (relevant + irrelevant) * 0.2
# 应用调整后的分数
adjusted_score = min(1.0, max(0.0, score + adjustment))
adjusted_results.append((doc, adjusted_score))
# 按调整后的分数排序
adjusted_results.sort(key=lambda x: x[1], reverse=True)
# 记录本次查询
self.record_query(query, len(initial_results))
return adjusted_results[:top_k]
5.2 反馈界面实现
为了收集用户反馈,我们需要在前端界面添加反馈按钮:
python
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from typing import List, Optional
router = APIRouter()
class FeedbackRequest(BaseModel):
query: str
doc_id: str
is_relevant: bool
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""提交文档相关性反馈。"""
try:
feedback_optimizer = get_feedback_optimizer()
feedback_optimizer.record_feedback(
request.query,
request.doc_id,
request.is_relevant
)
return {"status": "success", "message": "反馈已记录"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"提交反馈失败: {str(e)}")
@router.get("/feedback/stats")
async def get_feedback_stats(query: Optional[str] = None):
"""获取反馈统计信息。"""
try:
feedback_optimizer = get_feedback_optimizer()
if query:
# 获取特定查询的反馈
stats = feedback_optimizer.get_relevance_feedback_for_query(query)
else:
# 获取全局反馈统计
stats = feedback_optimizer.get_global_feedback_stats()
return {"status": "success", "data": stats}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取反馈统计失败: {str(e)}")
六、查询意图分析器扩展
Read file: src/chains/processors/recommendation.py
6.1 自定义意图处理器
我们已经有了推荐查询处理器,现在我们可以扩展更多专用意图处理器:
python
import numpy as np
from typing import List, Dict, Any
from src.data_processing.vectorization.factory import VectorizationFactory
class QueryIntentClassifier:
"""查询意图分类器,基于向量相似度识别查询意图。"""
def __init__(self, embedding_model=None):
"""初始化意图分类器。"""
# 初始化向量化模型
self.embedding_model = embedding_model or VectorizationFactory.create_vectorizer("bge-m3")
# 定义意图和示例
self.intent_examples = {
"比较分析": [
"A和B有什么区别?",
"哪一个更好,X还是Y?",
"比较一下P和Q的优缺点"
],
"因果解释": [
"为什么会出现这种情况?",
"导致X的原因是什么?",
"这个问题的根源是什么?"
],
"列举信息": [
"列出所有的X",
"有哪些方法可以做Y?",
"X包含哪些组成部分?"
],
"概念解释": [
"什么是X?",
"X的定义是什么?",
"如何理解Y概念?"
],
"数据统计": [
"X的平均值是多少?",
"Y的增长率是多少?",
"Z的分布情况怎样?"
],
"操作指导": [
"如何做X?",
"执行Y的步骤是什么?",
"使用Z的方法有哪些?"
],
"推荐建议": [
"推荐几款好用的X",
"有什么适合Y的工具?",
"帮我选择一个合适的Z"
]
}
# 预计算意图向量
self.intent_vectors = self._compute_intent_vectors()
def _compute_intent_vectors(self):
"""预计算每种意图的平均向量表示。"""
intent_vectors = {}
for intent, examples in self.intent_examples.items():
# 获取每个示例的向量
vectors = self.embedding_model.batch_vectorize(examples)
# 计算平均向量
avg_vector = np.mean(vectors, axis=0)
# 标准化
avg_vector = avg_vector / np.linalg.norm(avg_vector)
# 存储意图向量
intent_vectors[intent] = avg_vector
return intent_vectors
def classify_intent(self, query: str):
"""分类查询意图。"""
# 向量化查询
query_vector = self.embedding_model.vectorize(query)
# 计算与各个意图的相似度
similarities = {}
for intent, vector in self.intent_vectors.items():
similarity = np.dot(query_vector, vector)
similarities[intent] = similarity
# 找出最相似的意图
max_intent = max(similarities, key=similarities.get)
max_similarity = similarities[max_intent]
# 如果最高相似度低于阈值,认为是一般信息查询
if max_similarity < 0.5:
return "一般信息查询", max_similarity
return max_intent, max_similarity
6.2 自定义比较分析处理器
每种意图需要专门的处理器,以比较分析为例:
python
import re
from typing import List, Dict, Any
from .base import QueryProcessor
class ComparisonQueryProcessor(QueryProcessor):
"""比较分析查询处理器,处理涉及比较的查询。"""
def process(self, query: str, documents: List[Any], **kwargs) -> Dict[str, Any]:
"""处理比较分析查询。"""
if not documents:
return {
"answer": "抱歉,没有找到相关的比较信息。",
"sources": []
}
# 提取需要比较的实体
entities = self._extract_comparison_entities(query)
if len(entities) < 2:
# 如果没有识别出多个实体,尝试从文档中提取
entities = self._extract_entities_from_documents(documents)
# 为每个实体收集信息
entity_info = self._collect_entity_information(entities, documents)
# 生成比较表格
comparison_table = self._generate_comparison_table(entity_info)
# 生成比较分析结论
conclusion = self._generate_comparison_conclusion(entity_info, query)
# 组合最终答案
answer = f"根据您的比较请求,以下是{', '.join(entities)}的对比分析:\n\n{comparison_table}\n\n{conclusion}"
return {
"answer": answer,
"sources": documents,
"entities": entities,
"comparison_table": comparison_table
}
def can_handle(self, query: str) -> bool:
"""判断是否是比较分析查询。"""
# 比较关键词
comparison_keywords = ["比较", "区别", "差异", "优缺点", "对比", "相比", "VS", "好坏"]
# 比较句式模式
comparison_patterns = [
r"(.+)和(.+)的区别",
r"(.+)与(.+)的(差异|不同)",
r"(.+)相比(.+)怎么样",
r"(.+)还是(.+)更好"
]
# 检查关键词
if any(keyword in query for keyword in comparison_keywords):
return True
# 检查句式模式
for pattern in comparison_patterns:
if re.search(pattern, query):
return True
return False
def _extract_comparison_entities(self, query: str) -> List[str]:
"""从查询中提取需要比较的实体。"""
# 通过正则表达式提取实体
patterns = [
r"(.+)和(.+)的区别",
r"(.+)与(.+)的(差异|不同)",
r"(.+)相比(.+)怎么样",
r"(.+)还是(.+)更好"
]
for pattern in patterns:
match = re.search(pattern, query)
if match:
# 提取匹配的实体
entities = [match.group(1).strip(), match.group(2).strip()]
return entities
# 使用分词和NER提取实体
# ...
return []
def _extract_entities_from_documents(self, documents: List[Any]) -> List[str]:
"""从文档中提取可能的比较实体。"""
# 实现从文档中提取实体的逻辑
# ...
return []
def _collect_entity_information(self, entities: List[str], documents: List[Any]) -> Dict[str, Dict]:
"""为每个实体从文档中收集信息。"""
entity_info = {}
for entity in entities:
entity_info[entity] = {
"advantages": [],
"disadvantages": [],
"features": {},
"mentions": 0
}
# 从文档中提取该实体的信息
for doc in documents:
text = doc.page_content
# 统计提及次数
if entity in text:
entity_info[entity]["mentions"] += text.count(entity)
# 提取优点
advantages = self._extract_advantages(text, entity)
entity_info[entity]["advantages"].extend(advantages)
# 提取缺点
disadvantages = self._extract_disadvantages(text, entity)
entity_info[entity]["disadvantages"].extend(disadvantages)
# 提取特性
features = self._extract_features(text, entity)
for feature, value in features.items():
if feature in entity_info[entity]["features"]:
entity_info[entity]["features"][feature].append(value)
else:
entity_info[entity]["features"][feature] = [value]
return entity_info
def _extract_advantages(self, text: str, entity: str) -> List[str]:
"""提取实体的优点。"""
patterns = [
f"{entity}的优点",
f"{entity}的好处",
f"{entity}的优势"
]
# ...提取逻辑
return []
def _extract_disadvantages(self, text: str, entity: str) -> List[str]:
"""提取实体的缺点。"""
# ...提取逻辑
return []
def _extract_features(self, text: str, entity: str) -> Dict[str, str]:
"""提取实体的特性。"""
# ...提取逻辑
return {}
def _generate_comparison_table(self, entity_info: Dict[str, Dict]) -> str:
"""生成比较表格。"""
# 构建表格头部
entities = list(entity_info.keys())
table = f"| 特性 | {' | '.join(entities)} |\n"
table += f"| --- | {' | '.join(['---' for _ in entities])} |\n"
# 添加共同特性行
all_features = set()
for entity, info in entity_info.items():
all_features.update(info["features"].keys())
for feature in sorted(all_features):
row = f"| {feature} | "
for entity in entities:
if feature in entity_info[entity]["features"]:
values = entity_info[entity]["features"][feature]
row += f"{values[0]} | "
else:
row += "- | "
table += row + "\n"
# 添加优缺点行
table += f"| 优点 | {' | '.join([', '.join(info['advantages'][:3]) or '-' for _, info in entity_info.items()])} |\n"
table += f"| 缺点 | {' | '.join([', '.join(info['disadvantages'][:3]) or '-' for _, info in entity_info.items()])} |\n"
return table
def _generate_comparison_conclusion(self, entity_info: Dict[str, Dict], query: str) -> str:
"""生成比较分析结论。"""
entities = list(entity_info.keys())
if len(entities) < 2:
return "无法生成比较结论,找不到足够的实体信息。"
# 基于优缺点和特性生成结论
# ...生成逻辑
return "根据以上对比,每个选项都有各自的优缺点,具体选择取决于您的具体需求和场景。"
七、实际案例分析
7.1 自定义法律文档处理器
法律文档有其特殊性,以下是一个专门处理法律文档的处理器示例:
python
import re
from typing import List, Dict, Any, Optional
from langchain.schema import Document
from src.data_processing.processors.base import BaseDocumentProcessor, ProcessorConfig
class LegalDocumentProcessor(BaseDocumentProcessor):
"""法律文档处理器,专门处理法律文书。"""
def __init__(self, config: Optional[ProcessorConfig] = None):
"""初始化法律文档处理器。"""
super().__init__(config)
# 法律术语词典
self.legal_terms = self._load_legal_terms()
def _load_legal_terms(self):
"""加载法律术语词典。"""
# 实际应用中应从外部文件加载
return {
"原告": "起诉方,请求法院裁判的一方",
"被告": "被起诉方,被请求法院裁判的一方",
"诉讼": "通过法院解决纠纷的法律程序",
# 更多法律术语...
}
def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:
"""处理法律文档文件。"""
# 提取文本
if mime_type == "application/pdf":
text = self._extract_pdf_text(file_content)
elif mime_type == "application/msword" or mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
text = self._extract_word_text(file_content)
else:
text = file_content.decode(self.config.encoding, errors='ignore')
# 处理提取出的文本
return self.process_text(text, {"source": filename, "mime_type": mime_type})
def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
"""处理法律文档文本。"""
# 解析文档结构
sections = self._parse_legal_document_structure(text)
# 为每个部分创建文档
documents = []
for section_name, section_content in sections.items():
# 创建元数据
section_metadata = {
**(metadata or {}),
"section": section_name,
"document_type": "legal"
}
# 为法律术语添加注解
annotated_content = self._annotate_legal_terms(section_content)
# 创建文档
doc = Document(
page_content=annotated_content,
metadata=section_metadata
)
documents.append(self._add_metadata(doc))
return documents
def _parse_legal_document_structure(self, text: str) -> Dict[str, str]:
"""解析法律文档结构。"""
sections = {}
# 常见法律文书结构
section_patterns = [
(r"案号[::]\s*(.+?)(?=\n)", "case_number"),
(r"原告[::]\s*(.+?)(?=\n被告)", "plaintiff"),
(r"被告[::]\s*(.+?)(?=\n)", "defendant"),
(r"诉讼请求[::]\s*(.+?)(?=\n)", "claims"),
(r"事实与理由[::]\s*(.+?)(?=\n)", "facts_and_reasons"),
(r"裁判结果[::]\s*(.+?)(?=\n)", "judgment"),
(r"裁判理由[::]\s*(.+?)(?=\n)", "reasoning")
]
# 提取各部分内容
for pattern, section_name in section_patterns:
match = re.search(pattern, text, re.DOTALL)
if match:
sections[section_name] = match.group(1).strip()
# 如果没有提取到结构化内容,按段落分割
if not sections:
paragraphs = re.split(r'\n\s*\n', text)
for i, para in enumerate(paragraphs):
sections[f"paragraph_{i+1}"] = para.strip()
return sections
def _annotate_legal_terms(self, text: str) -> str:
"""为法律术语添加注解。"""
annotated_text = text
for term, definition in self.legal_terms.items():
# 术语可能在文本中多次出现,只添加一次注解
if term in annotated_text:
# 在第一次出现的地方添加注解
term_with_note = f"{term}[注: {definition}]"
annotated_text = annotated_text.replace(term, term_with_note, 1)
return annotated_text
7.2 金融数据向量化器
针对金融数据的特殊性,我们可以实现专用的向量化器:
python
import numpy as np
from typing import List, Dict, Any
from src.data_processing.vectorization.base import BaseVectorizer
import re
class FinancialDataVectorizer(BaseVectorizer):
"""金融数据向量化器,专门处理金融文本和数据。"""
def __init__(self, cache_dir='./cache/vectorization',
base_model='bge-m3',
numerical_weight=0.3):
"""初始化金融数据向量化器。"""
super().__init__(cache_dir)
# 使用基础模型作为语义向量化基础
from src.data_processing.vectorization.factory import VectorizationFactory
self.base_vectorizer = VectorizationFactory.create_vectorizer(base_model)
# 数值特征权重
self.numerical_weight = numerical_weight
# 金融术语词典
self.financial_terms = self._load_financial_terms()
def _load_financial_terms(self):
"""加载金融术语词典。"""
# 实际应用中应从外部文件加载
return [
"股票", "债券", "基金", "期货", "期权", "保险", "理财",
"利率", "汇率", "通货膨胀", "GDP", "PPI", "CPI", "PMI",
"资产", "负债", "股东", "投资", "风险", "收益", "波动"
]
def vectorize(self, text: str) -> np.ndarray:
"""将金融文本转换为向量。"""
# 1. 提取数值特征
numerical_features = self._extract_numerical_features(text)
# 2. 获取基础语义向量
semantic_vector = self.base_vectorizer.vectorize(text)
# 3. 融合数值特征和语义向量
combined_vector = self._combine_features(semantic_vector, numerical_features)
return combined_vector
def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:
"""批量将金融文本转换为向量。"""
results = []
for text in texts:
vector = self.vectorize(text)
results.append(vector)
return results
def _extract_numerical_features(self, text: str) -> Dict[str, float]:
"""提取文本中的数值特征。"""
features = {}
# 提取百分比
percentage_pattern = r'(\d+.?\d*)%'
percentages = re.findall(percentage_pattern, text)
if percentages:
features['percentage_avg'] = sum(float(p) for p in percentages) / len(percentages)
features['percentage_count'] = len(percentages)
# 提取金额
amount_pattern = r'(\d+.?\d*)\s*(万|亿|千|百万|美元|元|美金|英镑|欧元)'
amounts = re.findall(amount_pattern, text)
if amounts:
# 转换为标准单位
std_amounts = []
for amount, unit in amounts:
value = float(amount)
if unit == '万':
value *= 10000
elif unit == '亿':
value *= 100000000
# 其他单位转换...
std_amounts.append(value)
if std_amounts:
features['amount_avg'] = sum(std_amounts) / len(std_amounts)
features['amount_max'] = max(std_amounts)
features['amount_count'] = len(std_amounts)
# 提取日期差(如时间跨度)
# ...日期提取和计算逻辑
return features
def _combine_features(self, semantic_vector: np.ndarray, numerical_features: Dict[str, float]) -> np.ndarray:
"""将数值特征与语义向量融合。"""
# 1. 将数值特征标准化
if not numerical_features:
return semantic_vector
# 创建固定长度的数值特征向量
numerical_vector = np.zeros(10) # 预留10个维度给数值特征
# 填充数值特征
feature_index = {
'percentage_avg': 0,
'percentage_count': 1,
'amount_avg': 2,
'amount_max': 3,
'amount_count': 4,
# 其他特征索引...
}
for feature, value in numerical_features.items():
if feature in feature_index:
numerical_vector[feature_index[feature]] = value
# 标准化数值特征
num_max = np.max(numerical_vector) if np.max(numerical_vector) > 0 else 1.0
numerical_vector = numerical_vector / num_max
# 2. 调整原始向量维度以适应数值特征
original_dim = semantic_vector.shape[0]
new_dim = original_dim - len(numerical_vector) # 确保最终维度不变
# 对语义向量应用PCA或截断,使其维度减少
from sklearn.decomposition import PCA
pca = PCA(n_components=new_dim)
semantic_reduced = pca.fit_transform(semantic_vector.reshape(1, -1)).flatten()
# 3. 融合向量
# 语义向量权重
semantic_weight = 1 - self.numerical_weight
# 拼接向量
combined = np.concatenate([
semantic_reduced * semantic_weight,
numerical_vector * self.numerical_weight
])
# 标准化最终向量
combined = combined / np.linalg.norm(combined)
return combined
八、总结与展望
8.1 扩展RAG系统的最佳实践
通过本文的介绍,我们展示了如何在RAG系统中实现高度自定义的功能扩展。总结最佳实践如下:
- 抽象基类设计:使用抽象基类定义统一接口,确保各组件遵循相同约定
- 工厂模式解耦:通过工厂模式创建组件实例,降低组件间耦合
- 配置驱动初始化:使用环境变量和配置文件驱动组件初始化,提高灵活性
- 专业化处理器:针对特定领域或文档类型开发专用处理器,提高处理质量
- 混合检索策略:结合多种检索方法,平衡关键词匹配和语义检索的优势
- 用户反馈闭环:收集用户反馈并应用于结果优化,持续提升检索质量
- 查询意图分析:根据查询意图选择专用处理器,提供更精准的回答
8.2 未来扩展方向
RAG系统仍有多个可以探索的扩展方向:
- 多模态处理:扩展到图像、音频、视频等多模态数据处理
- 时间感知检索:支持时间序列数据和趋势分析查询
- 自适应检索:根据用户历史查询行为自动调整检索策略
- 多源融合:支持多数据源查询结果的智能融合
- 可解释检索:提供检索结果的可解释性,帮助用户理解结果来源
- 离线预计算:为常见查询路径预计算结果,提升响应速度
下篇预告:《RAG系统效能提升的七个关键实践》将详解:
- 分块策略优化(表格/代码/文本差异化处理)
- 缓存机制设计(向量缓存/结果缓存/模型缓存)
- 异步处理实现(文档处理流水线优化)
- 安全防护方案(输入过滤/权限控制)
- 效果评估方法(检索准确率/响应时间/QPS)