人工智能基础知识笔记三十四:提升RAG效果的几种技术

检索增强生成技术通过在生成阶段引入外部知识,有效缓解了大语言模型"幻觉"和知识陈旧的问题。 然而,传统的RAG系统存在两个普遍痛点:检索的盲目性利用的粗糙性。系统通常盲目信任检索器的结果,无论相关与否都全部塞给生成器;同时,将整篇文档作为上下文,其中大量无关信息反而会"稀释"关键知识,误导生成过程。

在博客 https://blog.csdn.net/jimmyleeee/article/details/156648886 里介绍了可以通过Rerank的技术把RAG获得的结果进行重新排序,把和目标关系最密切的结果排在前面。但是,即使使用Rerank也不能保证RAG的匹配的结果的质量一定就是满足需求的。 本文会介绍几种RAG的技术,可以提高RAG的效果。针对这些问题,研究社区正从"被动接受检索结果"向"主动评估、控制和优化"的方向演进。以下将深入剖析三种前沿技术------Corrective RAG、Reflexion和Self-RAG,它们分别从结果纠错、事后反思、过程自省三个不同维度,共同绘制出构建更健壮、更精准RAG系统的技术蓝图。

1、Corrective RAG:为检索结果装上"质检员"与"净化器"

Corrective RAG (简称CRAG)的核心思想是:不盲目信任任何一次检索,而是先对检索结果进行整体评估和必要修正,再将"净化后"的知识提供给生成器。 CRAG的核心创新在于三重保障机制 :评估检索质量、动态补充知识源、精细化处理文档内容。论文参考:https://arxiv.org/pdf/2401.15884.pdf

1.1 系统流程:

python 复制代码
输入查询 → 初始检索 → 检索评估器 → 置信度分类
    ↓
[高置信度] → 文档净化 → 精炼知识 → 生成响应
    ↓
[低置信度] → 触发网络搜索 → 融合结果 → 生成响应
    ↓  
[模糊置信度] → 净化+搜索 → 加权融合 → 生成响应

1.2 CRAG的本质

增强系统的风险控制能力。它通过"评估-分流-净化/扩充"的流水线,将单点检索失败的风险降到最低,尤其适合对事实准确性要求极高的场景(如问答、报告生成)。其"即插即用"的设计,可以无缝集成到现有RAG管道中。

1.3 "质检"环节:轻量级检索评估器

CRAG在流程前端引入了一个轻量级的检索评估器(基于T5-large等较小模型构建)。它的任务不是重新排序,而是对当前检索到的整组文档进行整体质量评估,并输出一个置信度分数。这个分数决定了系统下一步的"动作"。

  1. "分流"与"净化"行动

    根据评估置信度,系统触发三类行动:

    • Correct(正确) :当认为检索结果整体相关可靠时,进入 "分解-重组"算法。该算法将长文档拆解成更细的知识片段,过滤掉无关文本,再重新组合成精炼的知识条。这相当于从一堆矿石中精准提炼出金属,去除了大量杂质。

    • Incorrect(错误) :当判断检索结果基本不相关时,果断丢弃 初始结果。为了避免静态知识库的局限,CRAG在此触发大规模网络搜索作为补充,从更广阔、动态的互联网中获取新知进行修正。

    • Ambiguous(模糊) :在无法明确判断时,采取软性策略,同时使用净化后的本地检索结果和网络搜索的补充结果。

1.4 Python示例代码:

python 复制代码
import numpy as np
from typing import List, Dict, Tuple, Optional
from enum import Enum
import requests
import json

class ConfidenceLevel(Enum):
    CORRECT = "correct"
    INCORRECT = "incorrect"
    AMBIGUOUS = "ambiguous"

class CRAGRetrievalEvaluator:
    """轻量级检索评估器"""
    
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        from sentence_transformers import SentenceTransformer, util
        self.model = SentenceTransformer(model_name)
        self.threshold_correct = 0.7  # 高置信度阈值
        self.threshold_incorrect = 0.3  # 低置信度阈值
    
    def evaluate_retrieval_quality(self, query: str, retrieved_docs: List[str]) -> Dict:
        """
        评估检索结果的整体质量
        
        参数:
            query: 用户查询
            retrieved_docs: 检索到的文档列表
            
        返回:
            包含置信度和评估详情的字典
        """
        # 将查询和文档编码为向量
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        doc_embeddings = self.model.encode(retrieved_docs, convert_to_tensor=True)
        
        # 计算相似度
        similarities = util.cos_sim(query_embedding, doc_embeddings)[0].cpu().numpy()
        
        # 计算整体质量指标
        avg_similarity = float(np.mean(similarities))
        max_similarity = float(np.max(similarities))
        relevance_std = float(np.std(similarities))
        
        # 综合置信度评分
        confidence_score = 0.6 * avg_similarity + 0.4 * max_similarity - 0.2 * relevance_std
        confidence_score = max(0.0, min(1.0, confidence_score))
        
        # 确定置信度等级
        if confidence_score >= self.threshold_correct:
            confidence_level = ConfidenceLevel.CORRECT
        elif confidence_score <= self.threshold_incorrect:
            confidence_level = ConfidenceLevel.INCORRECT
        else:
            confidence_level = ConfidenceLevel.AMBIGUOUS
        
        return {
            "confidence_score": confidence_score,
            "confidence_level": confidence_level,
            "avg_similarity": avg_similarity,
            "max_similarity": max_similarity,
            "similarity_std": relevance_std,
            "detailed_similarities": similarities.tolist()
        }

class DocumentRefiner:
    """文档净化器 - 实现分解-重组算法"""
    
    def __init__(self, chunk_size: int = 200, chunk_overlap: int = 50):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
    
    def decompose_document(self, document: str) -> List[str]:
        """将文档分解为信息单元"""
        # 简单的按句子分割,实际可使用更复杂的NLP方法
        import re
        sentences = re.split(r'(?<=[.!?])\s+', document)
        
        chunks = []
        current_chunk = ""
        
        for sentence in sentences:
            if len(current_chunk) + len(sentence) <= self.chunk_size:
                current_chunk += " " + sentence if current_chunk else sentence
            else:
                if current_chunk:
                    chunks.append(current_chunk)
                current_chunk = sentence
        
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks
    
    def filter_irrelevant_chunks(self, query: str, chunks: List[str], 
                                 similarity_threshold: float = 0.5) -> List[str]:
        """过滤与查询不相关的信息单元"""
        from sentence_transformers import SentenceTransformer, util
        
        model = SentenceTransformer("all-MiniLM-L6-v2")
        query_embedding = model.encode(query, convert_to_tensor=True)
        chunk_embeddings = model.encode(chunks, convert_to_tensor=True)
        
        similarities = util.cos_sim(query_embedding, chunk_embeddings)[0].cpu().numpy()
        
        relevant_chunks = [
            chunk for chunk, sim in zip(chunks, similarities) 
            if sim >= similarity_threshold
        ]
        
        return relevant_chunks
    
    def recompose_knowledge_strips(self, relevant_chunks: List[str]) -> str:
        """将相关片段重组成精炼的知识条"""
        # 简单的拼接,可添加更智能的摘要或重组逻辑
        return "\n".join(relevant_chunks)
    
    def refine_document(self, query: str, document: str) -> str:
        """完整的文档净化流程"""
        chunks = self.decompose_document(document)
        relevant_chunks = self.filter_irrelevant_chunks(query, chunks)
        refined_knowledge = self.recompose_knowledge_strips(relevant_chunks)
        return refined_knowledge

class WebSearchAugmenter:
    """网络搜索增强器"""
    
    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key
    
    def search_web(self, query: str, num_results: int = 5) -> List[Dict]:
        """
        执行网络搜索
        
        注意: 实际使用时需要替换为真实的搜索API
        如: SerpAPI, Google Custom Search, DuckDuckGo等
        """
        # 这里使用模拟数据,实际应用中需接入真实API
        if self.api_key:
            # 示例: 使用SerpAPI
            # params = {
            #     "q": query,
            #     "api_key": self.api_key,
            #     "num": num_results
            # }
            # response = requests.get("https://serpapi.com/search", params=params)
            # results = response.json().get("organic_results", [])
            pass
        
        # 模拟返回结果
        mock_results = [
            {
                "title": f"关于{query}的搜索结果1",
                "snippet": f"这是关于{query}的相关信息摘要1。包含关键事实和数据。",
                "link": "https://example.com/result1"
            },
            {
                "title": f"关于{query}的搜索结果2",
                "snippet": f"这是关于{query}的深入分析2。提供了不同的视角和细节。",
                "link": "https://example.com/result2"
            }
        ]
        
        return mock_results[:num_results]

