扩展你的RAG系统:自定义处理器与向量化方法

一、模块化设计与扩展架构

在前几篇中,我们详细介绍了RAG系统的架构、核心模块和部署方案。本篇将深入探讨如何扩展系统功能,实现自定义处理器和向量化方法,以应对更多样化的业务需求。

1.1 扩展性设计原则

我们的RAG系统遵循以下扩展性设计原则:

核心设计模式包括:

  1. 抽象基类:定义统一接口,确保实现一致性
  2. 工厂模式:根据配置动态创建适当的组件实例
  3. 策略模式:运行时选择不同的算法策略
  4. 装饰器模式:在不修改原有代码的情况下扩展功能
  5. 依赖注入:通过配置注入依赖,降低组件间耦合

二、自定义文档处理器开发

Read file: src/data_processing/processors/base.py

2.1 处理器抽象接口

所有处理器继承自BaseDocumentProcessor抽象基类:

python 复制代码
class BaseDocumentProcessor(ABC):
    """文档处理器基类。"""
    
    def __init__(self, config: Optional[ProcessorConfig] = None):
        """使用配置初始化处理器。"""
        self.config = config or ProcessorConfig()
        self._setup_logging()
    
    @abstractmethod
    def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:
        """处理单个文件。"""
        pass
    
    @abstractmethod
    def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
        """处理原始文本。"""
        pass

2.2 实现自定义学术论文处理器

假设我们需要处理特定格式的学术论文PDF,可以这样实现:

python 复制代码
class AcademicPaperProcessor(BaseDocumentProcessor):
    """学术论文处理器,专门处理论文PDF格式。"""
    
    def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:
        """处理论文PDF文件。"""
        if mime_type != "application/pdf":
            raise ValueError(f"不支持的文件类型: {mime_type},仅支持PDF")
        
        # 使用PDF解析库提取文本
        text = self._extract_pdf_text(file_content)
        
        # 解析论文结构(标题、摘要、章节等)
        sections = self._parse_paper_structure(text)
        
        # 为每个章节创建文档
        documents = []
        for section_name, section_text in sections.items():
            # 创建元数据
            metadata = {
                "source": filename,
                "section": section_name,
                "paper_type": "academic",
                "file_type": "pdf"
            }
            
            # 分块处理
            chunks = self._split_section(section_text, section_name)
            
            # 创建文档对象
            for i, chunk in enumerate(chunks):
                doc = Document(
                    page_content=chunk,
                    metadata={
                        **metadata,
                        "chunk_id": i,
                        "total_chunks": len(chunks)
                    }
                )
                documents.append(self._add_metadata(doc))
        
        return documents
    
    def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
        """处理原始文本。"""
        # 解析论文结构
        sections = self._parse_paper_structure(text)
        
        # 同样的处理逻辑...
        documents = []
        # 处理各部分...
        
        return documents
    
    def _extract_pdf_text(self, pdf_bytes: bytes) -> str:
        """提取PDF文本,保留论文格式。"""
        # 实现PDF文本提取,注意保留论文特有格式
        import io
        import PyPDF2
        
        # 创建PDF阅读器
        pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
        full_text = []
        
        # 提取每页文本
        for page_num in range(len(pdf_reader.pages)):
            page = pdf_reader.pages[page_num]
            full_text.append(page.extract_text())
        
        return "\n\n".join(full_text)
        
    def _parse_paper_structure(self, text: str) -> Dict[str, str]:
        """解析论文结构,识别标题、摘要、引言、方法、结果、讨论等章节。"""
        import re
        # 使用正则表达式或NLP技术识别论文结构
        sections = {}
        
        # 识别标题(通常在文档开始,字体较大)
        title_match = re.search(r'^(.+?)(?=\n\n)', text)
        if title_match:
            sections["title"] = title_match.group(1).strip()
        
        # 识别摘要(通常在标题后,以"Abstract"开头)
        abstract_match = re.search(r'Abstract[:.\s]+(.+?)(?=\n\n\d+.|\n\nIntroduction)', text, re.DOTALL)
        if abstract_match:
            sections["abstract"] = abstract_match.group(1).strip()
        
        # 识别其他章节
        # 查找常见论文节标题
        section_patterns = [
            (r'Introduction(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "introduction"),
            (r'Methods(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "methods"),
            (r'Results(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "results"),
            (r'Discussion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "discussion"),
            (r'Conclusion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z]|$)', "conclusion"),
            (r'References(?:\s|\n)+(.+?)(?=$)', "references")
        ]
        
        for pattern, section_name in section_patterns:
            section_match = re.search(pattern, text, re.DOTALL)
            if section_match:
                sections[section_name] = section_match.group(1).strip()
        
        return sections

2.3 注册自定义处理器

开发完成后,将自定义处理器注册到处理器工厂:

python 复制代码
# 在document_processor.py中注册新处理器
from src.data_processing.processors.academic_paper_processor import AcademicPaperProcessor

class DocumentProcessor:
    """文档处理器工厂,根据MIME类型选择合适的处理器。"""
    
    def __init__(self, config=None):
        """初始化处理器工厂。"""
        self.config = config or ProcessorConfig()
        self._processors = {}
        self._register_default_processors()
    
    def _register_default_processors(self):
        """注册默认处理器。"""
        # 现有处理器
        self._processors.update({
            "application/pdf": PDFProcessor(self.config),
            "application/vnd.openxmlformats-officedocument.wordprocessingml.document": WordProcessor(self.config),
            # ...其他处理器
        })
        
        # 注册自定义学术论文处理器(覆盖默认PDF处理器)
        if self.config.doc_type == DocumentType.ACADEMIC_PAPER:
            self._processors["application/pdf"] = AcademicPaperProcessor(self.config)
    
    def register_processor(self, mime_type, processor):
        """注册自定义处理器。"""
        self._processors[mime_type] = processor

三、自定义向量化方法实现

Read file: src/data_processing/vectorization/base.py

3.1 向量化器接口与工厂模式

所有向量化器继承自BaseVectorizer

python 复制代码
class BaseVectorizer(ABC):
    """向量化器基类,定义向量化器的接口。"""
    
    def __init__(self, cache_dir: str = './cache/vectorization'):
        """初始化向量化器。"""
        self.cache_dir = cache_dir
        self._ensure_cache_dir()
    
    @abstractmethod
    def vectorize(self, text: str) -> np.ndarray:
        """将文本转换为向量。"""
        pass
    
    @abstractmethod
    def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:
        """批量将文本转换为向量。"""
        pass

工厂模式用于创建不同类型的向量化器:

python 复制代码
class VectorizationFactory:
    """向量化工厂类,用于创建不同类型的向量化器。"""
    
    _vectorizers = {
        'tfidf': TfidfVectorizer,
        'word2vec': Word2VecVectorizer,
        'bert': BertVectorizer,
        'bge-m3': BgeVectorizer
    }
    
    @classmethod
    def create_vectorizer(cls, method: str = 'tfidf', **kwargs) -> BaseVectorizer:
        """创建向量化器。"""
        method = method.lower()
        if method not in cls._vectorizers:
            supported_methods = ", ".join(cls._vectorizers.keys())
            raise ValueError(f"不支持的向量化方法: {method}。支持的方法有: {supported_methods}")
        
        # 创建向量化器...

3.2 OpenAI嵌入模型集成示例

下面我们实现一个OpenAI向量化器,将文本转换为OpenAI提供的嵌入向量:

python 复制代码
import numpy as np
import os
import time
import logging
from openai import OpenAI
from typing import List
from .base import BaseVectorizer

class OpenAIVectorizer(BaseVectorizer):
    """使用OpenAI API的向量化器。"""
    
    def __init__(self, model_name="text-embedding-3-small", batch_size=32, 
                 api_key=None, dimensions=1536, cache_dir='./cache/vectorization'):
        """初始化OpenAI向量化器。
        
        Args:
            model_name: OpenAI嵌入模型名称
            batch_size: 批处理大小
            api_key: OpenAI API密钥
            dimensions: 嵌入向量维度
            cache_dir: 缓存目录
        """
        super().__init__(cache_dir=cache_dir)
        self.model_name = model_name
        self.batch_size = batch_size
        self.dimensions = dimensions
        
        # 初始化OpenAI客户端
        self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
        self.logger = logging.getLogger(self.__class__.__name__)
    
    def vectorize(self, text: str) -> np.ndarray:
        """将文本转换为向量。
        
        Args:
            text: 要向量化的文本
            
        Returns:
            文本的向量表示
        """
        if not text.strip():
            # 处理空文本
            return np.zeros(self.dimensions)
        
        try:
            # 调用OpenAI API生成嵌入
            response = self.client.embeddings.create(
                model=self.model_name,
                input=text,
                dimensions=self.dimensions
            )
            
            # 提取嵌入向量
            embedding = response.data[0].embedding
            return np.array(embedding)
            
        except Exception as e:
            self.logger.error(f"OpenAI向量化失败: {str(e)}")
            # 出错时返回零向量
            return np.zeros(self.dimensions)
    
    def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:
        """批量将文本转换为向量。
        
        Args:
            texts: 要向量化的文本列表
            
        Returns:
            文本的向量表示列表
        """
        # 过滤空文本
        texts = [text for text in texts if text.strip()]
        if not texts:
            return [np.zeros(self.dimensions)]
        
        results = []
        # 按批次处理,避免API限制
        for i in range(0, len(texts), self.batch_size):
            batch = texts[i:i+self.batch_size]
            
            try:
                # 批量调用OpenAI API
                response = self.client.embeddings.create(
                    model=self.model_name,
                    input=batch,
                    dimensions=self.dimensions
                )
                
                # 提取嵌入向量
                batch_embeddings = [np.array(data.embedding) for data in response.data]
                results.extend(batch_embeddings)
                
                # 防止API限制
                if len(texts) > self.batch_size and i + self.batch_size < len(texts):
                    time.sleep(0.5)
                    
            except Exception as e:
                self.logger.error(f"OpenAI批量向量化失败: {str(e)}")
                # 出错时为每个文本添加零向量
                results.extend([np.zeros(self.dimensions) for _ in batch])
        
        return results
    
    def get_dimensions(self) -> int:
        """获取向量维度。"""
        return self.dimensions

3.3 注册自定义向量化器

向量化工厂类需要注册新增的向量化器:

python 复制代码
# 添加OpenAI向量化器到工厂类
from src.data_processing.vectorization.openai_vectorizer import OpenAIVectorizer

# 更新向量化工厂
class VectorizationFactory:
    """向量化工厂类,用于创建不同类型的向量化器。"""
    
    _vectorizers = {
        'tfidf': TfidfVectorizer,
        'word2vec': Word2VecVectorizer,
        'bert': BertVectorizer,
        'bge-m3': BgeVectorizer,
        'openai': OpenAIVectorizer  # 添加新的向量化器
    }
    
    @staticmethod
    def _get_config_from_env(method: str) -> Dict[str, Any]:
        """从环境变量获取配置。"""
        config = {}
        # ...现有配置逻辑...
        
        # 添加OpenAI配置
        elif method == 'openai':
            config['model_name'] = os.getenv('OPENAI_EMBEDDING_MODEL', 'text-embedding-3-small')
            config['batch_size'] = int(os.getenv('OPENAI_BATCH_SIZE', '32'))
            config['dimensions'] = int(os.getenv('OPENAI_EMBEDDING_DIMENSIONS', '1536'))
            # API密钥从环境变量获取
        
        return config

四、混合检索策略实现

4.1 多模型混合检索器

在实际应用中,单一检索方法往往不能满足所有需求。我们可以实现一个混合检索策略,结合多种方法的优势:

python 复制代码
import asyncio
import logging
from typing import List, Tuple, Dict, Any
import jieba
from langchain.schema import Document
from src.data_processing.vectorization.factory import VectorizationFactory

class HybridSearchRetriever:
    """混合检索器,结合多种检索方法的优势。"""
    
    def __init__(self, 
                 vector_store,
                 keyword_weight=0.3, 
                 semantic_weight=0.7,
                 rerank_model=None):
        """初始化混合检索器。"""
        self.vector_store = vector_store
        self.keyword_weight = keyword_weight
        self.semantic_weight = semantic_weight
        self.rerank_model = rerank_model
        self.logger = logging.getLogger(self.__class__.__name__)
        
        # 初始化BM25索引
        self._initialize_bm25_index()
        
        # 初始化向量化器
        self.vectorizer = VectorizationFactory.create_vectorizer('bge-m3')
    
    def _initialize_bm25_index(self):
        """初始化BM25关键词索引。"""
        from rank_bm25 import BM25Okapi
        
        # 获取所有文档
        docs = self.vector_store.get_all_documents()
        texts = [doc.page_content for doc in docs]
        
        # 分词处理
        tokenized_corpus = [list(jieba.cut(text)) for text in texts]
        self.bm25 = BM25Okapi(tokenized_corpus)
        self.doc_ids = [doc.metadata.get('doc_id') for doc in docs]
        self.documents = docs
    
    async def retrieve(self, query: str, top_k: int = 5, threshold: float = 0.0):
        """混合检索实现。"""
        # 1. 关键词检索
        keyword_results = await self._keyword_search(query, top_k * 2)
        
        # 2. 向量检索
        vector_results = await self._vector_search(query, top_k * 2)
        
        # 3. 合并结果
        merged_results = self._merge_results(keyword_results, vector_results)
        
        # 4. 重排序(如果有重排序模型)
        if self.rerank_model and len(merged_results) > top_k:
            merged_results = await self._rerank_results(query, merged_results, top_k)
        
        # 5. 过滤低于阈值的结果
        filtered_results = [
            (doc, score) for doc, score in merged_results 
            if score >= threshold
        ]
        
        # 返回前top_k个结果
        return filtered_results[:top_k]
    
    async def _keyword_search(self, query: str, top_k: int):
        """BM25关键词检索。"""
        # 分词处理查询
        tokenized_query = list(jieba.cut(query))
        
        # 计算BM25分数
        bm25_scores = self.bm25.get_scores(tokenized_query)
        
        # 创建文档ID与分数的映射
        results = []
        for i, score in enumerate(bm25_scores):
            if score > 0:  # 只保留有分数的结果
                results.append((self.documents[i], score))
        
        # 按分数排序
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]
    
    async def _vector_search(self, query: str, top_k: int):
        """向量相似度检索。"""
        # 调用向量存储的检索方法
        return await self.vector_store.asimilarity_search_with_score(query, top_k)
    
    def _merge_results(self, keyword_results, vector_results):
        """合并关键词和向量检索结果。"""
        # 创建文档ID到结果的映射
        merged_map = {}
        
        # 处理关键词结果
        max_keyword_score = max([score for _, score in keyword_results]) if keyword_results else 1.0
        for doc, score in keyword_results:
            doc_id = doc.metadata.get("id")
            # 归一化分数到[0,1]
            normalized_score = score / max_keyword_score
            merged_map[doc_id] = {
                "doc": doc,
                "keyword_score": normalized_score,
                "vector_score": 0.0
            }
        
        # 处理向量结果
        for doc, score in vector_results:
            doc_id = doc.metadata.get("id")
            # 向量相似度已经在[0,1]范围内
            if doc_id in merged_map:
                # 更新现有条目
                merged_map[doc_id]["vector_score"] = score
            else:
                # 添加新条目
                merged_map[doc_id] = {
                    "doc": doc,
                    "keyword_score": 0.0,
                    "vector_score": score
                }
        
        # 计算加权总分
        result_list = []
        for item in merged_map.values():
            final_score = (
                self.keyword_weight * item["keyword_score"] +
                self.semantic_weight * item["vector_score"]
            )
            result_list.append((item["doc"], final_score))
        
        # 按最终得分排序
        result_list.sort(key=lambda x: x[1], reverse=True)
        return result_list

4.2 跨模态检索扩展

RAG系统除了处理文本,也可以扩展为处理图像、音频等多模态数据:

python 复制代码
import os
import asyncio
import numpy as np
from typing import List, Tuple, Dict, Any
from langchain.schema import Document

class MultiModalRetriever:
    """多模态检索器,支持文本、图像等多种模态。"""
    
    def __init__(self, vector_stores, embedding_models):
        """初始化多模态检索器。"""
        self.vector_stores = vector_stores
        self.embedding_models = embedding_models
    
    async def retrieve(self, query, modal_type=None, top_k=5):
        """多模态检索实现。"""
        # 自动检测模态类型
        if modal_type == "auto":
            modal_type = self._detect_modal_type(query)
        
        # 如果是文本查询
        if modal_type == "text":
            # 使用CLIP文本编码器和BGE编码器生成向量
            clip_embedding = self.embedding_models["clip"].encode_text(query)
            bge_embedding = self.embedding_models["bge"].vectorize(query)
            
            # 并行检索不同模态的数据
            tasks = [
                self.vector_stores["text"].asimilarity_search_by_vector(bge_embedding, top_k),
                self.vector_stores["image"].asimilarity_search_by_vector(clip_embedding, top_k)
            ]
            
            text_results, image_results = await asyncio.gather(*tasks)
            
            # 合并结果
            return self._merge_modal_results(text_results, image_results, top_k)
            
        # 如果是图像查询
        elif modal_type == "image":
            # 使用CLIP图像编码器生成向量
            image_embedding = self.embedding_models["clip"].encode_image(query)
            
            # 检索图像相关的数据
            results = await self.vector_stores["image"].asimilarity_search_by_vector(
                image_embedding, top_k
            )
            
            return results

五、用户反馈优化机制

5.1 反馈数据收集与存储

为进一步提升RAG系统的检索质量,我们可以加入用户反馈机制:

python 复制代码
import sqlite3
import logging
from datetime import datetime
from typing import List, Dict, Any

class FeedbackOptimizer:
    """基于用户反馈优化RAG检索结果。"""
    
    def __init__(self, vector_store, feedback_db=None):
        """初始化反馈优化器。"""
        self.vector_store = vector_store
        self.feedback_db = feedback_db or self._initialize_feedback_db()
        self.logger = logging.getLogger(self.__class__.__name__)
    
    def _initialize_feedback_db(self):
        """初始化反馈数据库。"""
        conn = sqlite3.connect('data/feedback.db')
        c = conn.cursor()
        
        # 创建反馈表
        c.execute('''
        CREATE TABLE IF NOT EXISTS feedback (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            query_text TEXT,
            doc_id TEXT,
            is_relevant INTEGER,
            timestamp TEXT
        )
        ''')
        
        # 创建查询日志表
        c.execute('''
        CREATE TABLE IF NOT EXISTS query_log (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            query_text TEXT,
            results_count INTEGER,
            timestamp TEXT
        )
        ''')
        
        conn.commit()
        return conn
    
    def record_feedback(self, query, doc_id, is_relevant):
        """记录用户反馈。"""
        cursor = self.feedback_db.cursor()
        cursor.execute(
            "INSERT INTO feedback (query_text, doc_id, is_relevant, timestamp) VALUES (?, ?, ?, datetime('now'))",
            (query, doc_id, 1 if is_relevant else 0)
        )
        self.feedback_db.commit()
    
    def record_query(self, query, results_count):
        """记录查询日志。"""
        cursor = self.feedback_db.cursor()
        cursor.execute(
            "INSERT INTO query_log (query_text, results_count, timestamp) VALUES (?, ?, datetime('now'))",
            (query, results_count)
        )
        self.feedback_db.commit()
    
    def get_relevance_feedback_for_query(self, query, limit=10):
        """获取特定查询的相关性反馈。"""
        cursor = self.feedback_db.cursor()
        cursor.execute(
            "SELECT doc_id, is_relevant, COUNT(*) FROM feedback WHERE query_text = ? GROUP BY doc_id, is_relevant",
            (query,)
        )
        return cursor.fetchall()
    
    def optimize_results(self, query, initial_results, top_k=5):
        """基于历史反馈优化检索结果。"""
        # 获取查询的历史反馈
        feedback = self.get_relevance_feedback_for_query(query)
        
        # 如果没有反馈,直接返回原始结果
        if not feedback:
            return initial_results[:top_k]
        
        # 将反馈转换为字典形式,方便查找
        feedback_dict = {}
        for doc_id, is_relevant, count in feedback:
            if doc_id not in feedback_dict:
                feedback_dict[doc_id] = {"relevant": 0, "irrelevant": 0}
            
            if is_relevant:
                feedback_dict[doc_id]["relevant"] += count
            else:
                feedback_dict[doc_id]["irrelevant"] += count
        
        # 应用反馈调整分数
        adjusted_results = []
        for doc, score in initial_results:
            doc_id = doc.metadata.get("id")
            
            # 计算反馈调整因子
            adjustment = 0
            if doc_id in feedback_dict:
                # 相关反馈提高分数,不相关反馈降低分数
                relevant = feedback_dict[doc_id]["relevant"]
                irrelevant = feedback_dict[doc_id]["irrelevant"]
                
                # 根据反馈比例调整分数
                if relevant + irrelevant > 0:
                    adjustment = (relevant - irrelevant) / (relevant + irrelevant) * 0.2
            
            # 应用调整后的分数
            adjusted_score = min(1.0, max(0.0, score + adjustment))
            adjusted_results.append((doc, adjusted_score))
        
        # 按调整后的分数排序
        adjusted_results.sort(key=lambda x: x[1], reverse=True)
        
        # 记录本次查询
        self.record_query(query, len(initial_results))
        
        return adjusted_results[:top_k]

5.2 反馈界面实现

为了收集用户反馈,我们需要在前端界面添加反馈按钮:

python 复制代码
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from typing import List, Optional

router = APIRouter()

class FeedbackRequest(BaseModel):
    query: str
    doc_id: str
    is_relevant: bool

@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
    """提交文档相关性反馈。"""
    try:
        feedback_optimizer = get_feedback_optimizer()
        feedback_optimizer.record_feedback(
            request.query, 
            request.doc_id, 
            request.is_relevant
        )
        return {"status": "success", "message": "反馈已记录"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"提交反馈失败: {str(e)}")

@router.get("/feedback/stats")
async def get_feedback_stats(query: Optional[str] = None):
    """获取反馈统计信息。"""
    try:
        feedback_optimizer = get_feedback_optimizer()
        if query:
            # 获取特定查询的反馈
            stats = feedback_optimizer.get_relevance_feedback_for_query(query)
        else:
            # 获取全局反馈统计
            stats = feedback_optimizer.get_global_feedback_stats()
        return {"status": "success", "data": stats}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"获取反馈统计失败: {str(e)}")

六、查询意图分析器扩展

Read file: src/chains/processors/recommendation.py

6.1 自定义意图处理器

我们已经有了推荐查询处理器,现在我们可以扩展更多专用意图处理器:

python 复制代码
import numpy as np
from typing import List, Dict, Any
from src.data_processing.vectorization.factory import VectorizationFactory

class QueryIntentClassifier:
    """查询意图分类器,基于向量相似度识别查询意图。"""
    
    def __init__(self, embedding_model=None):
        """初始化意图分类器。"""
        # 初始化向量化模型
        self.embedding_model = embedding_model or VectorizationFactory.create_vectorizer("bge-m3")
        
        # 定义意图和示例
        self.intent_examples = {
            "比较分析": [
                "A和B有什么区别?",
                "哪一个更好,X还是Y?",
                "比较一下P和Q的优缺点"
            ],
            "因果解释": [
                "为什么会出现这种情况?",
                "导致X的原因是什么?",
                "这个问题的根源是什么?"
            ],
            "列举信息": [
                "列出所有的X",
                "有哪些方法可以做Y?",
                "X包含哪些组成部分?"
            ],
            "概念解释": [
                "什么是X?",
                "X的定义是什么?",
                "如何理解Y概念?"
            ],
            "数据统计": [
                "X的平均值是多少?",
                "Y的增长率是多少?",
                "Z的分布情况怎样?"
            ],
            "操作指导": [
                "如何做X?",
                "执行Y的步骤是什么?",
                "使用Z的方法有哪些?"
            ],
            "推荐建议": [
                "推荐几款好用的X",
                "有什么适合Y的工具?",
                "帮我选择一个合适的Z"
            ]
        }
        
        # 预计算意图向量
        self.intent_vectors = self._compute_intent_vectors()
        
    def _compute_intent_vectors(self):
        """预计算每种意图的平均向量表示。"""
        intent_vectors = {}
        for intent, examples in self.intent_examples.items():
            # 获取每个示例的向量
            vectors = self.embedding_model.batch_vectorize(examples)
            # 计算平均向量
            avg_vector = np.mean(vectors, axis=0)
            # 标准化
            avg_vector = avg_vector / np.linalg.norm(avg_vector)
            # 存储意图向量
            intent_vectors[intent] = avg_vector
        return intent_vectors
        
    def classify_intent(self, query: str):
        """分类查询意图。"""
        # 向量化查询
        query_vector = self.embedding_model.vectorize(query)
        
        # 计算与各个意图的相似度
        similarities = {}
        for intent, vector in self.intent_vectors.items():
            similarity = np.dot(query_vector, vector)
            similarities[intent] = similarity
            
        # 找出最相似的意图
        max_intent = max(similarities, key=similarities.get)
        max_similarity = similarities[max_intent]
        
        # 如果最高相似度低于阈值,认为是一般信息查询
        if max_similarity < 0.5:
            return "一般信息查询", max_similarity
            
        return max_intent, max_similarity

6.2 自定义比较分析处理器

每种意图需要专门的处理器,以比较分析为例:

python 复制代码
import re
from typing import List, Dict, Any
from .base import QueryProcessor

class ComparisonQueryProcessor(QueryProcessor):
    """比较分析查询处理器,处理涉及比较的查询。"""
    
    def process(self, query: str, documents: List[Any], **kwargs) -> Dict[str, Any]:
        """处理比较分析查询。"""
        if not documents:
            return {
                "answer": "抱歉,没有找到相关的比较信息。",
                "sources": []
            }
        
        # 提取需要比较的实体
        entities = self._extract_comparison_entities(query)
        
        if len(entities) < 2:
            # 如果没有识别出多个实体,尝试从文档中提取
            entities = self._extract_entities_from_documents(documents)
        
        # 为每个实体收集信息
        entity_info = self._collect_entity_information(entities, documents)
        
        # 生成比较表格
        comparison_table = self._generate_comparison_table(entity_info)
        
        # 生成比较分析结论
        conclusion = self._generate_comparison_conclusion(entity_info, query)
        
        # 组合最终答案
        answer = f"根据您的比较请求,以下是{', '.join(entities)}的对比分析:\n\n{comparison_table}\n\n{conclusion}"
        
        return {
            "answer": answer,
            "sources": documents,
            "entities": entities,
            "comparison_table": comparison_table
        }
    
    def can_handle(self, query: str) -> bool:
        """判断是否是比较分析查询。"""
        # 比较关键词
        comparison_keywords = ["比较", "区别", "差异", "优缺点", "对比", "相比", "VS", "好坏"]
        # 比较句式模式
        comparison_patterns = [
            r"(.+)和(.+)的区别",
            r"(.+)与(.+)的(差异|不同)",
            r"(.+)相比(.+)怎么样",
            r"(.+)还是(.+)更好"
        ]
        
        # 检查关键词
        if any(keyword in query for keyword in comparison_keywords):
            return True
        
        # 检查句式模式
        for pattern in comparison_patterns:
            if re.search(pattern, query):
                return True
                
        return False
    
    def _extract_comparison_entities(self, query: str) -> List[str]:
        """从查询中提取需要比较的实体。"""
        # 通过正则表达式提取实体
        patterns = [
            r"(.+)和(.+)的区别",
            r"(.+)与(.+)的(差异|不同)",
            r"(.+)相比(.+)怎么样",
            r"(.+)还是(.+)更好"
        ]
        
        for pattern in patterns:
            match = re.search(pattern, query)
            if match:
                # 提取匹配的实体
                entities = [match.group(1).strip(), match.group(2).strip()]
                return entities
        
        # 使用分词和NER提取实体
        # ...
        
        return []
    
    def _extract_entities_from_documents(self, documents: List[Any]) -> List[str]:
        """从文档中提取可能的比较实体。"""
        # 实现从文档中提取实体的逻辑
        # ...
        return []
    
    def _collect_entity_information(self, entities: List[str], documents: List[Any]) -> Dict[str, Dict]:
        """为每个实体从文档中收集信息。"""
        entity_info = {}
        
        for entity in entities:
            entity_info[entity] = {
                "advantages": [],
                "disadvantages": [],
                "features": {},
                "mentions": 0
            }
            
            # 从文档中提取该实体的信息
            for doc in documents:
                text = doc.page_content
                
                # 统计提及次数
                if entity in text:
                    entity_info[entity]["mentions"] += text.count(entity)
                
                # 提取优点
                advantages = self._extract_advantages(text, entity)
                entity_info[entity]["advantages"].extend(advantages)
                
                # 提取缺点
                disadvantages = self._extract_disadvantages(text, entity)
                entity_info[entity]["disadvantages"].extend(disadvantages)
                
                # 提取特性
                features = self._extract_features(text, entity)
                for feature, value in features.items():
                    if feature in entity_info[entity]["features"]:
                        entity_info[entity]["features"][feature].append(value)
                    else:
                        entity_info[entity]["features"][feature] = [value]
        
        return entity_info
    
    def _extract_advantages(self, text: str, entity: str) -> List[str]:
        """提取实体的优点。"""
        patterns = [
            f"{entity}的优点",
            f"{entity}的好处",
            f"{entity}的优势"
        ]
        # ...提取逻辑
        return []
    
    def _extract_disadvantages(self, text: str, entity: str) -> List[str]:
        """提取实体的缺点。"""
        # ...提取逻辑
        return []
    
    def _extract_features(self, text: str, entity: str) -> Dict[str, str]:
        """提取实体的特性。"""
        # ...提取逻辑
        return {}
    
    def _generate_comparison_table(self, entity_info: Dict[str, Dict]) -> str:
        """生成比较表格。"""
        # 构建表格头部
        entities = list(entity_info.keys())
        table = f"| 特性 | {' | '.join(entities)} |\n"
        table += f"| --- | {' | '.join(['---' for _ in entities])} |\n"
        
        # 添加共同特性行
        all_features = set()
        for entity, info in entity_info.items():
            all_features.update(info["features"].keys())
        
        for feature in sorted(all_features):
            row = f"| {feature} | "
            for entity in entities:
                if feature in entity_info[entity]["features"]:
                    values = entity_info[entity]["features"][feature]
                    row += f"{values[0]} | "
                else:
                    row += "- | "
            table += row + "\n"
        
        # 添加优缺点行
        table += f"| 优点 | {' | '.join([', '.join(info['advantages'][:3]) or '-' for _, info in entity_info.items()])} |\n"
        table += f"| 缺点 | {' | '.join([', '.join(info['disadvantages'][:3]) or '-' for _, info in entity_info.items()])} |\n"
        
        return table
    
    def _generate_comparison_conclusion(self, entity_info: Dict[str, Dict], query: str) -> str:
        """生成比较分析结论。"""
        entities = list(entity_info.keys())
        
        if len(entities) < 2:
            return "无法生成比较结论,找不到足够的实体信息。"
        
        # 基于优缺点和特性生成结论
        # ...生成逻辑
        
        return "根据以上对比,每个选项都有各自的优缺点,具体选择取决于您的具体需求和场景。"

七、实际案例分析

7.1 自定义法律文档处理器

法律文档有其特殊性,以下是一个专门处理法律文档的处理器示例:

python 复制代码
import re
from typing import List, Dict, Any, Optional
from langchain.schema import Document
from src.data_processing.processors.base import BaseDocumentProcessor, ProcessorConfig

class LegalDocumentProcessor(BaseDocumentProcessor):
    """法律文档处理器,专门处理法律文书。"""
    
    def __init__(self, config: Optional[ProcessorConfig] = None):
        """初始化法律文档处理器。"""
        super().__init__(config)
        # 法律术语词典
        self.legal_terms = self._load_legal_terms()
        
    def _load_legal_terms(self):
        """加载法律术语词典。"""
        # 实际应用中应从外部文件加载
        return {
            "原告": "起诉方,请求法院裁判的一方",
            "被告": "被起诉方,被请求法院裁判的一方",
            "诉讼": "通过法院解决纠纷的法律程序",
            # 更多法律术语...
        }
        
    def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:
        """处理法律文档文件。"""
        # 提取文本
        if mime_type == "application/pdf":
            text = self._extract_pdf_text(file_content)
        elif mime_type == "application/msword" or mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
            text = self._extract_word_text(file_content)
        else:
            text = file_content.decode(self.config.encoding, errors='ignore')
        
        # 处理提取出的文本
        return self.process_text(text, {"source": filename, "mime_type": mime_type})
    
    def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
        """处理法律文档文本。"""
        # 解析文档结构
        sections = self._parse_legal_document_structure(text)
        
        # 为每个部分创建文档
        documents = []
        for section_name, section_content in sections.items():
            # 创建元数据
            section_metadata = {
                **(metadata or {}),
                "section": section_name,
                "document_type": "legal"
            }
            
            # 为法律术语添加注解
            annotated_content = self._annotate_legal_terms(section_content)
            
            # 创建文档
            doc = Document(
                page_content=annotated_content,
                metadata=section_metadata
            )
            documents.append(self._add_metadata(doc))
        
        return documents
    
    def _parse_legal_document_structure(self, text: str) -> Dict[str, str]:
        """解析法律文档结构。"""
        sections = {}
        
        # 常见法律文书结构
        section_patterns = [
            (r"案号[::]\s*(.+?)(?=\n)", "case_number"),
            (r"原告[::]\s*(.+?)(?=\n被告)", "plaintiff"),
            (r"被告[::]\s*(.+?)(?=\n)", "defendant"),
            (r"诉讼请求[::]\s*(.+?)(?=\n)", "claims"),
            (r"事实与理由[::]\s*(.+?)(?=\n)", "facts_and_reasons"),
            (r"裁判结果[::]\s*(.+?)(?=\n)", "judgment"),
            (r"裁判理由[::]\s*(.+?)(?=\n)", "reasoning")
        ]
        
        # 提取各部分内容
        for pattern, section_name in section_patterns:
            match = re.search(pattern, text, re.DOTALL)
            if match:
                sections[section_name] = match.group(1).strip()
        
        # 如果没有提取到结构化内容,按段落分割
        if not sections:
            paragraphs = re.split(r'\n\s*\n', text)
            for i, para in enumerate(paragraphs):
                sections[f"paragraph_{i+1}"] = para.strip()
        
        return sections
    
    def _annotate_legal_terms(self, text: str) -> str:
        """为法律术语添加注解。"""
        annotated_text = text
        for term, definition in self.legal_terms.items():
            # 术语可能在文本中多次出现,只添加一次注解
            if term in annotated_text:
                # 在第一次出现的地方添加注解
                term_with_note = f"{term}[注: {definition}]"
                annotated_text = annotated_text.replace(term, term_with_note, 1)
        
        return annotated_text

7.2 金融数据向量化器

针对金融数据的特殊性,我们可以实现专用的向量化器:

python 复制代码
import numpy as np
from typing import List, Dict, Any
from src.data_processing.vectorization.base import BaseVectorizer
import re

class FinancialDataVectorizer(BaseVectorizer):
    """金融数据向量化器,专门处理金融文本和数据。"""
    
    def __init__(self, cache_dir='./cache/vectorization', 
                 base_model='bge-m3', 
                 numerical_weight=0.3):
        """初始化金融数据向量化器。"""
        super().__init__(cache_dir)
        # 使用基础模型作为语义向量化基础
        from src.data_processing.vectorization.factory import VectorizationFactory
        self.base_vectorizer = VectorizationFactory.create_vectorizer(base_model)
        
        # 数值特征权重
        self.numerical_weight = numerical_weight
        
        # 金融术语词典
        self.financial_terms = self._load_financial_terms()
    
    def _load_financial_terms(self):
        """加载金融术语词典。"""
        # 实际应用中应从外部文件加载
        return [
            "股票", "债券", "基金", "期货", "期权", "保险", "理财",
            "利率", "汇率", "通货膨胀", "GDP", "PPI", "CPI", "PMI",
            "资产", "负债", "股东", "投资", "风险", "收益", "波动"
        ]
    
    def vectorize(self, text: str) -> np.ndarray:
        """将金融文本转换为向量。"""
        # 1. 提取数值特征
        numerical_features = self._extract_numerical_features(text)
        
        # 2. 获取基础语义向量
        semantic_vector = self.base_vectorizer.vectorize(text)
        
        # 3. 融合数值特征和语义向量
        combined_vector = self._combine_features(semantic_vector, numerical_features)
        
        return combined_vector
    
    def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:
        """批量将金融文本转换为向量。"""
        results = []
        for text in texts:
            vector = self.vectorize(text)
            results.append(vector)
        return results
    
    def _extract_numerical_features(self, text: str) -> Dict[str, float]:
        """提取文本中的数值特征。"""
        features = {}
        
        # 提取百分比
        percentage_pattern = r'(\d+.?\d*)%'
        percentages = re.findall(percentage_pattern, text)
        if percentages:
            features['percentage_avg'] = sum(float(p) for p in percentages) / len(percentages)
            features['percentage_count'] = len(percentages)
        
        # 提取金额
        amount_pattern = r'(\d+.?\d*)\s*(万|亿|千|百万|美元|元|美金|英镑|欧元)'
        amounts = re.findall(amount_pattern, text)
        if amounts:
            # 转换为标准单位
            std_amounts = []
            for amount, unit in amounts:
                value = float(amount)
                if unit == '万':
                    value *= 10000
                elif unit == '亿':
                    value *= 100000000
                # 其他单位转换...
                std_amounts.append(value)
            
            if std_amounts:
                features['amount_avg'] = sum(std_amounts) / len(std_amounts)
                features['amount_max'] = max(std_amounts)
                features['amount_count'] = len(std_amounts)
        
        # 提取日期差(如时间跨度)
        # ...日期提取和计算逻辑
        
        return features
    
    def _combine_features(self, semantic_vector: np.ndarray, numerical_features: Dict[str, float]) -> np.ndarray:
        """将数值特征与语义向量融合。"""
        # 1. 将数值特征标准化
        if not numerical_features:
            return semantic_vector
        
        # 创建固定长度的数值特征向量
        numerical_vector = np.zeros(10)  # 预留10个维度给数值特征
        
        # 填充数值特征
        feature_index = {
            'percentage_avg': 0,
            'percentage_count': 1,
            'amount_avg': 2,
            'amount_max': 3,
            'amount_count': 4,
            # 其他特征索引...
        }
        
        for feature, value in numerical_features.items():
            if feature in feature_index:
                numerical_vector[feature_index[feature]] = value
        
        # 标准化数值特征
        num_max = np.max(numerical_vector) if np.max(numerical_vector) > 0 else 1.0
        numerical_vector = numerical_vector / num_max
        
        # 2. 调整原始向量维度以适应数值特征
        original_dim = semantic_vector.shape[0]
        new_dim = original_dim - len(numerical_vector)  # 确保最终维度不变
        
        # 对语义向量应用PCA或截断,使其维度减少
        from sklearn.decomposition import PCA
        pca = PCA(n_components=new_dim)
        semantic_reduced = pca.fit_transform(semantic_vector.reshape(1, -1)).flatten()
        
        # 3. 融合向量
        # 语义向量权重
        semantic_weight = 1 - self.numerical_weight
        
        # 拼接向量
        combined = np.concatenate([
            semantic_reduced * semantic_weight,
            numerical_vector * self.numerical_weight
        ])
        
        # 标准化最终向量
        combined = combined / np.linalg.norm(combined)
        
        return combined

八、总结与展望

8.1 扩展RAG系统的最佳实践

通过本文的介绍,我们展示了如何在RAG系统中实现高度自定义的功能扩展。总结最佳实践如下:

  1. 抽象基类设计:使用抽象基类定义统一接口,确保各组件遵循相同约定
  2. 工厂模式解耦:通过工厂模式创建组件实例,降低组件间耦合
  3. 配置驱动初始化:使用环境变量和配置文件驱动组件初始化,提高灵活性
  4. 专业化处理器:针对特定领域或文档类型开发专用处理器,提高处理质量
  5. 混合检索策略:结合多种检索方法,平衡关键词匹配和语义检索的优势
  6. 用户反馈闭环:收集用户反馈并应用于结果优化,持续提升检索质量
  7. 查询意图分析:根据查询意图选择专用处理器,提供更精准的回答

8.2 未来扩展方向

RAG系统仍有多个可以探索的扩展方向:

  1. 多模态处理:扩展到图像、音频、视频等多模态数据处理
  2. 时间感知检索:支持时间序列数据和趋势分析查询
  3. 自适应检索:根据用户历史查询行为自动调整检索策略
  4. 多源融合:支持多数据源查询结果的智能融合
  5. 可解释检索:提供检索结果的可解释性,帮助用户理解结果来源
  6. 离线预计算:为常见查询路径预计算结果,提升响应速度

下篇预告:《RAG系统效能提升的七个关键实践》将详解:

  • 分块策略优化(表格/代码/文本差异化处理)
  • 缓存机制设计(向量缓存/结果缓存/模型缓存)
  • 异步处理实现(文档处理流水线优化)
  • 安全防护方案(输入过滤/权限控制)
  • 效果评估方法(检索准确率/响应时间/QPS)

项目代码库:github.com/bikeread/ra...

相关推荐
徐小黑ACG19 分钟前
GO语言 使用protobuf
开发语言·后端·golang·protobuf
战族狼魂3 小时前
CSGO 皮肤交易平台后端 (Spring Boot) 代码结构与示例
java·spring boot·后端
杉之4 小时前
常见前端GET请求以及对应的Spring后端接收接口写法
java·前端·后端·spring·vue
hycccccch5 小时前
Canal+RabbitMQ实现MySQL数据增量同步
java·数据库·后端·rabbitmq
bobz9655 小时前
k8s 怎么提供虚拟机更好
后端
bobz9656 小时前
nova compute 如何创建 ovs 端口
后端
用键盘当武器的秋刀鱼6 小时前
springBoot统一响应类型3.5.1版本
java·spring boot·后端
Asthenia04127 小时前
从迷宫到公式:为 NFA 构造正规式
后端
Asthenia04127 小时前
像整理玩具一样:DFA 化简和状态等价性
后端
Asthenia04127 小时前
编译原理:打包思维-NFA 怎么变成 DFA
后端