class CRAGOrchestrator:
    """CRAG总控协调器"""
    
    def __init__(self, retriever, generator, use_web_search: bool = True):
        self.retriever = retriever
        self.generator = generator
        self.evaluator = CRAGRetrievalEvaluator()
        self.refiner = DocumentRefiner()
        self.web_searcher = WebSearchAugmenter() if use_web_search else None
    
    def process_query(self, query: str, top_k: int = 5) -> Dict:
        """处理查询的完整CRAG流程"""
        
        # 步骤1: 初始检索
        print(f"步骤1: 检索与查询 '{query}' 相关的文档...")
        retrieved_docs = self.retriever.retrieve(query, top_k)
        
        # 步骤2: 评估检索质量
        print("步骤2: 评估检索结果质量...")
        evaluation = self.evaluator.evaluate_retrieval_quality(query, retrieved_docs)
        confidence_level = evaluation["confidence_level"]
        print(f"  置信度等级: {confidence_level.value} (分数: {evaluation['confidence_score']:.3f})")
        
        # 步骤3: 根据置信度执行不同操作
        if confidence_level == ConfidenceLevel.CORRECT:
            print("步骤3: 执行'正确'操作 - 净化本地检索结果")
            refined_docs = [
                self.refiner.refine_document(query, doc)
                for doc in retrieved_docs
            ]
            knowledge_source = "\n\n".join(refined_docs)
            source_type = "refined_local"
            
        elif confidence_level == ConfidenceLevel.INCORRECT:
            print("步骤3: 执行'错误'操作 - 启动网络搜索")
            if self.web_searcher:
                web_results = self.web_searcher.search_web(query, top_k)
                knowledge_source = "\n\n".join([
                    f"{r['title']}: {r['snippet']}" for r in web_results
                ])
                source_type = "web_search"
            else:
                knowledge_source = retrieved_docs[0] if retrieved_docs else ""
                source_type = "fallback_local"
                
        else:  # AMBIGUOUS
            print("步骤3: 执行'模糊'操作 - 融合本地与网络结果")
            # 净化本地结果
            refined_local = [
                self.refiner.refine_document(query, doc)
                for doc in retrieved_docs[:2]  # 取前2个进行净化
            ]
            
            # 获取网络结果
            web_results = []
            if self.web_searcher:
                web_results = self.web_searcher.search_web(query, 3)
                web_knowledge = [
                    f"{r['title']}: {r['snippet']}" for r in web_results
                ]
            else:
                web_knowledge = []
            
            # 加权融合(简单拼接,实际可实现更复杂的融合策略)
            knowledge_source = "本地知识:\n" + "\n\n".join(refined_local[:2])
            if web_knowledge:
                knowledge_source += "\n\n网络补充:\n" + "\n\n".join(web_knowledge[:2])
            source_type = "hybrid"
        
        # 步骤4: 生成最终响应
        print("步骤4: 基于优化后的知识生成响应...")
        prompt = self._construct_prompt(query, knowledge_source)
        response = self.generator.generate(prompt)
        
        return {
            "query": query,
            "confidence_evaluation": evaluation,
            "knowledge_source": knowledge_source,
            "source_type": source_type,
            "response": response,
            "retrieved_docs": retrieved_docs
        }
    
    def _construct_prompt(self, query: str, knowledge: str) -> str:
        """构建生成提示"""
        prompt_template = """基于以下信息回答问题。如果信息不足,请说明哪些方面信息不足。

相关参考信息:
{knowledge}

问题: {query}

请提供详细、准确的回答,并注明信息的主要来源。"""
        return prompt_template.format(knowledge=knowledge, query=query)

# 示例使用的简单检索器和生成器(实际应用需替换为真实组件)
class SimpleRetriever:
    """简单检索器示例"""
    
    def __init__(self, corpus: Dict[str, str]):
        self.corpus = corpus
        from sentence_transformers import SentenceTransformer
        self.model = SentenceTransformer("all-MiniLM-L6-v2")
        
        # 预编码所有文档
        self.doc_ids = list(corpus.keys())
        self.doc_texts = [corpus[doc_id] for doc_id in self.doc_ids]
        self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_tensor=True)
    
    def retrieve(self, query: str, top_k: int = 5) -> List[str]:
        from sentence_transformers import util
        
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        similarities = util.cos_sim(query_embedding, self.doc_embeddings)[0]
        
        # 获取最相似的文档索引
        top_indices = similarities.argsort(descending=True)[:top_k].cpu().numpy()
        
        return [self.doc_texts[idx] for idx in top_indices]

class SimpleGenerator:
    """简单生成器示例(实际应使用LLM)"""
    
    def generate(self, prompt: str) -> str:
        # 这里返回模拟响应,实际应接入真实LLM
        return f"基于提供的参考信息,回答如下:\n\n这是一个关于'{prompt.split(':')[-1].strip()}'的示例回答。在实际CRAG系统中,这里会是由大型语言模型生成的、基于上下文的详细回答。"

# 使用示例
def main():
    # 1. 准备示例知识库
    sample_corpus = {
        "doc1": "Python是一种高级编程语言,由Guido van Rossum于1991年创建。它支持多种编程范式,包括面向对象、命令式、函数式和过程式编程。",
        "doc2": "机器学习是人工智能的一个分支,使计算机能够从数据中学习而无需明确编程。常见的算法包括线性回归、决策树和神经网络。",
        "doc3": "检索增强生成(RAG)结合了信息检索和文本生成技术,通过检索相关文档来增强大型语言模型的生成能力,减少幻觉。",
        "doc4": "深度学习是机器学习的一个子领域,基于人工神经网络。它在图像识别、自然语言处理和语音识别等领域取得了突破性进展。",
        "doc5": "CRAG(Corrective RAG)是一种改进的RAG框架,包含检索评估器、文档净化和网络搜索增强等组件,提高了生成结果的鲁棒性。"
    }
    
    # 2. 初始化组件
    retriever = SimpleRetriever(sample_corpus)
    generator = SimpleGenerator()
    crag_orchestrator = CRAGOrchestrator(retriever, generator, use_web_search=True)
    
    # 3. 测试不同查询
    test_queries = [
        "什么是CRAG?",
        "请解释量子计算的基本原理",  # 知识库中可能没有的信息
        "机器学习和深度学习有什么区别?"
    ]
    
    for query in test_queries:
        print("\n" + "="*60)
        print(f"处理查询: {query}")
        print("="*60)
        
        result = crag_orchestrator.process_query(query, top_k=3)
        
        print(f"\n生成的响应:\n{result['response']}")
        print(f"\n知识来源类型: {result['source_type']}")
        print(f"置信度分数: {result['confidence_evaluation']['confidence_score']:.3f}")

if __name__ == "__main__":
    main()

1.5 CRAG的适用场景与局限

a) 使用场景

1)事实准确性要求高的问答系统

2) 动态信息查询(需要最新网络信息)

3) 企业知识库与外部知识结合的场景

b) 当前局限与改进方向

  1. 评估器准确性:轻量级评估器可能误判

  2. 网络搜索延迟:实时搜索影响响应时间

  3. 多文档融合:简单拼接可能引入矛盾

2、Reflexion:赋予智能体"复盘反思"与"持续学习"的能力

如果说CRAG专注于单次任务的内部修正,那么Reflexion框架则着眼于让智能体在多次任务尝试中持续进化 。它解决了传统强化学习需要大量试错和昂贵模型微调的难题。论文参考:https://arxiv.org/abs/2303.11366

2.1 系统流程

python 复制代码
用户查询 → 检索相关文档 → 生成初始答案
                          ↓
                   获取反馈信号
                          ↓
                   生成反思文本
                          ↓
                   存储到记忆缓冲区
                          ↓
  新查询 → 检索相关文档+相关反思 → 生成改进答案

2.2 本质

  1. Reflexion将学习过程显式化、语言化。它构建了一个可增长、可检索的"经验库",使智能体具备了类似人类的从错误中学习并积累专业知识的能力。这对于需要多步复杂推理或交互的任务(如调试、策略游戏)效果尤为显著。

  2. Reflexion-RAG的核心创新在于使用语言反馈而非权重更新来强化智能体。传统强化学习需要大量试错和昂贵的模型微调,而Reflexion让智能体通过自然语言进行自我反思,并将这些反思存储在可检索的记忆缓冲区中,实现持续改进。

2.3 核心机制:语言反馈与情景记忆

  • Reflexion的突破在于,它不通过调整模型权重来学习,而是让智能体进行**"口头反思"**。具体流程如下:

    • 智能体执行一项任务(如编写代码、解答问题)。

    • 接收来自环境或评估器的反馈信号(可以是简单的对错分数,也可以是自然语言评语)。

    • 根据反馈,智能体用语言生成一段反思文本,分析刚才失败的原因或成功的经验。

    • 将这段反思文本存入一个情景记忆缓冲区

2.4 如何提升后续表现

当智能体再次遇到相似任务时,它会从记忆缓冲区中检索相关的过往经历和反思结论,并将这些"前车之鉴"作为上下文提示。例如,在编程任务中,如果上次因未处理边界条件而失败,反思文本会记录这一点。下次遇到类似函数时,模型会"想起"这个教训,从而生成更健壮的代码。

2.5 Python示例代码:

python 复制代码
"""
Reflexion-RAG: 具有自我反思与持续学习能力的RAG系统
核心功能:
1. 基础RAG流程:检索 + 生成
2. 反馈收集与反思生成
3. 反思记忆存储与检索
4. 基于反思的答案改进
"""

import json
import hashlib
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from dataclasses import dataclass, asdict
from enum import Enum
import numpy as np
from collections import defaultdict
import pickle
import os

# ==================== 数据模型 ====================

class FeedbackType(Enum):
    """反馈类型枚举"""
    BINARY = "binary"  # 正确/错误
    SCALAR = "scalar"  # 分数,如0-10
    TEXTUAL = "textual"  # 文本反馈
    CODE_EXECUTION = "code_execution"  # 代码执行结果
    AUTOMATED_EVAL = "automated_eval"  # 自动评估结果

@dataclass
class Reflection:
    """反思数据结构"""
    id: str  # 唯一标识
    original_query: str  # 原始查询
    retrieved_docs: List[str]  # 检索到的文档
    generated_answer: str  # 生成的答案
    feedback_type: FeedbackType  # 反馈类型
    feedback_content: Any  # 反馈内容
    reflection_text: str  # 反思文本
    timestamp: datetime  # 创建时间
    embedding: Optional[np.ndarray] = None  # 反思文本的向量表示
    
    def to_dict(self) -> Dict:
        """转换为字典"""
        data = asdict(self)
        data['feedback_type'] = self.feedback_type.value
        data['timestamp'] = self.timestamp.isoformat()
        if self.embedding is not None:
            data['embedding'] = self.embedding.tolist()
        return data
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'Reflection':
        """从字典创建实例"""
        data = data.copy()
        data['feedback_type'] = FeedbackType(data['feedback_type'])
        data['timestamp'] = datetime.fromisoformat(data['timestamp'])
        if 'embedding' in data:
            data['embedding'] = np.array(data['embedding'])
        return cls(**data)

@dataclass
class ReflexionConfig:
    """Reflexion-RAG配置"""
    # 记忆缓冲区配置
    memory_buffer_size: int = 1000  # 最大存储反思数量
    reflection_embedding_model: str = "all-MiniLM-L6-v2"  # 用于反思嵌入的模型
    
    # 反思生成配置
    enable_auto_feedback: bool = True  # 是否启用自动反馈
    feedback_integration_method: str = "concatenate"  # 反馈集成方法: concatenate, weighted, separate
    
    # 检索配置
    max_reflections_per_query: int = 3  # 每查询最多使用的反思数量
    similarity_threshold: float = 0.7  # 反思相似度阈值
    
    # 生成配置
    use_reflection_in_prompt: bool = True  # 是否在提示中使用反思
    reflection_weight: float = 0.3  # 反思在最终决策中的权重

# ==================== 记忆管理系统 ====================

class ReflectionMemoryBuffer:
    """反思记忆缓冲区"""
    
    def __init__(self, config: ReflexionConfig, storage_path: str = "reflection_memory.pkl"):
        self.config = config
        self.storage_path = storage_path
        self.memory = []  # 存储所有反思
        self.query_to_reflections = defaultdict(list)  # 查询到反思的映射
        self.embedding_model = None
        self.reflection_embeddings = []  # 存储所有反思的嵌入向量
        
        # 加载嵌入模型(按需)
        self._init_embedding_model()
        
        # 从磁盘加载现有记忆
        self.load_memory()
    
    def _init_embedding_model(self):
        """初始化嵌入模型(延迟加载)"""
        try:
            from sentence_transformers import SentenceTransformer
            self.embedding_model = SentenceTransformer(self.config.reflection_embedding_model)
        except ImportError:
            print("警告: 未安装sentence-transformers,将无法使用基于嵌入的反思检索")
            self.embedding_model = None
    
    def add_reflection(self, reflection: Reflection):
        """添加反思到记忆缓冲区"""
        # 生成嵌入向量
        if self.embedding_model is not None:
            reflection.embedding = self.embedding_model.encode(
                reflection.reflection_text, 
                convert_to_tensor=False
            )
        
        # 添加到内存
        self.memory.append(reflection)
        query_hash = self._hash_query(reflection.original_query)
        self.query_to_reflections[query_hash].append(reflection.id)
        
        # 限制缓冲区大小
        if len(self.memory) > self.config.memory_buffer_size:
            self.memory.pop(0)
        
        # 保存到磁盘
        self.save_memory()
    
    def retrieve_relevant_reflections(self, query: str, top_k: int = None) -> List[Reflection]:
        """检索与查询相关的反思"""
        if top_k is None:
            top_k = self.config.max_reflections_per_query
        
        relevant_reflections = []
        
        # 方法1: 基于查询哈希的精确匹配
        query_hash = self._hash_query(query)
        if query_hash in self.query_to_reflections:
            reflection_ids = self.query_to_reflections[query_hash]
            for reflection_id in reflection_ids:
                for reflection in self.memory:
                    if reflection.id == reflection_id:
                        relevant_reflections.append(reflection)
                        break
        
        # 方法2: 基于语义相似度的检索(如果嵌入模型可用)
        if self.embedding_model is not None and len(relevant_reflections) < top_k:
            semantic_reflections = self._retrieve_by_semantic_similarity(query, top_k)
            relevant_reflections.extend(semantic_reflections)
        
        # 去重并限制数量
        seen_ids = set()
        unique_reflections = []
        for reflection in relevant_reflections:
            if reflection.id not in seen_ids:
                seen_ids.add(reflection.id)
                unique_reflections.append(reflection)
            if len(unique_reflections) >= top_k:
                break
        
        return unique_reflections[:top_k]
    
    def _retrieve_by_semantic_similarity(self, query: str, top_k: int) -> List[Reflection]:
        """基于语义相似度检索反思"""
        if not self.memory or self.embedding_model is None:
            return []
        
        # 编码查询
        query_embedding = self.embedding_model.encode(query, convert_to_tensor=False)
        
        # 计算相似度
        similarities = []
        for reflection in self.memory:
            if reflection.embedding is not None:
                similarity = self._cosine_similarity(query_embedding, reflection.embedding)
                if similarity >= self.config.similarity_threshold:
                    similarities.append((similarity, reflection))
        
        # 按相似度排序
        similarities.sort(key=lambda x: x[0], reverse=True)
        
        return [reflection for _, reflection in similarities[:top_k]]
    
    def _hash_query(self, query: str) -> str:
        """生成查询的哈希值(用于快速查找)"""
        # 简单实现:使用MD5哈希
        return hashlib.md5(query.lower().encode()).hexdigest()
    
    def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """计算余弦相似度"""
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        if norm1 == 0 or norm2 == 0:
            return 0.0
        return dot_product / (norm1 * norm2)
    
    def save_memory(self):
        """保存记忆到磁盘"""
        try:
            # 转换为可序列化的字典
            memory_data = [reflection.to_dict() for reflection in self.memory]
            
            # 保存查询映射
            query_mapping = {k: list(v) for k, v in self.query_to_reflections.items()}
            
            data = {
                'memory': memory_data,
                'query_mapping': query_mapping,
                'config': asdict(self.config)
            }
            
            with open(self.storage_path, 'wb') as f:
                pickle.dump(data, f)
            
            print(f"记忆已保存到 {self.storage_path}")
        except Exception as e:
            print(f"保存记忆失败: {e}")
    
    def load_memory(self):
        """从磁盘加载记忆"""
        if not os.path.exists(self.storage_path):
            print(f"记忆文件不存在: {self.storage_path}")
            return
        
        try:
            with open(self.storage_path, 'rb') as f:
                data = pickle.load(f)
            
            # 恢复反思对象
            self.memory = [Reflection.from_dict(item) for item in data['memory']]
            
            # 恢复查询映射
            self.query_to_reflections = defaultdict(list)
            for query_hash, reflection_ids in data['query_mapping'].items():
                self.query_to_reflections[query_hash].extend(reflection_ids)
            
            print(f"从 {self.storage_path} 加载了 {len(self.memory)} 条反思")
        except Exception as e:
            print(f"加载记忆失败: {e}")
    
    def get_statistics(self) -> Dict[str, Any]:
        """获取记忆缓冲区的统计信息"""
        total_reflections = len(self.memory)
        unique_queries = len(self.query_to_reflections)
        
        # 按反馈类型统计
        feedback_type_counts = defaultdict(int)
        for reflection in self.memory:
            feedback_type_counts[reflection.feedback_type.value] += 1
        
        return {
            "total_reflections": total_reflections,
            "unique_queries": unique_queries,
            "feedback_type_distribution": dict(feedback_type_counts),
            "memory_buffer_usage": f"{total_reflections}/{self.config.memory_buffer_size}"
        }

# ==================== 反馈收集器 ====================

class FeedbackCollector:
    """反馈收集器"""
    
    @staticmethod
    def collect_binary_feedback(correct: bool, additional_notes: str = "") -> Dict:
        """收集二元反馈(正确/错误)"""
        return {
            "type": FeedbackType.BINARY,
            "content": {
                "correct": correct,
                "notes": additional_notes
            }
        }
    
    @staticmethod
    def collect_scalar_feedback(score: float, max_score: float = 10.0, 
                                criteria: Dict = None) -> Dict:
        """收集标量分数反馈"""
        if criteria is None:
            criteria = {"accuracy": 0.4, "completeness": 0.3, "clarity": 0.3}
        
        normalized_score = score / max_score
        return {
            "type": FeedbackType.SCALAR,
            "content": {
                "score": score,
                "max_score": max_score,
                "normalized_score": normalized_score,
                "criteria": criteria
            }
        }
    
    @staticmethod
    def collect_textual_feedback(feedback_text: str) -> Dict:
        """收集文本反馈"""
        return {
            "type": FeedbackType.TEXTUAL,
            "content": {
                "text": feedback_text
            }
        }
    
    @staticmethod
    def collect_code_execution_feedback(code: str, output: str, 
                                        expected_output: str = None) -> Dict:
        """收集代码执行反馈"""
        success = expected_output is None or output.strip() == expected_output.strip()
        return {
            "type": FeedbackType.CODE_EXECUTION,
            "content": {
                "code": code,
                "output": output,
                "expected_output": expected_output,
                "success": success
            }
        }
    
    @staticmethod
    def automated_evaluation_feedback(query: str, answer: str, 
                                      reference_answer: str = None) -> Dict:
        """自动评估反馈"""
        # 简单的基于相似度的自动评估
        from sentence_transformers import SentenceTransformer, util
        
        model = SentenceTransformer("all-MiniLM-L6-v2")
        answer_embedding = model.encode(answer, convert_to_tensor=True)
        
        if reference_answer:
            ref_embedding = model.encode(reference_answer, convert_to_tensor=True)
            similarity = util.cos_sim(answer_embedding, ref_embedding).item()
            score = similarity * 10  # 转换为0-10分
        else:
            # 如果没有参考答案,使用基于长度的简单启发式方法
            score = min(10, len(answer.split()) / 10)
            similarity = None
        
        return {
            "type": FeedbackType.AUTOMATED_EVAL,
            "content": {
                "similarity_to_reference": similarity,
                "auto_score": score,
                "has_reference": reference_answer is not None
            }
        }

# ==================== 反思生成器 ====================

class ReflectionGenerator:
    """反思生成器"""
    
    def __init__(self, llm_client=None):
        self.llm_client = llm_client  # 实际的LLM客户端,这里使用模拟
    
    def generate_reflection(self, query: str, retrieved_docs: List[str], 
                            answer: str, feedback: Dict) -> str:
        """生成反思文本"""
        feedback_type = feedback["type"]
        feedback_content = feedback["content"]
        
        if feedback_type == FeedbackType.BINARY:
            correct = feedback_content["correct"]
            notes = feedback_content.get("notes", "")
            
            if correct:
                reflection = f"对于查询 '{query[:50]}...',我给出了正确的回答。"
                if notes:
                    reflection += f" 用户特别指出: {notes}"
                reflection += " 我应该记住这种成功的模式。"
            else:
                reflection = f"对于查询 '{query[:50]}...',我的回答是错误的。"
                if notes:
                    reflection += f" 具体问题: {notes}"
                reflection += " 我需要调整对这类问题的理解。"
                
        elif feedback_type == FeedbackType.SCALAR:
            score = feedback_content["score"]
            max_score = feedback_content["max_score"]
            normalized_score = feedback_content["normalized_score"]
            
            if normalized_score >= 0.7:
                reflection = f"对于查询 '{query[:50]}...',我的回答评分较高 ({score}/{max_score})。"
                reflection += " 这表明我的方法有效,应该继续使用类似策略。"
            elif normalized_score >= 0.4:
                reflection = f"对于查询 '{query[:50]}...',我的回答评分中等 ({score}/{max_score})。"
                reflection += " 有改进空间,我需要分析哪些部分可以做得更好。"
            else:
                reflection = f"对于查询 '{query[:50]}...',我的回答评分较低 ({score}/{max_score})。"
                reflection += " 我需要从根本上重新思考如何处理这类问题。"
                
        elif feedback_type == FeedbackType.TEXTUAL:
            feedback_text = feedback_content["text"]
            reflection = f"对于查询 '{query[:50]}...',用户反馈: '{feedback_text}'。"
            reflection += " 我需要仔细考虑这个反馈来改进未来的回答。"
            
        elif feedback_type == FeedbackType.CODE_EXECUTION:
            success = feedback_content["success"]
            if success:
                reflection = f"对于代码生成查询,我的代码执行成功。"
                reflection += " 这表明我对编程问题的理解是正确的。"
            else:
                reflection = f"对于代码生成查询,我的代码执行失败或输出不正确。"
                reflection += " 我需要更仔细地分析问题要求并测试边缘情况。"
                
        elif feedback_type == FeedbackType.AUTOMATED_EVAL:
            auto_score = feedback_content.get("auto_score", 0)
            reflection = f"对于查询 '{query[:50]}...',自动评估得分为 {auto_score:.1f}/10。"
            if auto_score >= 7:
                reflection += " 自动评估表明回答质量较好。"
            else:
                reflection += " 自动评估表明有改进空间。"
        
        # 添加上下文信息
        doc_summary = ", ".join([doc[:30] + "..." for doc in retrieved_docs[:2]])
        reflection += f" 我检索了以下文档: {doc_summary}"
        
        return reflection
    
    def generate_reflection_with_llm(self, query: str, retrieved_docs: List[str],
                                     answer: str, feedback: Dict) -> str:
        """使用LLM生成更复杂的反思(如果LLM客户端可用)"""
        if self.llm_client is None:
            # 回退到基于规则的方法
            return self.generate_reflection(query, retrieved_docs, answer, feedback)
        
        # 构建LLM提示
        prompt = self._build_reflection_prompt(query, retrieved_docs, answer, feedback)
        
        try:
            # 这里调用实际的LLM
            # reflection = self.llm_client.generate(prompt)
            # 模拟响应
            reflection = f"LLM生成的反思: 对于查询'{query[:30]}...',基于反馈{feedback['type'].value},"
            reflection += "我分析了回答的质量并确定了改进方向。"
            return reflection
        except Exception as e:
            print(f"LLM反思生成失败: {e}")
            return self.generate_reflection(query, retrieved_docs, answer, feedback)
    
    def _build_reflection_prompt(self, query: str, retrieved_docs: List[str],
                                 answer: str, feedback: Dict) -> str:
        """构建反思生成提示"""
        feedback_str = json.dumps(feedback, indent=2, default=str)
        
        prompt = f"""基于以下交互生成一段反思,帮助改进未来的回答:

查询: {query}

检索到的文档摘要:
{chr(10).join(['- ' + doc[:100] + '...' for doc in retrieved_docs[:3]])}

生成的回答:
{answer}

收到的反馈:
{feedback_str}

请生成一段简短的反思,总结这次交互的经验教训,并指出未来如何改进。反思应该具体、可操作。"""
        
        return prompt

# ==================== Reflexion-RAG主系统 ====================

class ReflexionRAGSystem:
    """Reflexion-RAG主系统"""
    
    def __init__(self, base_rag_system, config: Optional[ReflexionConfig] = None):
        self.base_rag = base_rag_system  # 基础RAG系统
        self.config = config or ReflexionConfig()
        
        # 初始化组件
        self.memory_buffer = ReflectionMemoryBuffer(self.config)
        self.feedback_collector = FeedbackCollector()
        self.reflection_generator = ReflectionGenerator()
        
        # 性能追踪
        self.interaction_history = []
    
    def answer_with_reflection(self, query: str, 
                               collect_feedback: bool = True) -> Dict[str, Any]:
        """使用反思增强的RAG回答查询"""
        print(f"\n处理查询: {query[:50]}...")
        
        # 步骤1: 检索相关反思
        relevant_reflections = self.memory_buffer.retrieve_relevant_reflections(query)
        print(f"检索到 {len(relevant_reflections)} 条相关反思")
        
        # 步骤2: 使用基础RAG生成初始答案
        base_result = self.base_rag.answer(query)
        initial_answer = base_result.get("answer", "")
        retrieved_docs = base_result.get("retrieved_docs", [])
        
        # 步骤3: 如果有相关反思,改进答案
        if relevant_reflections and self.config.use_reflection_in_prompt:
            improved_answer = self._improve_answer_with_reflections(
                query, initial_answer, retrieved_docs, relevant_reflections
            )
        else:
            improved_answer = initial_answer
        
        # 步骤4: 收集反馈(如果启用)
        feedback = None
        if collect_feedback and self.config.enable_auto_feedback:
            feedback = self._collect_auto_feedback(query, improved_answer)
        
        # 步骤5: 记录交互
        interaction_id = hashlib.md5(f"{query}{datetime.now().isoformat()}".encode()).hexdigest()
        interaction_record = {
            "id": interaction_id,
            "query": query,
            "initial_answer": initial_answer,
            "improved_answer": improved_answer,
            "retrieved_docs": retrieved_docs,
            "relevant_reflections": [r.id for r in relevant_reflections],
            "feedback": feedback,
            "timestamp": datetime.now().isoformat()
        }
        self.interaction_history.append(interaction_record)
        
        return {
            "query": query,
            "answer": improved_answer,
            "initial_answer": initial_answer,
            "retrieved_docs": retrieved_docs,
            "relevant_reflections": relevant_reflections,
            "feedback": feedback,
            "interaction_id": interaction_id,
            "used_reflections": len(relevant_reflections) > 0
        }
    
    def _improve_answer_with_reflections(self, query: str, initial_answer: str,
                                         retrieved_docs: List[str], 
                                         reflections: List[Reflection]) -> str:
        """使用反思改进答案"""
        # 简单实现:将反思作为额外上下文
        reflection_texts = [r.reflection_text for r in reflections]
        reflections_summary = "\n\n历史经验教训:\n" + "\n".join(
            [f"- {text}" for text in reflection_texts]
        )
        
        # 构建增强提示
        enhanced_prompt = f"""基于以下信息和历史经验,请回答问题。

相关文档:
{chr(10).join(['- ' + doc[:200] for doc in retrieved_docs[:3]])}

{reflections_summary}

问题: {query}

请考虑历史经验教训,提供更准确的回答。"""
        
        # 使用基础RAG的生成器(这里简化处理)
        # 实际应用中应调用LLM生成改进答案
        improved_answer = f"{initial_answer}\n\n[基于{len(reflections)}条历史反思进行了优化]"
        
        return improved_answer
    
    def _collect_auto_feedback(self, query: str, answer: str) -> Dict:
        """收集自动反馈"""
        # 这里可以集成各种自动评估方法
        # 示例:使用简单的启发式方法
        answer_length = len(answer.split())
        
        if answer_length < 10:
            score = 3.0
        elif answer_length < 50:
            score = 6.0
        else:
            score = 8.0
        
        return self.feedback_collector.collect_scalar_feedback(
            score=score, 
            max_score=10.0,
            criteria={"length": 0.3, "specificity": 0.4, "relevance": 0.3}
        )
    
    def add_feedback_and_reflect(self, interaction_id: str, feedback: Dict):
        """为特定交互添加反馈并生成反思"""
        # 查找交互记录
        interaction = None
        for record in self.interaction_history:
            if record["id"] == interaction_id:
                interaction = record
                break
        
        if not interaction:
            print(f"未找到交互记录: {interaction_id}")
            return
        
        # 生成反思
        reflection_text = self.reflection_generator.generate_reflection(
            query=interaction["query"],
            retrieved_docs=interaction["retrieved_docs"],
            answer=interaction["improved_answer"],
            feedback=feedback
        )
        
        # 创建反思对象
        reflection_id = hashlib.md5(
            f"{interaction_id}{feedback['type']}{datetime.now().isoformat()}".encode()
        ).hexdigest()
        
        reflection = Reflection(
            id=reflection_id,
            original_query=interaction["query"],
            retrieved_docs=interaction["retrieved_docs"],
            generated_answer=interaction["improved_answer"],
            feedback_type=feedback["type"],
            feedback_content=feedback["content"],
            reflection_text=reflection_text,
            timestamp=datetime.now()
        )
        
        # 存储到记忆缓冲区
        self.memory_buffer.add_reflection(reflection)
        
        print(f"已生成并存储反思: {reflection_id}")
        return reflection
    
    def get_system_stats(self) -> Dict[str, Any]:
        """获取系统统计信息"""
        memory_stats = self.memory_buffer.get_statistics()
        
        return {
            "total_interactions": len(self.interaction_history),
            "memory_statistics": memory_stats,
            "config": asdict(self.config)
        }

# ==================== 基础RAG系统(示例实现) ====================

class BaseRAGSystem:
    """基础RAG系统(简化示例)"""
    
    def __init__(self, knowledge_base: Dict[str, str] = None):
        self.knowledge_base = knowledge_base or self._create_sample_kb()
        self.embedding_model = None
    
    def _create_sample_kb(self) -> Dict[str, str]:
        """创建示例知识库"""
        return {
            "doc1": "Python是一种高级编程语言,由Guido van Rossum于1991年创建。",
            "doc2": "机器学习是人工智能的一个分支,使计算机能够从数据中学习。",
            "doc3": "RAG(检索增强生成)结合信息检索和文本生成技术。",
            "doc4": "深度学习基于人工神经网络,在图像识别等领域表现出色。",
            "doc5": "Reflexion是一种通过语言反馈强化学习的方法。",
            "doc6": "大型语言模型如GPT-4在自然语言处理任务中表现优异。"
        }
    
    def retrieve(self, query: str, top_k: int = 3) -> List[str]:
        """检索相关文档(简化实现)"""
        # 这里使用简单的关键词匹配,实际应使用向量检索
        query_lower = query.lower()
        relevant_docs = []
        
        for doc_id, content in self.knowledge_base.items():
            # 简单关键词匹配
            score = 0
            for word in query_lower.split():
                if word in content.lower():
                    score += 1
            
            if score > 0:
                relevant_docs.append((score, content))
        
        # 按分数排序
        relevant_docs.sort(key=lambda x: x[0], reverse=True)
        
        return [doc for _, doc in relevant_docs[:top_k]]
    
    def generate_answer(self, query: str, context: List[str]) -> str:
        """生成答案(简化实现)"""
        # 实际应使用LLM生成
        context_summary = " ".join([doc[:100] for doc in context])
        
        return f"基于检索到的信息: {context_summary[:200]}...,回答查询: {query}。这是一个示例回答,实际应用中会使用LLM生成更详细的答案。"
    
    def answer(self, query: str) -> Dict[str, Any]:
        """完整RAG流程"""
        retrieved_docs = self.retrieve(query, top_k=3)
        answer = self.generate_answer(query, retrieved_docs)
        
        return {
            "query": query,
            "answer": answer,
            "retrieved_docs": retrieved_docs
        }

# ==================== 使用示例 ====================

def main():
    """Reflexion-RAG使用示例"""
    print("=" * 60)
    print("Reflexion-RAG 系统演示")
    print("=" * 60)
    
    # 1. 初始化基础RAG系统
    base_rag = BaseRAGSystem()
    
    # 2. 配置Reflexion-RAG
    config = ReflexionConfig(
        memory_buffer_size=50,
        max_reflections_per_query=2,
        similarity_threshold=0.6
    )
    
    # 3. 创建Reflexion-RAG系统
    reflexion_rag = ReflexionRAGSystem(base_rag, config)
    
    # 4. 执行一系列查询(模拟用户交互)
    queries = [
        "什么是Python?",
        "解释机器学习的基本概念",
        "RAG系统如何工作?",
        "深度学习和机器学习有什么区别?"
    ]
    
    for i, query in enumerate(queries, 1):
        print(f"\n{'='*40}")
        print(f"交互 {i}: {query}")
        print(f"{'='*40}")
        
        # 回答问题
        result = reflexion_rag.answer_with_reflection(query)
        print(f"答案: {result['answer'][:150]}...")
        print(f"使用了 {len(result['relevant_reflections'])} 条反思")
        
        # 模拟用户反馈(每隔一次交互提供反馈)
        if i % 2 == 0:
            print("\n模拟用户反馈...")
            
            # 收集反馈
            feedback = FeedbackCollector.collect_binary_feedback(
                correct=True if i % 3 != 0 else False,  # 每3次有一个错误
                additional_notes="回答比较全面" if i % 3 != 0 else "缺少关键细节"
            )
            
            # 添加反馈并生成反思
            reflection = reflexion_rag.add_feedback_and_reflect(
                result["interaction_id"], 
                feedback
            )
            
            if reflection:
                print(f"生成的反思: {reflection.reflection_text[:100]}...")
    
    # 5. 显示系统统计
    print(f"\n{'='*60}")
    print("系统统计信息:")
    print(f"{'='*60}")
    
    stats = reflexion_rag.get_system_stats()
    print(f"总交互次数: {stats['total_interactions']}")
    print(f"记忆缓冲区: {stats['memory_statistics']['memory_buffer_usage']}")
    print(f"反思类型分布: {json.dumps(stats['memory_statistics']['feedback_type_distribution'], indent=2)}")
    
    # 6. 测试新查询(应能从历史中学习)
    print(f"\n{'='*60}")
    print("测试新查询(应能利用历史反思):")
    print(f"{'='*60}")
    
    new_query = "Python编程语言有什么特点?"
    new_result = reflexion_rag.answer_with_reflection(new_query)
    print(f"查询: {new_query}")
    print(f"答案: {new_result['answer'][:200]}...")
    print(f"使用了 {len(new_result['relevant_reflections'])} 条相关反思")

if __name__ == "__main__":
    main()

2.6 扩展应用场景

  1. 代码生成与调试:Reflexion特别适合编程任务,可以从编译错误和执行结果中学习

  2. 对话系统:基于用户满意度反馈改进对话策略

  3. 教育辅导:根据学生答题情况调整教学策略

  4. 内容审核:从审核决策反馈中学习更准确的审核标准

3、Self-RAG:将检索、生成与批判内化为模型的"本能"

Self-RAG 走得更远,它训练一个单一的模型,使其自适应地掌握整个RAG流程的决策权 ,实现真正的"自我驱动"。 论文参考:https://arxiv.org/abs/2310.11511

3.1 系统流程

python 复制代码
输入查询 → 自适应检索决策 → [Retrieve]/[NoRetrieve]
               ↓
       检索相关文档(如需)
               ↓
       文档相关性评估 → [Relevant]/[Irrelevant]
               ↓
       生成带反思标记的内容
               ↓
       内容支持性评估 → [Support]/[NoSupport]
               ↓
       生成质量评估 → [Good]/[Poor]
               ↓
       最终输出

3.2 本质

Self-RAG实现了检索与生成的深度一体化。它将检索决策、内容评估和生成控制都内化到同一个模型的推理逻辑中,使整个系统更加灵活、高效且可控,特别适合开放域问答和需要高事实准确性、高引用透明度的长文本生成任务。

3.3 核心机制

  1. 自适应检索:该出手时才出手

    传统RAG对每个问题都机械地检索固定数量的文档。Self-RAG训练的模型会自行判断:当前问题是否需要检索? 对于常识性问题,模型选择直接依赖内部知识生成,避免不必要的检索开销和噪声引入。

  2. 自我反思标记:精细化过程控制

    Self-RAG在训练时引入了特殊的反思标记。这些标记让模型在生成过程中,能同时对几个关键节点进行自我评估:

    • 检索必要性:是否需要检索?

    • 段落相关性:检索到的这个文档块相关吗?

    • 生成内容支持:我接下来的这句话,有检索到的文档作为依据吗?

    • 生成内容质量:我生成的这段话整体上是否流畅、有用?

Self-RAG的核心是在同一个模型内实现了检索决策、内容生成和结果评估的全流程闭环。与传统RAG的关键区别在于:

维度 传统RAG Self-RAG
检索时机 固定或启发式 自适应决定
评估机制 外部评估 内部自我评估
控制粒度 粗粒度 细粒度(句子/段落级)
反馈机制 无实时反馈 实时自我反思

3.5 代码示例

python 复制代码
"""
Self-RAG完整实现:自适应检索、自我反思的RAG系统
核心特性:
1. 自适应检索决策(是否检索、何时检索)
2. 文档相关性自我评估
3. 生成内容支持性评估
4. 细粒度反思标记生成
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Any, Optional, Tuple, Union
import json
import numpy as np
from dataclasses import dataclass, field
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    PreTrainedTokenizer,
    PreTrainedModel,
    Trainer,
    TrainingArguments
)
from enum import Enum
import logging
from collections import defaultdict
from tqdm import tqdm
import os

# ==================== 配置与数据模型 ====================

class ReflectionToken(Enum):
    """Self-RAG反思标记枚举"""
    # 检索决策标记
    RETRIEVE = "[Retrieve]"
    NO_RETRIEVE = "[NoRetrieve]"
    
    # 文档相关性标记
    RELEVANT = "[Relevant]"
    IRRELEVANT = "[Irrelevant]"
    
    # 内容支持性标记
    SUPPORT = "[Support]"
    NO_SUPPORT = "[NoSupport]"
    
    # 生成质量标记
    GOOD = "[Good]"
    POOR = "[Poor]"
    
    # 生成控制标记
    CONTINUE = "[Continue]"
    END = "[End]"

@dataclass
class SelfRAGConfig:
    """Self-RAG配置"""
    # 模型配置
    base_model_name: str = "mistralai/Mistral-7B-v0.1"
    max_length: int = 4096
    reflection_token_weight: float = 1.5  # 反思标记在损失中的权重
    
    # 检索配置
    retrieval_threshold: float = 0.3  # 检索决策阈值
    top_k_documents: int = 3  # 每次检索的文档数量
    use_adaptive_retrieval: bool = True  # 是否使用自适应检索
    
    # 反思标记配置
    enable_reflection_tokens: bool = True
    reflection_token_positions: List[str] = field(default_factory=lambda: [
        "retrieve_decision",
        "document_relevance", 
        "content_support",
        "generation_quality"
    ])
    
    # 训练配置
    learning_rate: float = 2e-5
    batch_size: int = 2
    gradient_accumulation_steps: int = 4
    num_train_epochs: int = 3

@dataclass
class RetrievalResult:
    """检索结果"""
    documents: List[str]
    scores: List[float]
    metadata: List[Dict[str, Any]]

@dataclass
class GenerationStep:
    """生成步骤记录"""
    step_id: int
    generated_text: str
    reflection_tokens: List[Tuple[str, str]]  # (标记类型, 标记值)
    retrieved_docs: Optional[List[str]] = None
    confidence_scores: Optional[Dict[str, float]] = None

# ==================== Self-RAG模型架构 ====================

class SelfRAGModel(nn.Module):
    """Self-RAG模型:在基础LLM上添加反思标记生成能力"""
    
    def __init__(self, config: SelfRAGConfig):
        super().__init__()
        self.config = config
        
        # 加载基础语言模型
        self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
        self.base_model = AutoModelForCausalLM.from_pretrained(
            config.base_model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )
        
        # 添加反思标记到词汇表
        self._add_reflection_tokens()
        
        # 反射头:用于生成反思标记
        self.reflection_heads = nn.ModuleDict({
            "retrieve_decision": nn.Linear(self.base_model.config.hidden_size, 2),
            "document_relevance": nn.Linear(self.base_model.config.hidden_size, 2),
            "content_support": nn.Linear(self.base_model.config.hidden_size, 2),
            "generation_quality": nn.Linear(self.base_model.config.hidden_size, 2)
        })
        
        # 检索决策器
        self.retrieval_decision_head = nn.Sequential(
            nn.Linear(self.base_model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        # 初始化反思标记的嵌入
        self._init_reflection_token_embeddings()
    
    def _add_reflection_tokens(self):
        """将反思标记添加到tokenizer的词汇表中"""
        reflection_tokens = [token.value for token in ReflectionToken]
        
        # 检查是否已添加
        existing_tokens = set(self.tokenizer.get_vocab().keys())
        new_tokens = [token for token in reflection_tokens if token not in existing_tokens]
        
        if new_tokens:
            self.tokenizer.add_tokens(new_tokens)
            self.base_model.resize_token_embeddings(len(self.tokenizer))
            print(f"添加了 {len(new_tokens)} 个反思标记到词汇表")
    
    def _init_reflection_token_embeddings(self):
        """初始化反思标记的嵌入向量"""
        # 获取特殊标记的嵌入
        special_tokens = [token.value for token in ReflectionToken]
        special_token_ids = self.tokenizer.convert_tokens_to_ids(special_tokens)
        
        # 这里可以添加自定义的初始化逻辑
        # 例如,可以将反思标记的嵌入初始化为与语义相近的词的嵌入的平均值
        
    def forward(self, input_ids, attention_mask=None, labels=None, 
                retrieval_context=None, reflection_targets=None):
        """
        Self-RAG前向传播
        
        参数:
            input_ids: 输入token IDs
            attention_mask: 注意力掩码
            labels: 用于训练的目标token IDs
            retrieval_context: 检索到的上下文文档
            reflection_targets: 反思标记的目标值
        
        返回:
            包含损失和输出的字典
        """
        # 基础模型前向传播
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        
        hidden_states = outputs.hidden_states[-1]  # 最后一层隐藏状态
        logits = outputs.logits
        
        total_loss = 0
        reflection_outputs = {}
        
        # 计算反思标记的损失(如果提供目标)
        if reflection_targets is not None:
            reflection_loss = self._compute_reflection_loss(
                hidden_states, reflection_targets
            )
            total_loss += reflection_loss * self.config.reflection_token_weight
            reflection_outputs["reflection_loss"] = reflection_loss
        
        # 计算语言建模损失(如果提供标签)
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            lm_loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100
            )
            total_loss += lm_loss
            reflection_outputs["lm_loss"] = lm_loss
        
        return {
            "loss": total_loss,
            "logits": logits,
            "hidden_states": hidden_states,
            **reflection_outputs
        }
    
    def _compute_reflection_loss(self, hidden_states, reflection_targets):
        """计算反思标记的损失"""
        reflection_loss = 0
        batch_size = hidden_states.size(0)
        
        for position, head in self.reflection_heads.items():
            if position in reflection_targets:
                # 提取对应位置的隐藏状态
                # 这里简化处理:使用序列的最后一个隐藏状态
                position_hidden = hidden_states[:, -1, :]  # (batch_size, hidden_size)
                
                # 计算反思标记的logits
                reflection_logits = head(position_hidden)  # (batch_size, 2)
                
                # 计算损失
                targets = reflection_targets[position]  # (batch_size,)
                loss = F.cross_entropy(reflection_logits, targets)
                reflection_loss += loss
        
        return reflection_loss / len(self.reflection_heads)
    
    def generate_with_reflection(self, query: str, retriever, max_steps: int = 10):
        """
        带反思标记的生成过程
        
        参数:
            query: 用户查询
            retriever: 检索器实例
            max_steps: 最大生成步骤数
        
        返回:
            生成结果和反思记录
        """
        # 初始化生成状态
        generated_text = ""
        reflection_records = []
        retrieved_docs = []
        step = 0
        
        # 初始输入
        input_text = query
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.device)
        
        while step < max_steps:
            step += 1
            print(f"\n步骤 {step}:")
            
            # 1. 检索决策
            retrieve_decision = self._decide_retrieval(input_text, generated_text)
            reflection_records.append(("retrieve_decision", retrieve_decision))
            
            if retrieve_decision == "retrieve":
                print("决定检索...")
                # 执行检索
                retrieval_results = retriever.retrieve(query, self.config.top_k_documents)
                retrieved_docs = retrieval_results.documents
                
                # 评估文档相关性
                relevance_decisions = []
                for doc in retrieved_docs:
                    relevance = self._assess_document_relevance(query, doc)
                    relevance_decisions.append(relevance)
                    reflection_records.append(("document_relevance", relevance))
                
                # 过滤相关文档
                relevant_docs = [
                    doc for doc, rel in zip(retrieved_docs, relevance_decisions)
                    if rel == "relevant"
                ]
                
                if relevant_docs:
                    # 构建增强的输入
                    context = "\n\n".join(relevant_docs)
                    enhanced_input = f"上下文:\n{context}\n\n问题: {query}\n\n回答:"
                else:
                    enhanced_input = f"问题: {query}\n\n回答:"
            else:
                print("决定不检索...")
                enhanced_input = f"问题: {query}\n\n回答:"
                relevant_docs = []
            
            # 2. 生成内容
            enhanced_ids = self.tokenizer.encode(enhanced_input, return_tensors="pt").to(self.device)
            
            # 生成下一个token
            with torch.no_grad():
                outputs = self.base_model.generate(
                    enhanced_ids,
                    max_length=len(enhanced_ids[0]) + 50,
                    num_return_sequences=1,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # 解码生成的文本
            generated_tokens = outputs[0, len(enhanced_ids[0]):]
            generated_segment = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            
            # 3. 评估生成内容
            if relevant_docs:
                support_decision = self._assess_content_support(
                    generated_segment, relevant_docs
                )
                reflection_records.append(("content_support", support_decision))
            
            quality_decision = self._assess_generation_quality(generated_segment)
            reflection_records.append(("generation_quality", quality_decision))
            
            # 4. 更新生成文本
            generated_text += generated_segment
            
            # 5. 决定是否继续
            continue_decision = self._decide_continuation(
                query, generated_text, step, max_steps
            )
            
            if continue_decision == "end":
                reflection_records.append(("generation_control", "end"))
                print("生成结束")
                break
            else:
                reflection_records.append(("generation_control", "continue"))
                # 更新输入,继续生成
                input_text = f"{query}\n\n当前回答: {generated_text}\n\n继续:"
        
        return {
            "query": query,
            "generated_text": generated_text,
            "reflection_records": reflection_records,
            "retrieved_docs": retrieved_docs,
            "steps": step
        }
    
    def _decide_retrieval(self, query: str, current_text: str) -> str:
        """决定是否需要检索"""
        if not self.config.use_adaptive_retrieval:
            return "retrieve"
        
        # 构建决策输入
        decision_input = f"问题: {query}\n\n当前回答: {current_text}\n\n是否需要检索更多信息?"
        input_ids = self.tokenizer.encode(decision_input, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.base_model(input_ids, output_hidden_states=True)
            last_hidden = outputs.hidden_states[-1][:, -1, :]
            
            # 使用检索决策头
            retrieval_score = self.retrieval_decision_head(last_hidden).item()
        
        return "retrieve" if retrieval_score > self.config.retrieval_threshold else "no_retrieve"
    
    def _assess_document_relevance(self, query: str, document: str) -> str:
        """评估文档相关性"""
        # 简化的评估:基于嵌入相似度
        # 实际应用中可以使用更复杂的评估方法
        query_embedding = self._get_text_embedding(query)
        doc_embedding = self._get_text_embedding(document[:500])  # 只取前500字符
        
        similarity = F.cosine_similarity(
            query_embedding.unsqueeze(0),
            doc_embedding.unsqueeze(0)
        ).item()
        
        return "relevant" if similarity > 0.7 else "irrelevant"
    
    def _assess_content_support(self, content: str, documents: List[str]) -> str:
        """评估内容是否得到文档支持"""
        # 简化的支持性评估
        content_embedding = self._get_text_embedding(content)
        
        max_similarity = 0
        for doc in documents:
            doc_embedding = self._get_text_embedding(doc[:300])
            similarity = F.cosine_similarity(
                content_embedding.unsqueeze(0),
                doc_embedding.unsqueeze(0)
            ).item()
            max_similarity = max(max_similarity, similarity)
        
        return "support" if max_similarity > 0.6 else "no_support"
    
    def _assess_generation_quality(self, content: str) -> str:
        """评估生成质量"""
        # 简化的质量评估:基于长度和多样性
        tokens = self.tokenizer.encode(content)
        
        if len(tokens) < 5:
            return "poor"
        
        # 计算词汇多样性
        unique_tokens = len(set(tokens))
        diversity = unique_tokens / len(tokens)
        
        return "good" if diversity > 0.7 else "poor"
    
    def _decide_continuation(self, query: str, current_text: str, 
                            current_step: int, max_steps: int) -> str:
        """决定是否继续生成"""
        if current_step >= max_steps:
            return "end"
        
        # 检查是否已回答完整
        completeness_keywords = ["总之", "综上所述", "总的来说", "因此"]
        if any(keyword in current_text for keyword in completeness_keywords):
            return "end"
        
        return "continue"
    
    def _get_text_embedding(self, text: str) -> torch.Tensor:
        """获取文本的嵌入向量"""
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.base_model(**inputs, output_hidden_states=True)
            # 使用最后一层隐藏状态的平均值作为文本嵌入
            last_hidden = outputs.hidden_states[-1]
            embedding = last_hidden.mean(dim=1).squeeze()
        
        return embedding
    
    @property
    def device(self):
        return self.base_model.device

# ==================== 训练数据与数据集 ====================

class SelfRAGDataset(Dataset):
    """Self-RAG训练数据集"""
    
    def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer, config: SelfRAGConfig):
        self.tokenizer = tokenizer
        self.config = config
        
        # 加载训练数据
        with open(data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        # 预处理数据
        self.examples = self._preprocess_data()
    
    def _preprocess_data(self):
        """预处理训练数据"""
        examples = []
        
        for item in self.data:
            query = item["query"]
            documents = item.get("documents", [])
            answer = item["answer"]
            
            # 生成带反思标记的训练序列
            training_sequence = self._create_training_sequence(query, documents, answer)
            examples.append(training_sequence)
        
        return examples
    
    def _create_training_sequence(self, query: str, documents: List[str], answer: str) -> Dict:
        """创建带反思标记的训练序列"""
        
        # 1. 构建检索决策部分
        if documents:
            retrieve_decision = ReflectionToken.RETRIEVE.value
            # 构建文档上下文
            context = "\n".join([f"[文档{i+1}] {doc[:200]}" for i, doc in enumerate(documents[:3])])
            input_text = f"{retrieve_decision} 上下文:\n{context}\n\n问题: {query}\n\n回答:"
        else:
            retrieve_decision = ReflectionToken.NO_RETRIEVE.value
            input_text = f"{retrieve_decision} 问题: {query}\n\n回答:"
        
        # 2. 构建生成部分(带反思标记)
        # 这里简化处理:在实际训练中,需要更精细地插入反思标记
        answer_with_reflections = f"{answer} {ReflectionToken.SUPPORT.value} {ReflectionToken.GOOD.value} {ReflectionToken.END.value}"
        
        # 3. 完整的训练文本
        full_text = f"{input_text} {answer_with_reflections}"
        
        # 4. 创建标签(用于语言建模损失)
        inputs = self.tokenizer(full_text, truncation=True, max_length=self.config.max_length)
        
        # 5. 创建反思标记的目标(用于反思损失)
        reflection_targets = {
            "retrieve_decision": 0 if documents else 1,  # 0: 检索, 1: 不检索
            "document_relevance": 0,  # 0: 相关, 1: 不相关
            "content_support": 0,     # 0: 支持, 1: 不支持
            "generation_quality": 0   # 0: 好, 1: 差
        }
        
        return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "reflection_targets": reflection_targets,
            "query": query,
            "answer": answer
        }
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

# ==================== 检索器实现 ====================

class VectorRetriever:
    """基于向量的检索器"""
    
    def __init__(self, knowledge_base: Dict[str, str], embedding_model_name: str = "all-MiniLM-L6-v2"):
        self.knowledge_base = knowledge_base
        
        # 初始化嵌入模型
        from sentence_transformers import SentenceTransformer
        self.embedding_model = SentenceTransformer(embedding_model_name)
        
        # 预计算知识库的嵌入
        self.doc_ids = list(knowledge_base.keys())
        self.doc_texts = [knowledge_base[doc_id] for doc_id in self.doc_ids]
        self.doc_embeddings = self.embedding_model.encode(
            self.doc_texts, 
            convert_to_tensor=True,
            show_progress_bar=True
        )
    
    def retrieve(self, query: str, top_k: int = 3) -> RetrievalResult:
        """检索相关文档"""
        # 编码查询
        query_embedding = self.embedding_model.encode(query, convert_to_tensor=True)
        
        # 计算相似度
        from sentence_transformers import util
        similarities = util.cos_sim(query_embedding, self.doc_embeddings)[0]
        
        # 获取top-k结果
        top_indices = similarities.argsort(descending=True)[:top_k].cpu().numpy()
        
        # 构建结果
        documents = [self.doc_texts[idx] for idx in top_indices]
        scores = [similarities[idx].item() for idx in top_indices]
        metadata = [{"doc_id": self.doc_ids[idx]} for idx in top_indices]
        
        return RetrievalResult(
            documents=documents,
            scores=scores,
            metadata=metadata
        )
    
    def add_document(self, doc_id: str, content: str):
        """添加文档到知识库"""
        self.knowledge_base[doc_id] = content
        
        # 更新嵌入
        new_embedding = self.embedding_model.encode(content, convert_to_tensor=True)
        self.doc_embeddings = torch.cat([self.doc_embeddings, new_embedding.unsqueeze(0)])
        self.doc_ids.append(doc_id)
        self.doc_texts.append(content)

# ==================== 训练器 ====================

class SelfRAGTrainer(Trainer):
    """自定义Self-RAG训练器"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        自定义损失计算,结合语言建模损失和反思标记损失
        """
        # 提取输入
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        labels = inputs.get("labels")
        reflection_targets = inputs.get("reflection_targets")
        
        # 前向传播
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            reflection_targets=reflection_targets
        )
        
        loss = outputs["loss"]
        
        return (loss, outputs) if return_outputs else loss
    
    def create_optimizer(self):
        """
        创建优化器,对反思标记相关的参数使用不同的学习率
        """
        # 区分基础模型参数和反思头参数
        base_model_params = []
        reflection_params = []
        
        for name, param in self.model.named_parameters():
            if "reflection_heads" in name or "retrieval_decision_head" in name:
                reflection_params.append(param)
            else:
                base_model_params.append(param)
        
        optimizer_grouped_parameters = [
            {
                "params": base_model_params,
                "lr": self.args.learning_rate,
                "weight_decay": self.args.weight_decay
            },
            {
                "params": reflection_params,
                "lr": self.args.learning_rate * 2.0,  # 反思头使用更高的学习率
                "weight_decay": self.args.weight_decay
            }
        ]
        
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.args.learning_rate
        )
        
        return optimizer

# ==================== 完整Self-RAG系统 ====================

class SelfRAGSystem:
    """完整的Self-RAG系统"""
    
    def __init__(self, config: SelfRAGConfig, knowledge_base_path: str = None):
        self.config = config
        
        # 初始化模型
        print("初始化Self-RAG模型...")
        self.model = SelfRAGModel(config)
        
        # 初始化检索器
        print("初始化检索器...")
        if knowledge_base_path:
            with open(knowledge_base_path, 'r', encoding='utf-8') as f:
                knowledge_base = json.load(f)
        else:
            knowledge_base = self._create_default_knowledge_base()
        
        self.retriever = VectorRetriever(knowledge_base)
        
        # 训练状态
        self.is_trained = False
        
        # 日志记录
        self.interaction_log = []
        
        print("Self-RAG系统初始化完成")
    
    def _create_default_knowledge_base(self) -> Dict[str, str]:
        """创建默认知识库"""
        return {
            "doc1": "Self-RAG(自反思检索增强生成)是一种先进的RAG框架,它使模型能够自我评估检索到的文档和生成的内容。",
            "doc2": "大型语言模型(LLM)如GPT-4在自然语言处理任务中表现出色,但可能存在幻觉问题。",
            "doc3": "检索增强生成(RAG)通过结合外部知识库来减少LLM的幻觉,提高事实准确性。",
            "doc4": "自适应检索允许模型根据查询的复杂性决定是否需要检索外部信息。",
            "doc5": "反思标记是Self-RAG中的特殊标记,用于控制生成过程和评估内容质量。",
            "doc6": "Python是一种广泛使用的高级编程语言,以其简洁性和可读性而闻名。",
            "doc7": "机器学习是人工智能的一个分支,使计算机能够从数据中学习而无需明确编程。"
        }
    
    def train(self, train_data_path: str, output_dir: str = "./self_rag_model"):
        """训练Self-RAG模型"""
        print(f"开始训练Self-RAG模型,使用数据: {train_data_path}")
        
        # 准备训练数据集
        train_dataset = SelfRAGDataset(train_data_path, self.model.tokenizer, self.config)
        
        # 配置训练参数
        training_args = TrainingArguments(
            output_dir=output_dir,
            overwrite_output_dir=True,
            num_train_epochs=self.config.num_train_epochs,
            per_device_train_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            fp16=torch.cuda.is_available(),
            logging_steps=10,
            save_steps=100,
            eval_steps=None,
            save_total_limit=2,
            remove_unused_columns=False,
            report_to="none"
        )
        
        # 创建训练器
        trainer = SelfRAGTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=self._collate_fn
        )
        
        # 开始训练
        trainer.train()
        
        # 保存模型
        trainer.save_model(output_dir)
        self.model.tokenizer.save_pretrained(output_dir)
        
        self.is_trained = True
        print(f"训练完成,模型已保存到: {output_dir}")
    
    def _collate_fn(self, batch):
        """自定义批处理函数"""
        # 动态计算最大长度
        max_length = max(len(item["input_ids"]) for item in batch)
        
        # 填充输入
        input_ids = []
        attention_mask = []
        reflection_targets = {
            "retrieve_decision": [],
            "document_relevance": [],
            "content_support": [],
            "generation_quality": []
        }
        
        for item in batch:
            # 填充input_ids
            padded_input = item["input_ids"] + [self.model.tokenizer.pad_token_id] * (max_length - len(item["input_ids"]))
            input_ids.append(padded_input)
            
            # 创建attention mask
            mask = [1] * len(item["input_ids"]) + [0] * (max_length - len(item["input_ids"]))
            attention_mask.append(mask)
            
            # 收集反思目标
            for key in reflection_targets.keys():
                reflection_targets[key].append(item["reflection_targets"][key])
        
        # 转换为tensor
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.long)
        
        # 创建标签(用于语言建模,这里简化处理)
        labels = input_ids.clone()
        
        # 将反思标记的目标转换为tensor
        for key in reflection_targets:
            reflection_targets[key] = torch.tensor(reflection_targets[key], dtype=torch.long)
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "reflection_targets": reflection_targets
        }
    
    def generate(self, query: str, max_steps: int = 5) -> Dict[str, Any]:
        """生成回答"""
        print(f"\n处理查询: {query}")
        
        # 使用模型生成带反思的回答
        result = self.model.generate_with_reflection(
            query=query,
            retriever=self.retriever,
            max_steps=max_steps
        )
        
        # 记录交互
        self.interaction_log.append({
            "timestamp": np.datetime64('now'),
            "query": query,
            "result": result
        })
        
        return result
    
    def analyze_reflection_patterns(self) -> Dict[str, Any]:
        """分析反思模式"""
        if not self.interaction_log:
            return {"error": "没有可分析的交互记录"}
        
        # 统计反思标记的使用情况
        reflection_counts = defaultdict(int)
        retrieval_decisions = {"retrieve": 0, "no_retrieve": 0}
        
        for interaction in self.interaction_log:
            for record in interaction["result"]["reflection_records"]:
                reflection_type, decision = record
                reflection_counts[reflection_type] += 1
                
                if reflection_type == "retrieve_decision":
                    retrieval_decisions[decision] += 1
        
        # 计算统计信息
        total_interactions = len(self.interaction_log)
        
        return {
            "total_interactions": total_interactions,
            "reflection_counts": dict(reflection_counts),
            "retrieval_decision_rate": {
                "retrieve": retrieval_decisions["retrieve"] / total_interactions,
                "no_retrieve": retrieval_decisions["no_retrieve"] / total_interactions
            },
            "avg_steps_per_interaction": np.mean([
                interaction["result"]["steps"] 
                for interaction in self.interaction_log
            ])
        }
    
    def export_training_data(self, output_path: str):
        """导出训练数据"""
        training_examples = []
        
        for interaction in self.interaction_log:
            query = interaction["query"]
            generated_text = interaction["result"]["generated_text"]
            retrieved_docs = interaction["result"]["retrieved_docs"]
            
            # 创建训练示例
            example = {
                "query": query,
                "documents": retrieved_docs,
                "answer": generated_text,
                "reflection_records": interaction["result"]["reflection_records"]
            }
            
            training_examples.append(example)
        
        # 保存到文件
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(training_examples, f, ensure_ascii=False, indent=2)
        
        print(f"导出了 {len(training_examples)} 个训练示例到: {output_path}")

# ==================== 示例训练数据生成 ====================

def generate_sample_training_data(output_path: str = "sample_training_data.json"):
    """生成示例训练数据"""
    sample_data = [
        {
            "query": "什么是Self-RAG?",
            "documents": [
                "Self-RAG(自反思检索增强生成)是一种先进的RAG框架,它使模型能够自我评估检索到的文档和生成的内容。",
                "反思标记是Self-RAG中的特殊标记,用于控制生成过程和评估内容质量。"
            ],
            "answer": "Self-RAG是一种自反思的检索增强生成框架,它通过引入反思标记使模型能够自我评估检索到的文档和生成的内容质量,从而提高事实准确性和可靠性。"
        },
        {
            "query": "Python有哪些特点?",
            "documents": [
                "Python是一种广泛使用的高级编程语言,以其简洁性和可读性而闻名。",
                "Python支持多种编程范式,包括面向对象、命令式、函数式和过程式编程。"
            ],
            "answer": "Python是一种高级编程语言,主要特点包括简洁易读的语法、丰富的标准库和第三方库、跨平台兼容性,以及支持多种编程范式。"
        },
        {
            "query": "机器学习的基本概念是什么?",
            "documents": [
                "机器学习是人工智能的一个分支,使计算机能够从数据中学习而无需明确编程。",
                "常见的机器学习算法包括监督学习、无监督学习和强化学习。"
            ],
            "answer": "机器学习是人工智能的一个子领域,其核心思想是让计算机系统从数据中自动学习模式和规律,而无需进行明确的编程。主要类型包括监督学习、无监督学习和强化学习。"
        }
    ]
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(sample_data, f, ensure_ascii=False, indent=2)
    
    print(f"示例训练数据已生成: {output_path}")
    return output_path

# ==================== 主程序示例 ====================

def main():
    """Self-RAG系统演示"""
    print("=" * 60)
    print("Self-RAG 系统演示")
    print("=" * 60)
    
    # 1. 生成示例训练数据
    print("\n1. 生成示例训练数据...")
    train_data_path = generate_sample_training_data()
    
    # 2. 配置Self-RAG系统
    print("\n2. 配置Self-RAG系统...")
    config = SelfRAGConfig(
        base_model_name="gpt2",  # 使用较小的模型进行演示
        max_length=512,
        batch_size=1,
        num_train_epochs=1
    )
    
    # 3. 初始化Self-RAG系统
    print("\n3. 初始化Self-RAG系统...")
    self_rag = SelfRAGSystem(config)
    
    # 4. 训练模型(可选)
    print("\n4. 训练Self-RAG模型(简化演示)...")
    # 注意:完整训练需要大量计算资源,这里跳过实际训练
    # self_rag.train(train_data_path, output_dir="./trained_self_rag")
    
    # 5. 使用Self-RAG生成回答
    print("\n5. 使用Self-RAG生成回答...")
    
    test_queries = [
        "解释Self-RAG的工作原理",
        "Python编程语言有什么优势?",
        "机器学习和深度学习有什么区别?"
    ]
    
    for query in test_queries:
        print(f"\n{'='*40}")
        print(f"查询: {query}")
        print(f"{'='*40}")
        
        result = self_rag.generate(query, max_steps=3)
        
        print(f"\n生成的回答:")
        print(f"{result['generated_text']}")
        
        print(f"\n反思记录:")
        for reflection_type, decision in result['reflection_records']:
            print(f"  - {reflection_type}: {decision}")
    
    # 6. 分析反思模式
    print("\n6. 分析反思模式...")
    analysis = self_rag.analyze_reflection_patterns()
    print(f"交互统计: {json.dumps(analysis, indent=2, ensure_ascii=False)}")
    
    # 7. 导出训练数据
    print("\n7. 导出训练数据...")
    self_rag.export_training_data("self_rag_training_data.json")
    
    print("\n" + "=" * 60)
    print("Self-RAG演示完成")
    print("=" * 60)

if __name__ == "__main__":
    # 设置日志
    logging.basicConfig(level=logging.INFO)
    
    # 运行主程序
    main()

3.6 适用场景

1. 与CRAG、基础RAG的场景对比

理解Self-RAG的独特之处,最好通过对比:

  • 基础RAG :适用于知识边界明确、查询模式相对固定的场景。如基于固定产品手册的客服问答。它"有问必查",简单可靠,但对动态知识和复杂查询不灵活。

  • CRAG :专注于为单次检索结果进行"质检与修正"。适用于检索质量不稳定、且允许调用网络搜索等动态知识源的场景。它是对一次检索的增强。

  • Self-RAG :强在全链条的智能决策与精细化控制 。它不仅决定"要不要查"(自适应检索),还对"查到的怎么样"(相关性评估)、"答得对不对"(支持性评估)、"答得好不好"(质量评估)进行全程标记。这牺牲了一定的简单性,换来了在开放、复杂、高要求场景下的卓越表现

4、总结与展望:构建下一代RAG系统的技术工具箱

将这三种技术并列观察,我们可以清晰地看到提升RAG质量的多层次策略:

技术名称 核心焦点 关键机制 适用场景
Corrective RAG 单次检索结果的可靠性 检索评估、置信度分流、文档净化、网络扩充 事实准确性要求极高的问答、报告生成
Reflexion 智能体跨任务的持续学习 语言反思、情景记忆、经验复用 多步交互任务(如编程、调试、游戏)
Self-RAG 检索与生成过程的自适应控制 自适应检索决策、反思标记、一体化模型 开放域问答、需要精确引用的长文本生成

在实际应用中,这些技术并非互斥,而是可以相互借鉴和融合。例如,一个高级RAG系统可以:

  1. 采用 Self-RAG 的思路,让模型学会何时调用检索。

  2. 在检索完成后,引入 CRAG 的评估与净化模块对文档进行精炼。

  3. 在整个系统部署后,利用 Reflexion 框架,让系统通过用户反馈持续反思和优化自身的决策逻辑。

RAG技术的演进,正从简单的"检索+生成"管道,走向具备评估、反思、决策、学习 能力的智能知识工作流。对于技术人员而言,理解这些技术的核心理念,如同获得了不同功能的"乐高"模块,能够根据具体需求,设计和搭建出更强大、更可靠的智能应用系统。未来的方向,必将是这些能力的更深层次融合,以及更通用、更轻量化适配方法的出现。

相关推荐
不如语冰2 小时前
AI大模型入门1.3-python基础-类
人工智能·pytorch·python·类和方法
智能相对论2 小时前
【年度AI观察】2026,车企反攻智能硬件
人工智能·智能硬件
m0_466525292 小时前
AI医疗的东软答卷:从技术破局到产业融合
人工智能
学习3人组2 小时前
AI视觉Python方向专业技术名词
开发语言·人工智能·python
静听松涛1332 小时前
通用人工智能(AGI)的阶段性定义与里程碑
人工智能
落雨盛夏2 小时前
深度学习|李哥1
人工智能·深度学习
Blossom.1182 小时前
大模型分布式训练通信优化:从Ring All-Reduce到分层压缩的实战演进
人工智能·分布式·python·深度学习·神经网络·机器学习·迁移学习
程序员泠零澪回家种桔子2 小时前
RAG自查询:让AI精准检索的秘密武器
人工智能·后端·算法
猿小猴子2 小时前
主流 AI IDE 之一的 Claude Code 介绍
人工智能·claude code