HelloAgents 进阶篇 task03

在有了前面搭建的智能体框架之后,接下来需要给智能体增加两个核心能力,记忆系统和检索增强生成。

一、记忆系统

针对不同的场景,我们设计出了四种记忆模块。

工作记忆 (Working Memory),它扮演着智能体"短期记忆"的角色,主要用于存储当前对话的上下文信息。为确保高速访问和响应,其容量被有意限制(例如,默认50条),并且生命周期与单个会话绑定,会话结束后便会自动清理。

情景记忆 (Episodic Memory),它负责长期存储具体的交互事件和智能体的学习经历。与工作记忆不同,情景记忆包含了丰富的上下文信息,并支持按时间序列或主题进行回顾式检索,是智能体"复盘"和学习过往经验的基础。

语义记忆 (Semantic Memory),它存储的是更为抽象的知识、概念和规则。例如,通过对话了解到的用户偏好、需要长期遵守的指令或领域知识点,都适合存放在这里。这部分记忆具有高度的持久性和重要性,是智能体形成"知识体系"和进行关联推理的核心。

**感知记忆 (Perceptual Memory),**该模块专门处理图像、音频等多模态信息,并支持跨模态检索。其生命周期会根据信息的重要性和可用存储空间进行动态管理。

1.记忆工具基类

要想搭建记忆工具,首先需要搭建一个工具基类MemoryTool。

python 复制代码
"""记忆工具

为HelloAgents框架提供记忆能力的工具实现。
可以作为工具添加到任何Agent中,让Agent具备记忆功能。
"""

from typing import Dict, Any, List
from datetime import datetime

from ..base import Tool, ToolParameter
from ...memory import MemoryManager, MemoryConfig

class MemoryTool(Tool):
    """记忆工具

    为Agent提供记忆功能:
    - 添加记忆
    - 检索相关记忆
    - 获取记忆摘要
    - 管理记忆生命周期
    """

    def __init__(
        self,
        user_id: str = "default_user",
        memory_config: MemoryConfig = None,
        memory_types: List[str] = None
    ):
        super().__init__(
            name="memory",
            description="记忆工具 - 可以存储和检索对话历史、知识和经验"
        )

        # 初始化记忆管理器
        self.memory_config = memory_config or MemoryConfig()
        self.memory_types = memory_types or ["working", "episodic", "semantic"]

        self.memory_manager = MemoryManager(
            config=self.memory_config,
            user_id=user_id,
            enable_working="working" in self.memory_types,
            enable_episodic="episodic" in self.memory_types,
            enable_semantic="semantic" in self.memory_types,
            enable_perceptual="perceptual" in self.memory_types
        )

        # 会话状态
        self.current_session_id = None
        self.conversation_count = 0

    def run(self, parameters: Dict[str, Any]) -> str:
        """执行工具 - Tool基类要求的接口

        Args:
            parameters: 工具参数字典,必须包含action参数

        Returns:
            执行结果字符串
        """
        if not self.validate_parameters(parameters):
            return "❌ 参数验证失败:缺少必需的参数"

        action = parameters.get("action")
        # 移除action参数,传递其余参数给execute方法
        kwargs = {k: v for k, v in parameters.items() if k != "action"}

        return self.execute(action, **kwargs)

    def get_parameters(self) -> List[ToolParameter]:
        """获取工具参数定义 - Tool基类要求的接口"""
        return [
            ToolParameter(
                name="action",
                type="string",
                description=(
                    "要执行的操作:"
                    "add(添加记忆), search(搜索记忆), summary(获取摘要), stats(获取统计), "
                    "update(更新记忆), remove(删除记忆), forget(遗忘记忆), consolidate(整合记忆), clear_all(清空所有记忆)"
                ),
                required=True
            ),
            ToolParameter(name="content", type="string", description="记忆内容(add/update时可用)", required=False),
            ToolParameter(name="query", type="string", description="搜索查询(search时可用)", required=False),
            ToolParameter(name="memory_type", type="string", description="记忆类型:working, episodic, semantic, perceptual(默认:working)", required=False, default="working"),
            ToolParameter(name="importance", type="number", description="重要性分数,0.0-1.0(add/update时可用)", required=False),
            ToolParameter(name="limit", type="integer", description="搜索结果数量限制(默认:5)", required=False, default=5),
            ToolParameter(name="memory_id", type="string", description="目标记忆ID(update/remove时必需)", required=False),
            ToolParameter(name="strategy", type="string", description="遗忘策略:importance_based/time_based/capacity_based(forget时可用)", required=False, default="importance_based"),
            ToolParameter(name="threshold", type="number", description="遗忘阈值(forget时可用,默认0.1)", required=False, default=0.1),
            ToolParameter(name="max_age_days", type="integer", description="最大保留天数(forget策略为time_based时可用)", required=False, default=30),
            ToolParameter(name="from_type", type="string", description="整合来源类型(consolidate时可用,默认working)", required=False, default="working"),
            ToolParameter(name="to_type", type="string", description="整合目标类型(consolidate时可用,默认episodic)", required=False, default="episodic"),
            ToolParameter(name="importance_threshold", type="number", description="整合重要性阈值(默认0.7)", required=False, default=0.7),
        ]

    def execute(self, action: str, **kwargs) -> str:
        """执行记忆操作

        支持的操作:
        - add: 添加记忆
        - search: 搜索记忆
        - summary: 获取记忆摘要
        - stats: 获取统计信息
        """

        if action == "add":
            return self._add_memory(**kwargs)
        elif action == "search":
            return self._search_memory(**kwargs)
        elif action == "summary":
            return self._get_summary(**kwargs)
        elif action == "stats":
            return self._get_stats()
        elif action == "update":
            return self._update_memory(**kwargs)
        elif action == "remove":
            return self._remove_memory(**kwargs)
        elif action == "forget":
            return self._forget(**kwargs)
        elif action == "consolidate":
            return self._consolidate(**kwargs)
        elif action == "clear_all":
            return self._clear_all()
        else:
            return f"不支持的操作: {action}。支持的操作: add, search, summary, stats, update, remove, forget, consolidate, clear_all"

    def _add_memory(
        self,
        content: str,
        memory_type: str = "working",
        importance: float = 0.5,
        **metadata
    ) -> str:
        """添加记忆"""
        try:
            # 确保会话ID存在
            if self.current_session_id is None:
                self.current_session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

            # 添加会话信息到元数据
            metadata.update({
                "session_id": self.current_session_id,
                "timestamp": datetime.now().isoformat()
            })

            memory_id = self.memory_manager.add_memory(
                content=content,
                memory_type=memory_type,
                importance=importance,
                metadata=metadata,
                auto_classify=False  # 禁用自动分类,使用明确指定的类型
            )

            return f"✅ 记忆已添加 (ID: {memory_id[:8]}...)"

        except Exception as e:
            return f"❌ 添加记忆失败: {str(e)}"

    def _search_memory(
        self,
        query: str,
        limit: int = 5,
        memory_types: List[str] = None,
        memory_type: str = None,  # 添加单数形式的参数支持
        min_importance: float = 0.1
    ) -> str:
        """搜索记忆"""
        try:
            # 处理单数形式的memory_type参数
            if memory_type and not memory_types:
                memory_types = [memory_type]

            results = self.memory_manager.retrieve_memories(
                query=query,
                limit=limit,
                memory_types=memory_types,
                min_importance=min_importance
            )

            if not results:
                return f"🔍 未找到与 '{query}' 相关的记忆"

            # 格式化结果
            formatted_results = []
            formatted_results.append(f"🔍 找到 {len(results)} 条相关记忆:")

            for i, memory in enumerate(results, 1):
                memory_type_label = {
                    "working": "工作记忆",
                    "episodic": "情景记忆",
                    "semantic": "语义记忆",
                    "perceptual": "感知记忆"
                }.get(memory.memory_type, memory.memory_type)

                content_preview = memory.content[:80] + "..." if len(memory.content) > 80 else memory.content
                formatted_results.append(
                    f"{i}. [{memory_type_label}] {content_preview} (重要性: {memory.importance:.2f})"
                )

            return "\n".join(formatted_results)

        except Exception as e:
            return f"❌ 搜索记忆失败: {str(e)}"

    def _get_summary(self, limit: int = 10) -> str:
        """获取记忆摘要"""
        try:
            stats = self.memory_manager.get_memory_stats()

            summary_parts = [
                f"📊 记忆系统摘要",
                f"总记忆数: {stats['total_memories']}",
                f"当前会话: {self.current_session_id or '未开始'}",
                f"对话轮次: {self.conversation_count}"
            ]

            # 各类型记忆统计
            if stats['memories_by_type']:
                summary_parts.append("\n📋 记忆类型分布:")
                for memory_type, type_stats in stats['memories_by_type'].items():
                    count = type_stats.get('count', 0)
                    avg_importance = type_stats.get('avg_importance', 0)
                    type_label = {
                        "working": "工作记忆",
                        "episodic": "情景记忆",
                        "semantic": "语义记忆",
                        "perceptual": "感知记忆"
                    }.get(memory_type, memory_type)

                    summary_parts.append(f"  • {type_label}: {count} 条 (平均重要性: {avg_importance:.2f})")

            # 获取重要记忆
            important_memories = []
            for memory_type in self.memory_types:
                if memory_type in stats['memories_by_type']:
                    memories = self.memory_manager.retrieve_memories(
                        query="",
                        memory_types=[memory_type],
                        limit=3,
                        min_importance=0.7
                    )
                    important_memories.extend(memories)

            if important_memories:
                # 按重要性排序
                important_memories.sort(key=lambda x: x.importance, reverse=True)
                summary_parts.append(f"\n⭐ 重要记忆 (前{min(limit, len(important_memories))}条):")

                for i, memory in enumerate(important_memories[:limit], 1):
                    content_preview = memory.content[:60] + "..." if len(memory.content) > 60 else memory.content
                    summary_parts.append(f"  {i}. {content_preview} (重要性: {memory.importance:.2f})")

            return "\n".join(summary_parts)

        except Exception as e:
            return f"❌ 获取摘要失败: {str(e)}"

    def _get_stats(self) -> str:
        """获取统计信息"""
        try:
            stats = self.memory_manager.get_memory_stats()

            stats_info = [
                f"📈 记忆系统统计",
                f"总记忆数: {stats['total_memories']}",
                f"启用的记忆类型: {', '.join(stats['enabled_types'])}",
                f"会话ID: {self.current_session_id or '未开始'}",
                f"对话轮次: {self.conversation_count}"
            ]

            return "\n".join(stats_info)

        except Exception as e:
            return f"❌ 获取统计信息失败: {str(e)}"

    def auto_record_conversation(self, user_input: str, agent_response: str):
        """自动记录对话

        这个方法可以被Agent调用来自动记录对话历史
        """
        self.conversation_count += 1
        # 记录用户输入
        self._add_memory(
            content=f"用户: {user_input}",
            memory_type="working",
            importance=0.6,
            type="user_input",
            conversation_id=self.conversation_count
        )

        # 记录Agent响应
        self._add_memory(
            content=f"助手: {agent_response}",
            memory_type="working",
            importance=0.7,
            type="agent_response",
            conversation_id=self.conversation_count
        )

        # 如果是重要对话,记录为情景记忆
        if len(agent_response) > 100 or "重要" in user_input or "记住" in user_input:
            interaction_content = f"对话 - 用户: {user_input}\n助手: {agent_response}"
            self._add_memory(
                content=interaction_content,
                memory_type="episodic",
                importance=0.8,
                type="interaction",
                conversation_id=self.conversation_count
            )

    def _update_memory(self, memory_id: str, content: str = None, importance: float = None, **metadata) -> str:
        """更新记忆"""
        try:
            success = self.memory_manager.update_memory(
                memory_id=memory_id,
                content=content,
                importance=importance,
                metadata=metadata or None
            )
            return "✅ 记忆已更新" if success else "⚠️ 未找到要更新的记忆"
        except Exception as e:
            return f"❌ 更新记忆失败: {str(e)}"

    def _remove_memory(self, memory_id: str) -> str:
        """删除记忆"""
        try:
            success = self.memory_manager.remove_memory(memory_id)
            return "✅ 记忆已删除" if success else "⚠️ 未找到要删除的记忆"
        except Exception as e:
            return f"❌ 删除记忆失败: {str(e)}"

    def _forget(self, strategy: str = "importance_based", threshold: float = 0.1, max_age_days: int = 30) -> str:
        """遗忘记忆(支持多种策略)"""
        try:
            count = self.memory_manager.forget_memories(
                strategy=strategy,
                threshold=threshold,
                max_age_days=max_age_days
            )
            return f"🧹 已遗忘 {count} 条记忆(策略: {strategy})"
        except Exception as e:
            return f"❌ 遗忘记忆失败: {str(e)}"

    def _consolidate(self, from_type: str = "working", to_type: str = "episodic", importance_threshold: float = 0.7) -> str:
        """整合记忆(将重要的短期记忆提升为长期记忆)"""
        try:
            count = self.memory_manager.consolidate_memories(
                from_type=from_type,
                to_type=to_type,
                importance_threshold=importance_threshold,
            )
            return f"🔄 已整合 {count} 条记忆为长期记忆({from_type} → {to_type},阈值={importance_threshold})"
        except Exception as e:
            return f"❌ 整合记忆失败: {str(e)}"

    def _clear_all(self) -> str:
        """清空所有记忆"""
        try:
            self.memory_manager.clear_all_memories()
            return "🧽 已清空所有记忆"
        except Exception as e:
            return f"❌ 清空记忆失败: {str(e)}"

    def add_knowledge(self, content: str, importance: float = 0.9):
        """添加知识到语义记忆

        便捷方法,用于添加重要知识
        """
        return self._add_memory(
            content=content,
            memory_type="semantic",
            importance=importance,
            knowledge_type="factual",
            source="manual"
        )

    def get_context_for_query(self, query: str, limit: int = 3) -> str:
        """为查询获取相关上下文

        这个方法可以被Agent调用来获取相关的记忆上下文
        """
        results = self.memory_manager.retrieve_memories(
            query=query,
            limit=limit,
            min_importance=0.3
        )

        if not results:
            return ""

        context_parts = ["相关记忆:"]
        for memory in results:
            context_parts.append(f"- {memory.content}")

        return "\n".join(context_parts)

    def clear_session(self):
        """清除当前会话"""
        self.current_session_id = None
        self.conversation_count = 0

        # 清理工作记忆
        if hasattr(self.memory_manager, 'working_memory'):
            self.memory_manager.working_memory.clear()

    def consolidate_memories(self):
        """整合记忆"""
        return self.memory_manager.consolidate_memories()

    def forget_old_memories(self, max_age_days: int = 30):
        """遗忘旧记忆"""
        return self.memory_manager.forget_memories(
            strategy="time_based",
            max_age_days=max_age_days
        )

在有了MemoryTool之后,我们还需要有一个记忆管理的模块MemoryManager。

python 复制代码
"""记忆管理器 - 记忆核心层的统一管理接口"""

from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import uuid
import logging

from .base import MemoryItem, MemoryConfig
from .types.working import WorkingMemory
from .types.episodic import EpisodicMemory
from .types.semantic import SemanticMemory
from .types.perceptual import PerceptualMemory
# 存储和检索功能已被各记忆类型内部实现替代

logger = logging.getLogger(__name__)

class MemoryManager:
    """记忆管理器 - 统一的记忆操作接口
    
    负责:
    - 记忆生命周期管理
    - 记忆优先级和重要性评估
    - 记忆遗忘和清理机制
    - 多类型记忆的协调管理
    """
    
    def __init__(
        self,
        config: Optional[MemoryConfig] = None,
        user_id: str = "default_user",
        enable_working: bool = True,
        enable_episodic: bool = True,
        enable_semantic: bool = True,
        enable_perceptual: bool = False
    ):
        self.config = config or MemoryConfig()
        self.user_id = user_id
        
        # 存储和检索功能已移至各记忆类型内部实现
        
        # 初始化各类型记忆
        self.memory_types = {}
        
        if enable_working:
            self.memory_types['working'] = WorkingMemory(self.config)
        
        if enable_episodic:
            self.memory_types['episodic'] = EpisodicMemory(self.config)
            
        if enable_semantic:
            self.memory_types['semantic'] = SemanticMemory(self.config)
            
        if enable_perceptual:
            self.memory_types['perceptual'] = PerceptualMemory(self.config)
        
        logger.info(f"MemoryManager初始化完成,启用记忆类型: {list(self.memory_types.keys())}")
    
    def add_memory(
        self,
        content: str,
        memory_type: str = "working",
        importance: Optional[float] = None,
        metadata: Optional[Dict[str, Any]] = None,
        auto_classify: bool = True
    ) -> str:
        """添加记忆
        
        Args:
            content: 记忆内容
            memory_type: 记忆类型
            importance: 重要性分数 (0-1)
            metadata: 元数据
            auto_classify: 是否自动分类到合适的记忆类型
            
        Returns:
            记忆ID
        """
        # 自动分类记忆类型
        if auto_classify:
            memory_type = self._classify_memory_type(content, metadata)
        
        # 计算重要性
        if importance is None:
            importance = self._calculate_importance(content, metadata)
        
        # 创建记忆项
        memory_item = MemoryItem(
            id=str(uuid.uuid4()),
            content=content,
            memory_type=memory_type,
            user_id=self.user_id,
            timestamp=datetime.now(),
            importance=importance,
            metadata=metadata or {}
        )
        
        # 添加到对应的记忆类型
        if memory_type in self.memory_types:
            memory_id = self.memory_types[memory_type].add(memory_item)
            logger.debug(f"添加记忆到 {memory_type}: {memory_id}")
            return memory_id
        else:
            raise ValueError(f"不支持的记忆类型: {memory_type}")
    
    def retrieve_memories(
        self,
        query: str,
        memory_types: Optional[List[str]] = None,
        limit: int = 10,
        min_importance: float = 0.0,
        time_range: Optional[tuple] = None
    ) -> List[MemoryItem]:
        """检索记忆
        
        Args:
            query: 查询内容
            memory_types: 要检索的记忆类型列表
            limit: 返回数量限制
            min_importance: 最小重要性阈值
            time_range: 时间范围 (start_time, end_time)
            
        Returns:
            检索到的记忆列表
        """
        if memory_types is None:
            memory_types = list(self.memory_types.keys())
        
        # 从各个记忆类型中检索
        all_results = []
        per_type_limit = max(1, limit // len(memory_types))

        for memory_type in memory_types:
            if memory_type in self.memory_types:
                memory_instance = self.memory_types[memory_type]
                try:
                    # 使用各个记忆类型自己的检索方法
                    type_results = memory_instance.retrieve(
                        query=query,
                        limit=per_type_limit,
                        min_importance=min_importance,
                        user_id=self.user_id
                    )
                    all_results.extend(type_results)
                except Exception as e:
                    logger.warning(f"检索 {memory_type} 记忆时出错: {e}")
                    continue

        # 按重要性和相关性排序
        all_results.sort(key=lambda x: x.importance, reverse=True)
        return all_results[:limit]
    
    def update_memory(
        self,
        memory_id: str,
        content: Optional[str] = None,
        importance: Optional[float] = None,
        metadata: Optional[Dict[str, Any]] = None
    ) -> bool:
        """更新记忆
        
        Args:
            memory_id: 记忆ID
            content: 新内容
            importance: 新重要性
            metadata: 新元数据
            
        Returns:
            是否更新成功
        """
        # 查找记忆所在的类型
        for memory_type, memory_instance in self.memory_types.items():
            if memory_instance.has_memory(memory_id):
                return memory_instance.update(memory_id, content, importance, metadata)
        
        logger.warning(f"未找到记忆: {memory_id}")
        return False
    
    def remove_memory(self, memory_id: str) -> bool:
        """删除记忆
        
        Args:
            memory_id: 记忆ID
            
        Returns:
            是否删除成功
        """
        for memory_type, memory_instance in self.memory_types.items():
            if memory_instance.has_memory(memory_id):
                return memory_instance.remove(memory_id)
        
        logger.warning(f"未找到记忆: {memory_id}")
        return False
    
    def forget_memories(
        self,
        strategy: str = "importance_based",
        threshold: float = 0.1,
        max_age_days: int = 30
    ) -> int:
        """记忆遗忘机制
        
        Args:
            strategy: 遗忘策略 ("importance_based", "time_based", "capacity_based")
            threshold: 遗忘阈值
            max_age_days: 最大保存天数
            
        Returns:
            遗忘的记忆数量
        """
        total_forgotten = 0
        
        for memory_type, memory_instance in self.memory_types.items():
            if hasattr(memory_instance, 'forget'):
                forgotten = memory_instance.forget(strategy, threshold, max_age_days)
                total_forgotten += forgotten

        logger.info(f"记忆遗忘完成: {total_forgotten} 条记忆")
        return total_forgotten

    def consolidate_memories(
        self,
        from_type: str = "working",
        to_type: str = "episodic",
        importance_threshold: float = 0.7
    ) -> int:
        """记忆整合 - 将重要的短期记忆转换为长期记忆

        Args:
            from_type: 源记忆类型
            to_type: 目标记忆类型
            importance_threshold: 重要性阈值

        Returns:
            整合的记忆数量
        """
        if from_type not in self.memory_types or to_type not in self.memory_types:
            logger.warning(f"记忆类型不存在: {from_type} -> {to_type}")
            return 0

        # 获取高重要性的源记忆
        source_memory = self.memory_types[from_type]
        target_memory = self.memory_types[to_type]

        # 获取需要整合的记忆
        all_memories = source_memory.get_all()
        candidates = [
            m for m in all_memories
            if m.importance >= importance_threshold
        ]

        consolidated_count = 0
        for memory in candidates:
            # 移动到目标记忆类型
            if source_memory.remove(memory.id):
                memory.memory_type = to_type
                memory.importance *= 1.1  # 提升重要性
                target_memory.add(memory)
                consolidated_count += 1

        logger.info(f"记忆整合完成: {consolidated_count} 条记忆从 {from_type} 转移到 {to_type}")
        return consolidated_count

    def get_memory_stats(self) -> Dict[str, Any]:
        """获取记忆统计信息"""
        stats = {
            "user_id": self.user_id,
            "enabled_types": list(self.memory_types.keys()),
            "total_memories": 0,
            "memories_by_type": {},
            "config": {
                "max_capacity": self.config.max_capacity,
                "importance_threshold": self.config.importance_threshold,
                "decay_factor": self.config.decay_factor
            }
        }

        for memory_type, memory_instance in self.memory_types.items():
            type_stats = memory_instance.get_stats()
            stats["memories_by_type"][memory_type] = type_stats
            # 使用count字段(活跃记忆数),而不是total_count(包含已遗忘的)
            stats["total_memories"] += type_stats.get("count", 0)

        return stats

    def clear_all_memories(self):
        """清空所有记忆"""
        for memory_type, memory_instance in self.memory_types.items():
            memory_instance.clear()
        logger.info("所有记忆已清空")




    def _classify_memory_type(self, content: str, metadata: Optional[Dict[str, Any]]) -> str:
        """自动分类记忆类型"""
        if metadata and metadata.get("type"):
            return metadata["type"]
        
        # 简单的分类逻辑,可以扩展为更复杂的分类器
        if self._is_episodic_content(content):
            return "episodic"
        elif self._is_semantic_content(content):
            return "semantic"
        else:
            return "working"
    
    def _is_episodic_content(self, content: str) -> bool:
        """判断是否为情景记忆内容"""
        episodic_keywords = ["昨天", "今天", "明天", "上次", "记得", "发生", "经历"]
        return any(keyword in content for keyword in episodic_keywords)
    
    def _is_semantic_content(self, content: str) -> bool:
        """判断是否为语义记忆内容"""
        semantic_keywords = ["定义", "概念", "规则", "知识", "原理", "方法"]
        return any(keyword in content for keyword in semantic_keywords)
    
    def _calculate_importance(self, content: str, metadata: Optional[Dict[str, Any]]) -> float:
        """计算记忆重要性"""
        importance = 0.5  # 基础重要性
        
        # 基于内容长度
        if len(content) > 100:
            importance += 0.1
        
        # 基于关键词
        important_keywords = ["重要", "关键", "必须", "注意", "警告", "错误"]
        if any(keyword in content for keyword in important_keywords):
            importance += 0.2
        
        # 基于元数据
        if metadata:
            if metadata.get("priority") == "high":
                importance += 0.3
            elif metadata.get("priority") == "low":
                importance -= 0.2
        
        return max(0.0, min(1.0, importance))
    

    def __str__(self) -> str:
        stats = self.get_memory_stats()
        return f"MemoryManager(user={self.user_id}, total={stats['total_memories']})"

2.四种记忆

有了上述的工具之后就可以着手搭建四种记忆模块。

2.1 工作记忆(WorkingMemory)

python 复制代码
"""工作记忆实现

按照第8章架构设计的工作记忆,提供:
- 短期上下文管理
- 容量和时间限制
- 优先级管理
- 自动清理机制
"""

from typing import List, Dict, Any
from datetime import datetime, timedelta
import heapq

from ..base import BaseMemory, MemoryItem, MemoryConfig

class WorkingMemory(BaseMemory):
    """工作记忆实现
    
    特点:
    - 容量有限(通常10-20条记忆)
    - 时效性强(会话级别)
    - 优先级管理
    - 自动清理过期记忆
    """
    
    def __init__(self, config: MemoryConfig, storage_backend=None):
        super().__init__(config, storage_backend)
        
        # 工作记忆特定配置
        self.max_capacity = self.config.working_memory_capacity
        self.max_tokens = self.config.working_memory_tokens
        # 纯内存TTL(分钟),可通过在 MemoryConfig 上挂载 working_memory_ttl_minutes 覆盖
        self.max_age_minutes = getattr(self.config, 'working_memory_ttl_minutes', 120)
        self.current_tokens = 0
        self.session_start = datetime.now()
        
        # 内存存储(工作记忆不需要持久化)
        self.memories: List[MemoryItem] = []
        
        # 使用优先级队列管理记忆
        self.memory_heap = []  # (priority, timestamp, memory_item)
    
    def add(self, memory_item: MemoryItem) -> str:
        """添加工作记忆"""
        # 过期清理
        self._expire_old_memories()
        # 计算优先级(重要性 + 时间衰减)
        priority = self._calculate_priority(memory_item)
        
        # 添加到堆中
        heapq.heappush(self.memory_heap, (-priority, memory_item.timestamp, memory_item))
        self.memories.append(memory_item)
        
        # 更新token计数
        self.current_tokens += len(memory_item.content.split())
        
        # 检查容量限制
        self._enforce_capacity_limits()
        
        return memory_item.id
    
    def retrieve(self, query: str, limit: int = 5, user_id: str = None, **kwargs) -> List[MemoryItem]:
        """检索工作记忆 - 混合语义向量检索和关键词匹配"""
        # 过期清理
        self._expire_old_memories()
        if not self.memories:
            return []

        # 过滤已遗忘的记忆
        active_memories = [m for m in self.memories if not m.metadata.get("forgotten", False)]
        
        # 按用户ID过滤(如果提供)
        filtered_memories = active_memories
        if user_id:
            filtered_memories = [m for m in active_memories if m.user_id == user_id]

        if not filtered_memories:
            return []

        # 尝试语义向量检索(如果有嵌入模型)
        vector_scores = {}
        try:
            # 简单的语义相似度计算(使用TF-IDF或其他轻量级方法)
            from sklearn.feature_extraction.text import TfidfVectorizer
            from sklearn.metrics.pairwise import cosine_similarity
            import numpy as np
            
            # 准备文档
            documents = [query] + [m.content for m in filtered_memories]
            
            # TF-IDF向量化
            vectorizer = TfidfVectorizer(stop_words=None, lowercase=True)
            tfidf_matrix = vectorizer.fit_transform(documents)
            
            # 计算相似度
            query_vector = tfidf_matrix[0:1]
            doc_vectors = tfidf_matrix[1:]
            similarities = cosine_similarity(query_vector, doc_vectors).flatten()
            
            # 存储向量分数
            for i, memory in enumerate(filtered_memories):
                vector_scores[memory.id] = similarities[i]
                
        except Exception as e:
            # 如果向量检索失败,回退到关键词匹配
            vector_scores = {}

        # 计算最终分数
        query_lower = query.lower()
        scored_memories = []
        
        for memory in filtered_memories:
            content_lower = memory.content.lower()
            
            # 获取向量分数(如果有)
            vector_score = vector_scores.get(memory.id, 0.0)
            
            # 关键词匹配分数
            keyword_score = 0.0
            if query_lower in content_lower:
                keyword_score = len(query_lower) / len(content_lower)
            else:
                # 分词匹配
                query_words = set(query_lower.split())
                content_words = set(content_lower.split())
                intersection = query_words.intersection(content_words)
                if intersection:
                    keyword_score = len(intersection) / len(query_words.union(content_words)) * 0.8

            # 混合分数:向量检索 + 关键词匹配
            if vector_score > 0:
                base_relevance = vector_score * 0.7 + keyword_score * 0.3
            else:
                base_relevance = keyword_score
            
            # 时间衰减
            time_decay = self._calculate_time_decay(memory.timestamp)
            base_relevance *= time_decay
            
            # 重要性权重
            importance_weight = 0.8 + (memory.importance * 0.4)
            final_score = base_relevance * importance_weight
            
            if final_score > 0:
                scored_memories.append((final_score, memory))

        # 按分数排序并返回
        scored_memories.sort(key=lambda x: x[0], reverse=True)
        return [memory for _, memory in scored_memories[:limit]]
    
    def update(
        self,
        memory_id: str,
        content: str = None,
        importance: float = None,
        metadata: Dict[str, Any] = None
    ) -> bool:
        """更新工作记忆"""
        for memory in self.memories:
            if memory.id == memory_id:
                old_tokens = len(memory.content.split())
                
                if content is not None:
                    memory.content = content
                    # 更新token计数
                    new_tokens = len(content.split())
                    self.current_tokens = self.current_tokens - old_tokens + new_tokens
                
                if importance is not None:
                    memory.importance = importance
                
                if metadata is not None:
                    memory.metadata.update(metadata)
                
                # 重新计算优先级并更新堆
                self._update_heap_priority(memory)
                
                return True
        return False
    
    def remove(self, memory_id: str) -> bool:
        """删除工作记忆"""
        for i, memory in enumerate(self.memories):
            if memory.id == memory_id:
                # 从列表中删除
                removed_memory = self.memories.pop(i)
                
                # 从堆中删除(标记删除)
                self._mark_deleted_in_heap(memory_id)
                
                # 更新token计数
                self.current_tokens -= len(removed_memory.content.split())
                self.current_tokens = max(0, self.current_tokens)
                
                return True
        return False
    
    def has_memory(self, memory_id: str) -> bool:
        """检查记忆是否存在"""
        return any(memory.id == memory_id for memory in self.memories)
    
    def clear(self):
        """清空所有工作记忆"""
        self.memories.clear()
        self.memory_heap.clear()
        self.current_tokens = 0
    
    def get_stats(self) -> Dict[str, Any]:
        """获取工作记忆统计信息"""
        # 过期清理(惰性)
        self._expire_old_memories()
        
        # 工作记忆中的记忆都是活跃的(已遗忘的记忆会被直接删除)
        active_memories = self.memories
        
        return {
            "count": len(active_memories),  # 活跃记忆数量
            "forgotten_count": 0,  # 工作记忆中已遗忘的记忆会被直接删除
            "total_count": len(self.memories),  # 总记忆数量
            "current_tokens": self.current_tokens,
            "max_capacity": self.max_capacity,
            "max_tokens": self.max_tokens,
            "max_age_minutes": self.max_age_minutes,
            "session_duration_minutes": (datetime.now() - self.session_start).total_seconds() / 60,
            "avg_importance": sum(m.importance for m in active_memories) / len(active_memories) if active_memories else 0.0,
            "capacity_usage": len(active_memories) / self.max_capacity if self.max_capacity > 0 else 0.0,
            "token_usage": self.current_tokens / self.max_tokens if self.max_tokens > 0 else 0.0,
            "memory_type": "working"
        }
    
    def get_recent(self, limit: int = 10) -> List[MemoryItem]:
        """获取最近的记忆"""
        sorted_memories = sorted(
            self.memories, 
            key=lambda x: x.timestamp, 
            reverse=True
        )
        return sorted_memories[:limit]
    
    def get_important(self, limit: int = 10) -> List[MemoryItem]:
        """获取重要记忆"""
        sorted_memories = sorted(
            self.memories,
            key=lambda x: x.importance,
            reverse=True
        )
        return sorted_memories[:limit]

    def get_all(self) -> List[MemoryItem]:
        """获取所有记忆"""
        return self.memories.copy()
    
    def get_context_summary(self, max_length: int = 500) -> str:
        """获取上下文摘要"""
        if not self.memories:
            return "No working memories available."
        
        # 按重要性和时间排序
        sorted_memories = sorted(
            self.memories,
            key=lambda m: (m.importance, m.timestamp),
            reverse=True
        )
        
        summary_parts = []
        current_length = 0
        
        for memory in sorted_memories:
            content = memory.content
            if current_length + len(content) <= max_length:
                summary_parts.append(content)
                current_length += len(content)
            else:
                # 截断最后一个记忆
                remaining = max_length - current_length
                if remaining > 50:  # 至少保留50个字符
                    summary_parts.append(content[:remaining] + "...")
                break
        
        return "Working Memory Context:\n" + "\n".join(summary_parts)
    
    def forget(self, strategy: str = "importance_based", threshold: float = 0.1, max_age_days: int = 1) -> int:
        """工作记忆遗忘机制"""
        forgotten_count = 0
        current_time = datetime.now()
        
        to_remove = []
        
        # 始终先执行TTL过期(分钟级)
        cutoff_ttl = current_time - timedelta(minutes=self.max_age_minutes)
        for memory in self.memories:
            if memory.timestamp < cutoff_ttl:
                to_remove.append(memory.id)
        
        if strategy == "importance_based":
            # 删除低重要性记忆
            for memory in self.memories:
                if memory.importance < threshold:
                    to_remove.append(memory.id)
        
        elif strategy == "time_based":
            # 删除过期记忆(工作记忆通常以小时计算)
            cutoff_time = current_time - timedelta(hours=max_age_days * 24)
            for memory in self.memories:
                if memory.timestamp < cutoff_time:
                    to_remove.append(memory.id)
        
        elif strategy == "capacity_based":
            # 删除超出容量的记忆
            if len(self.memories) > self.max_capacity:
                # 按优先级排序,删除最低的
                sorted_memories = sorted(
                    self.memories,
                    key=lambda m: self._calculate_priority(m)
                )
                excess_count = len(self.memories) - self.max_capacity
                for memory in sorted_memories[:excess_count]:
                    to_remove.append(memory.id)
        
        # 执行删除
        for memory_id in to_remove:
            if self.remove(memory_id):
                forgotten_count += 1
        
        return forgotten_count
    
    def _calculate_priority(self, memory: MemoryItem) -> float:
        """计算记忆优先级"""
        # 基础优先级 = 重要性
        priority = memory.importance
        
        # 时间衰减
        time_decay = self._calculate_time_decay(memory.timestamp)
        priority *= time_decay
        
        return priority
    
    def _calculate_time_decay(self, timestamp: datetime) -> float:
        """计算时间衰减因子"""
        time_diff = datetime.now() - timestamp
        hours_passed = time_diff.total_seconds() / 3600
        
        # 指数衰减(工作记忆衰减更快)
        decay_factor = self.config.decay_factor ** (hours_passed / 6)  # 每6小时衰减
        return max(0.1, decay_factor)  # 最小保持10%的权重
    
    def _enforce_capacity_limits(self):
        """强制执行容量限制"""
        # 检查记忆数量限制
        while len(self.memories) > self.max_capacity:
            self._remove_lowest_priority_memory()
        
        # 检查token限制
        while self.current_tokens > self.max_tokens:
            self._remove_lowest_priority_memory()

    def _expire_old_memories(self):
        """按TTL清理过期记忆,并同步更新堆与token计数"""
        if not self.memories:
            return
        cutoff_time = datetime.now() - timedelta(minutes=self.max_age_minutes)
        # 过滤保留的记忆
        kept: List[MemoryItem] = []
        removed_token_sum = 0
        for m in self.memories:
            if m.timestamp >= cutoff_time:
                kept.append(m)
            else:
                removed_token_sum += len(m.content.split())
        if len(kept) == len(self.memories):
            return
        # 覆盖列表与token
        self.memories = kept
        self.current_tokens = max(0, self.current_tokens - removed_token_sum)
        # 重建堆
        self.memory_heap = []
        for mem in self.memories:
            priority = self._calculate_priority(mem)
            heapq.heappush(self.memory_heap, (-priority, mem.timestamp, mem))
    
    def _remove_lowest_priority_memory(self):
        """删除优先级最低的记忆"""
        if not self.memories:
            return
        
        # 找到优先级最低的记忆
        lowest_priority = float('inf')
        lowest_memory = None
        
        for memory in self.memories:
            priority = self._calculate_priority(memory)
            if priority < lowest_priority:
                lowest_priority = priority
                lowest_memory = memory
        
        if lowest_memory:
            self.remove(lowest_memory.id)
    
    def _update_heap_priority(self, memory: MemoryItem):
        """更新堆中记忆的优先级"""
        # 简单实现:重建堆
        self.memory_heap = []
        for mem in self.memories:
            priority = self._calculate_priority(mem)
            heapq.heappush(self.memory_heap, (-priority, mem.timestamp, mem))
    
    def _mark_deleted_in_heap(self, memory_id: str):
        """在堆中标记删除的记忆"""
        # 由于heapq不支持直接删除,我们标记为已删除
        # 在后续操作中会被清理
        pass

2.2 情景记忆(EpisodicMemory)

python 复制代码
"""情景记忆实现

按照第8章架构设计的情景记忆,提供:
- 具体交互事件存储
- 时间序列组织
- 上下文丰富的记忆
- 模式识别能力
"""

from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime, timedelta
import os
import math
import json
import logging

logger = logging.getLogger(__name__)

from ..base import BaseMemory, MemoryItem, MemoryConfig
from ..storage import SQLiteDocumentStore, QdrantVectorStore
from ..embedding import get_text_embedder, get_dimension

class Episode:
    """情景记忆中的单个情景"""
    
    def __init__(
        self,
        episode_id: str,
        user_id: str,
        session_id: str,
        timestamp: datetime,
        content: str,
        context: Dict[str, Any],
        outcome: Optional[str] = None,
        importance: float = 0.5
    ):
        self.episode_id = episode_id
        self.user_id = user_id
        self.session_id = session_id
        self.timestamp = timestamp
        self.content = content
        self.context = context
        self.outcome = outcome
        self.importance = importance

class EpisodicMemory(BaseMemory):
    """情景记忆实现
    
    特点:
    - 存储具体的交互事件
    - 包含丰富的上下文信息
    - 按时间序列组织
    - 支持模式识别和回溯
    """
    
    def __init__(self, config: MemoryConfig, storage_backend=None):
        super().__init__(config, storage_backend)
        
        # 本地缓存(内存)
        self.episodes: List[Episode] = []
        self.sessions: Dict[str, List[str]] = {}  # session_id -> episode_ids
        
        # 模式识别缓存
        self.patterns_cache = {}
        self.last_pattern_analysis = None

        # 权威文档存储(SQLite)
        db_dir = self.config.storage_path if hasattr(self.config, 'storage_path') else "./memory_data"
        os.makedirs(db_dir, exist_ok=True)
        db_path = os.path.join(db_dir, "memory.db")
        self.doc_store = SQLiteDocumentStore(db_path=db_path)

        # 统一嵌入模型(多语言,默认384维)
        self.embedder = get_text_embedder()

        # 向量存储(Qdrant - 使用连接管理器避免重复连接)
        from ..storage.qdrant_store import QdrantConnectionManager
        qdrant_url = os.getenv("QDRANT_URL")
        qdrant_api_key = os.getenv("QDRANT_API_KEY")
        self.vector_store = QdrantConnectionManager.get_instance(
            url=qdrant_url,
            api_key=qdrant_api_key,
            collection_name=os.getenv("QDRANT_COLLECTION", "hello_agents_vectors"),
            vector_size=get_dimension(getattr(self.embedder, 'dimension', 384)),
            distance=os.getenv("QDRANT_DISTANCE", "cosine")
        )
    
    def add(self, memory_item: MemoryItem) -> str:
        """添加情景记忆"""
        # 从元数据中提取情景信息
        session_id = memory_item.metadata.get("session_id", "default_session")
        context = memory_item.metadata.get("context", {})
        outcome = memory_item.metadata.get("outcome")
        participants = memory_item.metadata.get("participants", [])
        tags = memory_item.metadata.get("tags", [])
        
        # 创建情景(内存缓存)
        episode = Episode(
            episode_id=memory_item.id,
            user_id=memory_item.user_id,
            session_id=session_id,
            timestamp=memory_item.timestamp,
            content=memory_item.content,
            context=context,
            outcome=outcome,
            importance=memory_item.importance
        )
        self.episodes.append(episode)
        if session_id not in self.sessions:
            self.sessions[session_id] = []
        self.sessions[session_id].append(episode.episode_id)

        # 1) 权威存储(SQLite)
        ts_int = int(memory_item.timestamp.timestamp())
        self.doc_store.add_memory(
            memory_id=memory_item.id,
            user_id=memory_item.user_id,
            content=memory_item.content,
            memory_type="episodic",
            timestamp=ts_int,
            importance=memory_item.importance,
            properties={
                "session_id": session_id,
                "context": context,
                "outcome": outcome,
                "participants": participants,
                "tags": tags
            }
        )

        # 2) 向量索引(Qdrant)
        try:
            embedding = self.embedder.encode(memory_item.content)
            if hasattr(embedding, "tolist"):
                embedding = embedding.tolist()
            self.vector_store.add_vectors(
                vectors=[embedding],
                metadata=[{
                    "memory_id": memory_item.id,
                    "user_id": memory_item.user_id,
                    "memory_type": "episodic",
                    "importance": memory_item.importance,
                    "session_id": session_id,
                    "content": memory_item.content
                }],
                ids=[memory_item.id]
            )
        except Exception:
            # 向量入库失败不影响权威存储
            pass

        return memory_item.id
    
    def retrieve(self, query: str, limit: int = 5, **kwargs) -> List[MemoryItem]:
        """检索情景记忆(结构化过滤 + 语义向量检索)"""
        user_id = kwargs.get("user_id")
        session_id = kwargs.get("session_id")
        time_range: Optional[Tuple[datetime, datetime]] = kwargs.get("time_range")
        importance_threshold: Optional[float] = kwargs.get("importance_threshold")

        # 结构化过滤候选(来自权威库)
        candidate_ids: Optional[set] = None
        if time_range is not None or importance_threshold is not None:
            start_ts = int(time_range[0].timestamp()) if time_range else None
            end_ts = int(time_range[1].timestamp()) if time_range else None
            docs = self.doc_store.search_memories(
                user_id=user_id,
                memory_type="episodic",
                start_time=start_ts,
                end_time=end_ts,
                importance_threshold=importance_threshold,
                limit=1000
            )
            candidate_ids = {d["memory_id"] for d in docs}

        # 向量检索(Qdrant)
        try:
            query_vec = self.embedder.encode(query)
            if hasattr(query_vec, "tolist"):
                query_vec = query_vec.tolist()
            where = {"memory_type": "episodic"}
            if user_id:
                where["user_id"] = user_id
            hits = self.vector_store.search_similar(
                query_vector=query_vec,
                limit=max(limit * 5, 20),
                where=where
            )
        except Exception:
            hits = []

        # 过滤与重排
        now_ts = int(datetime.now().timestamp())
        results: List[Tuple[float, MemoryItem]] = []
        seen = set()
        for hit in hits:
            meta = hit.get("metadata", {})
            mem_id = meta.get("memory_id")
            if not mem_id or mem_id in seen:
                continue
            
            # 检查是否已遗忘
            episode = next((e for e in self.episodes if e.episode_id == mem_id), None)
            if episode and episode.context.get("forgotten", False):
                continue  # 跳过已遗忘的记忆
                
            if candidate_ids is not None and mem_id not in candidate_ids:
                continue
            if session_id and meta.get("session_id") != session_id:
                continue

            # 从权威库读取完整记录
            doc = self.doc_store.get_memory(mem_id)
            if not doc:
                continue

            # 计算综合分数:向量0.6 + 近因0.2 + 重要性0.2
            vec_score = float(hit.get("score", 0.0))
            age_days = max(0.0, (now_ts - int(doc["timestamp"])) / 86400.0)
            recency_score = 1.0 / (1.0 + age_days)
            imp = float(doc.get("importance", 0.5))
            
            # 新评分算法:向量检索纯基于相似度,重要性作为加权因子
            # 基础相似度得分(不受重要性影响)
            base_relevance = vec_score * 0.8 + recency_score * 0.2
            
            # 重要性作为乘法加权因子,范围 [0.8, 1.2]
            importance_weight = 0.8 + (imp * 0.4)
            
            # 最终得分:相似度 * 重要性权重
            combined = base_relevance * importance_weight

            item = MemoryItem(
                id=doc["memory_id"],
                content=doc["content"],
                memory_type=doc["memory_type"],
                user_id=doc["user_id"],
                timestamp=datetime.fromtimestamp(doc["timestamp"]),
                importance=doc.get("importance", 0.5),
                metadata={
                    **doc.get("properties", {}),
                    "relevance_score": combined,
                    "vector_score": vec_score,
                    "recency_score": recency_score
                }
            )
            results.append((combined, item))
            seen.add(mem_id)

        # 若向量检索无结果,回退到简单关键词匹配(内存缓存)
        if not results:
            fallback = super()._generate_id  # 占位以避免未使用警告
            query_lower = query.lower()
            for ep in self._filter_episodes(user_id, session_id, time_range):
                if query_lower in ep.content.lower():
                    recency_score = 1.0 / (1.0 + max(0.0, (now_ts - int(ep.timestamp.timestamp())) / 86400.0))
                    # 回退匹配:新评分算法
                    keyword_score = 0.5  # 简单关键词匹配的基础分数
                    base_relevance = keyword_score * 0.8 + recency_score * 0.2
                    importance_weight = 0.8 + (ep.importance * 0.4)
                    combined = base_relevance * importance_weight
                    item = MemoryItem(
                        id=ep.episode_id,
                        content=ep.content,
                        memory_type="episodic",
                        user_id=ep.user_id,
                        timestamp=ep.timestamp,
                        importance=ep.importance,
                        metadata={
                            "session_id": ep.session_id,
                            "context": ep.context,
                            "outcome": ep.outcome,
                            "relevance_score": combined
                        }
                    )
                    results.append((combined, item))

        results.sort(key=lambda x: x[0], reverse=True)
        return [it for _, it in results[:limit]]
    
    def update(
        self,
        memory_id: str,
        content: str = None,
        importance: float = None,
        metadata: Dict[str, Any] = None
    ) -> bool:
        """更新情景记忆(SQLite为权威,Qdrant按需重嵌入)"""
        updated = False
        for episode in self.episodes:
            if episode.episode_id == memory_id:
                if content is not None:
                    episode.content = content
                if importance is not None:
                    episode.importance = importance
                if metadata is not None:
                    episode.context.update(metadata.get("context", {}))
                    if "outcome" in metadata:
                        episode.outcome = metadata["outcome"]
                updated = True
                break

        # 更新SQLite
        doc_updated = self.doc_store.update_memory(
            memory_id=memory_id,
            content=content,
            importance=importance,
            properties=metadata
        )

        # 如内容变更,重嵌入并upsert到Qdrant
        if content is not None:
            try:
                embedding = self.embedder.encode(content)
                if hasattr(embedding, "tolist"):
                    embedding = embedding.tolist()
                # 获取更新后的记录以同步payload
                doc = self.doc_store.get_memory(memory_id)
                payload = {
                    "memory_id": memory_id,
                    "user_id": doc["user_id"] if doc else "",
                    "memory_type": "episodic",
                    "importance": (doc.get("importance") if doc else importance) or 0.5,
                    "session_id": (doc.get("properties", {}) or {}).get("session_id"),
                    "content": content
                }
                self.vector_store.add_vectors(
                    vectors=[embedding],
                    metadata=[payload],
                    ids=[memory_id]
                )
            except Exception:
                pass

        return updated or doc_updated
    
    def remove(self, memory_id: str) -> bool:
        """删除情景记忆(SQLite + Qdrant)"""
        removed = False
        for i, episode in enumerate(self.episodes):
            if episode.episode_id == memory_id:
                removed_episode = self.episodes.pop(i)
                session_id = removed_episode.session_id
                if session_id in self.sessions:
                    self.sessions[session_id].remove(memory_id)
                    if not self.sessions[session_id]:
                        del self.sessions[session_id]
                removed = True
                break

        # 权威库删除
        doc_deleted = self.doc_store.delete_memory(memory_id)
        
        # 向量库删除
        try:
            self.vector_store.delete_memories([memory_id])
        except Exception:
            pass
        
        return removed or doc_deleted
    
    def has_memory(self, memory_id: str) -> bool:
        """检查记忆是否存在"""
        return any(episode.episode_id == memory_id for episode in self.episodes)
    
    def clear(self):
        """清空所有情景记忆(仅清理episodic,不影响其他类型)"""
        # 内存缓存
        self.episodes.clear()
        self.sessions.clear()
        self.patterns_cache.clear()

        # SQLite内的episodic全部删除
        docs = self.doc_store.search_memories(memory_type="episodic", limit=10000)
        ids = [d["memory_id"] for d in docs]
        for mid in ids:
            self.doc_store.delete_memory(mid)

        # Qdrant按ID删除对应向量
        try:
            if ids:
                self.vector_store.delete_memories(ids)
        except Exception:
            pass

    def forget(self, strategy: str = "importance_based", threshold: float = 0.1, max_age_days: int = 30) -> int:
        """情景记忆遗忘机制(硬删除)"""
        forgotten_count = 0
        current_time = datetime.now()
        
        to_remove = []  # 收集要删除的记忆ID
        
        for episode in self.episodes:
            should_forget = False
            
            if strategy == "importance_based":
                # 基于重要性遗忘
                if episode.importance < threshold:
                    should_forget = True
            elif strategy == "time_based":
                # 基于时间遗忘
                cutoff_time = current_time - timedelta(days=max_age_days)
                if episode.timestamp < cutoff_time:
                    should_forget = True
            elif strategy == "capacity_based":
                # 基于容量遗忘(保留最重要的)
                if len(self.episodes) > self.config.max_capacity:
                    sorted_episodes = sorted(self.episodes, key=lambda e: e.importance)
                    excess_count = len(self.episodes) - self.config.max_capacity
                    if episode in sorted_episodes[:excess_count]:
                        should_forget = True
            
            if should_forget:
                to_remove.append(episode.episode_id)
        
        # 执行硬删除
        for episode_id in to_remove:
            if self.remove(episode_id):
                forgotten_count += 1
                logger.info(f"情景记忆硬删除: {episode_id[:8]}... (策略: {strategy})")
        
        return forgotten_count

    def get_all(self) -> List[MemoryItem]:
        """获取所有情景记忆(转换为MemoryItem格式)"""
        memory_items = []
        for episode in self.episodes:
            memory_item = MemoryItem(
                id=episode.episode_id,
                content=episode.content,
                memory_type="episodic",
                user_id=episode.user_id,
                timestamp=episode.timestamp,
                importance=episode.importance,
                metadata=episode.metadata
            )
            memory_items.append(memory_item)
        return memory_items
    
    def get_stats(self) -> Dict[str, Any]:
        """获取情景记忆统计信息(合并SQLite与Qdrant)"""
        # 硬删除模式:所有episodes都是活跃的
        active_episodes = self.episodes
        
        db_stats = self.doc_store.get_database_stats()
        try:
            vs_stats = self.vector_store.get_collection_stats()
        except Exception:
            vs_stats = {"store_type": "qdrant"}
        return {
            "count": len(active_episodes),  # 活跃记忆数量
            "forgotten_count": 0,  # 硬删除模式下已遗忘的记忆会被直接删除
            "total_count": len(self.episodes),  # 总记忆数量
            "sessions_count": len(self.sessions),
            "avg_importance": sum(e.importance for e in active_episodes) / len(active_episodes) if active_episodes else 0.0,
            "time_span_days": self._calculate_time_span(),
            "memory_type": "episodic",
            "vector_store": vs_stats,
            "document_store": {k: v for k, v in db_stats.items() if k.endswith("_count") or k in ["store_type", "db_path"]}
        }
    
    def get_session_episodes(self, session_id: str) -> List[Episode]:
        """获取指定会话的所有情景"""
        if session_id not in self.sessions:
            return []
        
        episode_ids = self.sessions[session_id]
        return [e for e in self.episodes if e.episode_id in episode_ids]
    
    def find_patterns(self, user_id: str = None, min_frequency: int = 2) -> List[Dict[str, Any]]:
        """发现用户行为模式"""
        # 检查缓存
        cache_key = f"{user_id}_{min_frequency}"
        if (cache_key in self.patterns_cache and 
            self.last_pattern_analysis and 
            (datetime.now() - self.last_pattern_analysis).hours < 1):
            return self.patterns_cache[cache_key]
        
        # 过滤情景
        episodes = [e for e in self.episodes if user_id is None or e.user_id == user_id]
        
        # 简单的模式识别:基于内容关键词
        keyword_patterns = {}
        context_patterns = {}
        
        for episode in episodes:
            # 提取关键词
            words = episode.content.lower().split()
            for word in words:
                if len(word) > 3:  # 忽略短词
                    keyword_patterns[word] = keyword_patterns.get(word, 0) + 1
            
            # 提取上下文模式
            for key, value in episode.context.items():
                pattern_key = f"{key}:{value}"
                context_patterns[pattern_key] = context_patterns.get(pattern_key, 0) + 1
        
        # 筛选频繁模式
        patterns = []
        
        for keyword, frequency in keyword_patterns.items():
            if frequency >= min_frequency:
                patterns.append({
                    "type": "keyword",
                    "pattern": keyword,
                    "frequency": frequency,
                    "confidence": frequency / len(episodes)
                })
        
        for context_pattern, frequency in context_patterns.items():
            if frequency >= min_frequency:
                patterns.append({
                    "type": "context",
                    "pattern": context_pattern,
                    "frequency": frequency,
                    "confidence": frequency / len(episodes)
                })
        
        # 按频率排序
        patterns.sort(key=lambda x: x["frequency"], reverse=True)
        
        # 缓存结果
        self.patterns_cache[cache_key] = patterns
        self.last_pattern_analysis = datetime.now()
        
        return patterns
    
    def get_timeline(self, user_id: str = None, limit: int = 50) -> List[Dict[str, Any]]:
        """获取时间线视图"""
        episodes = [e for e in self.episodes if user_id is None or e.user_id == user_id]
        episodes.sort(key=lambda x: x.timestamp, reverse=True)
        
        timeline = []
        for episode in episodes[:limit]:
            timeline.append({
                "episode_id": episode.episode_id,
                "timestamp": episode.timestamp.isoformat(),
                "content": episode.content[:100] + "..." if len(episode.content) > 100 else episode.content,
                "session_id": episode.session_id,
                "importance": episode.importance,
                "outcome": episode.outcome
            })
        
        return timeline
    
    def _filter_episodes(
        self,
        user_id: str = None,
        session_id: str = None,
        time_range: Tuple[datetime, datetime] = None
    ) -> List[Episode]:
        """过滤情景"""
        filtered = self.episodes
        
        if user_id:
            filtered = [e for e in filtered if e.user_id == user_id]
        
        if session_id:
            filtered = [e for e in filtered if e.session_id == session_id]
        
        if time_range:
            start_time, end_time = time_range
            filtered = [e for e in filtered if start_time <= e.timestamp <= end_time]
        
        return filtered
    
    def _calculate_time_span(self) -> float:
        """计算记忆时间跨度(天)"""
        if not self.episodes:
            return 0.0
        
        timestamps = [e.timestamp for e in self.episodes]
        min_time = min(timestamps)
        max_time = max(timestamps)
        
        return (max_time - min_time).days
    
    def _persist_episode(self, episode: Episode):
        """持久化情景到存储后端"""
        if self.storage and hasattr(self.storage, 'add_memory'):
            self.storage.add_memory(
                memory_id=episode.episode_id,
                user_id=episode.user_id,
                content=episode.content,
                memory_type="episodic",
                timestamp=int(episode.timestamp.timestamp()),
                importance=episode.importance,
                properties={
                    "session_id": episode.session_id,
                    "context": episode.context,
                    "outcome": episode.outcome
                }
            )
    
    def _remove_from_storage(self, memory_id: str):
        """从存储后端删除"""
        if self.storage and hasattr(self.storage, 'delete_memory'):
            self.storage.delete_memory(memory_id)

2.3 语义记忆(SemanticMemory)

python 复制代码
"""语义记忆实现

结合向量检索和知识图谱的混合语义记忆,使用:
- HuggingFace 中文预训练模型进行文本嵌入
- 向量相似度检索进行快速初筛
- 知识图谱进行实体关系推理
- 混合检索策略优化结果质量
"""

from typing import List, Dict, Any, Optional, Set, Tuple
from datetime import datetime, timedelta
import json
import logging
import math
import numpy as np

from ..base import BaseMemory, MemoryItem, MemoryConfig
from ..embedding import get_text_embedder, get_dimension


# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Entity:
    """实体类"""
    
    def __init__(
        self,
        entity_id: str,
        name: str,
        entity_type: str = "MISC",
        description: str = "",
        properties: Dict[str, Any] = None
    ):
        self.entity_id = entity_id
        self.name = name
        self.entity_type = entity_type  # PERSON, ORG, PRODUCT, SKILL, CONCEPT等
        self.description = description
        self.properties = properties or {}
        self.created_at = datetime.now()
        self.updated_at = datetime.now()
        self.frequency = 1  # 出现频率
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "entity_id": self.entity_id,
            "name": self.name,
            "entity_type": self.entity_type,
            "description": self.description,
            "properties": self.properties,
            "frequency": self.frequency
        }

class Relation:
    """关系类"""
    
    def __init__(
        self,
        from_entity: str,
        to_entity: str,
        relation_type: str,
        strength: float = 1.0,
        evidence: str = "",
        properties: Dict[str, Any] = None
    ):
        self.from_entity = from_entity
        self.to_entity = to_entity
        self.relation_type = relation_type
        self.strength = strength
        self.evidence = evidence  # 支持该关系的原文本
        self.properties = properties or {}
        self.created_at = datetime.now()
        self.frequency = 1  # 关系出现频率
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "from_entity": self.from_entity,
            "to_entity": self.to_entity,
            "relation_type": self.relation_type,
            "strength": self.strength,
            "evidence": self.evidence,
            "properties": self.properties,
            "frequency": self.frequency
        }


class SemanticMemory(BaseMemory):
    """增强语义记忆实现
    
    特点:
    - 使用HuggingFace中文预训练模型进行文本嵌入
    - 向量检索进行快速相似度匹配
    - 知识图谱存储实体和关系
    - 混合检索策略:向量+图+语义推理
    """
    
    def __init__(self, config: MemoryConfig, storage_backend=None):
        super().__init__(config, storage_backend)
        
        # 嵌入模型(统一提供)
        self.embedding_model = None
        self._init_embedding_model()
        
        # 专业数据库存储
        self.vector_store = None
        self.graph_store = None
        self._init_databases()
        
        # 实体和关系缓存 (用于快速访问)
        self.entities: Dict[str, Entity] = {}
        self.relations: List[Relation] = []
        
        # 实体识别器
        self.nlp = None
        self._init_nlp()
        
        # 记忆存储
        self.semantic_memories: List[MemoryItem] = []
        self.memory_embeddings: Dict[str, np.ndarray] = {}
        
        logger.info("增强语义记忆初始化完成(使用Qdrant+Neo4j专业数据库)")
    
    def _init_embedding_model(self):
        """初始化统一嵌入模型(由 embedding_provider 管理)。"""
        try:
            self.embedding_model = get_text_embedder()
            # 轻量健康检查与日志
            try:
                test_vec = self.embedding_model.encode("health_check")
                dim = getattr(self.embedding_model, "dimension", len(test_vec))
                logger.info(f"✅ 嵌入模型就绪,维度: {dim}")
            except Exception:
                logger.info("✅ 嵌入模型就绪")
        except Exception as e:
            logger.error(f"❌ 嵌入模型初始化失败: {e}")
            raise
    
    def _init_databases(self):
        """初始化专业数据库存储"""
        try:
            from ...core.database_config import get_database_config
            # 获取数据库配置
            db_config = get_database_config()
            
            # 初始化Qdrant向量数据库(使用连接管理器避免重复连接)
            from ..storage.qdrant_store import QdrantConnectionManager
            qdrant_config = db_config.get_qdrant_config() or {}
            qdrant_config["vector_size"] = get_dimension()
            self.vector_store = QdrantConnectionManager.get_instance(**qdrant_config)
            logger.info("✅ Qdrant向量数据库初始化完成")
            
            # 初始化Neo4j图数据库
            from ..storage.neo4j_store import Neo4jGraphStore
            neo4j_config = db_config.get_neo4j_config()
            self.graph_store = Neo4jGraphStore(**neo4j_config)
            logger.info("✅ Neo4j图数据库初始化完成")
            
            # 验证连接
            vector_health = self.vector_store.health_check()
            graph_health = self.graph_store.health_check()
            
            if not vector_health:
                logger.warning("⚠️ Qdrant连接异常,部分功能可能受限")
            if not graph_health:
                logger.warning("⚠️ Neo4j连接异常,图搜索功能可能受限")
            
            logger.info(f"🏥 数据库健康状态: Qdrant={'✅' if vector_health else '❌'}, Neo4j={'✅' if graph_health else '❌'}")
            
        except Exception as e:
            logger.error(f"❌ 数据库初始化失败: {e}")
            logger.info("💡 请检查数据库配置和网络连接")
            logger.info("💡 参考 DATABASE_SETUP_GUIDE.md 进行配置")
            raise
    
    def _init_nlp(self):
        """初始化NLP处理器 - 智能多语言支持"""
        try:
            import spacy
            self.nlp_models = {}
            
            # 尝试加载多语言模型
            models_to_try = [
                ("zh_core_web_sm", "中文"),
                ("en_core_web_sm", "英文")
            ]
            
            loaded_models = []
            for model_name, lang_name in models_to_try:
                try:
                    nlp = spacy.load(model_name)
                    self.nlp_models[model_name] = nlp
                    loaded_models.append(lang_name)
                    logger.info(f"✅ 加载{lang_name}spaCy模型: {model_name}")
                except OSError:
                    logger.warning(f"⚠️ {lang_name}spaCy模型不可用: {model_name}")
            
            # 设置主要NLP处理器
            if "zh_core_web_sm" in self.nlp_models:
                self.nlp = self.nlp_models["zh_core_web_sm"]
                logger.info("🎯 主要使用中文spaCy模型")
            elif "en_core_web_sm" in self.nlp_models:
                self.nlp = self.nlp_models["en_core_web_sm"]
                logger.info("🎯 主要使用英文spaCy模型")
            else:
                self.nlp = None
                logger.warning("⚠️ 无可用spaCy模型,实体提取将受限")
            
            if loaded_models:
                logger.info(f"📚 可用语言模型: {', '.join(loaded_models)}")
                
        except ImportError:
            logger.warning("⚠️ spaCy不可用,实体提取将受限")
            self.nlp = None
            self.nlp_models = {}
    
    def add(self, memory_item: MemoryItem) -> str:
        """添加语义记忆"""
        try:
            # 1. 生成文本嵌入
            embedding = self.embedding_model.encode(memory_item.content)
            self.memory_embeddings[memory_item.id] = embedding
            
            # 2. 提取实体和关系
            entities = self._extract_entities(memory_item.content)
            relations = self._extract_relations(memory_item.content, entities)
            
            # 3. 存储到Neo4j图数据库
            for entity in entities:
                self._add_entity_to_graph(entity, memory_item)
            
            for relation in relations:
                self._add_relation_to_graph(relation, memory_item)
            
            # 4. 存储到Qdrant向量数据库
            metadata = {
                "memory_id": memory_item.id,
                "user_id": memory_item.user_id,
                "content": memory_item.content,
                "memory_type": memory_item.memory_type,
                "timestamp": int(memory_item.timestamp.timestamp()),
                "importance": memory_item.importance,
                "entities": [e.entity_id for e in entities],
                "entity_count": len(entities),
                "relation_count": len(relations)
            }
            
            success = self.vector_store.add_vectors(
                vectors=[embedding.tolist()],
                metadata=[metadata],
                ids=[memory_item.id]
            )
            
            if not success:
                logger.warning("⚠️ 向量存储失败,但记忆已添加到图数据库")
            
            # 5. 添加实体信息到元数据
            memory_item.metadata["entities"] = [e.entity_id for e in entities]
            memory_item.metadata["relations"] = [
                f"{r.from_entity}-{r.relation_type}-{r.to_entity}" for r in relations
            ]
            
            # 6. 存储记忆
            self.semantic_memories.append(memory_item)
            
            logger.info(f"✅ 添加语义记忆: {len(entities)}个实体, {len(relations)}个关系")
            return memory_item.id
        
        except Exception as e:
            logger.error(f"❌ 添加语义记忆失败: {e}")
            raise
    
    def retrieve(self, query: str, limit: int = 5, **kwargs) -> List[MemoryItem]:
        """检索语义记忆"""
        try:
            user_id = kwargs.get("user_id")

            # 1. 向量检索
            vector_results = self._vector_search(query, limit * 2, user_id)
            
            # 2. 图检索
            graph_results = self._graph_search(query, limit * 2, user_id)
            
            # 3. 混合排序
            combined_results = self._combine_and_rank_results(
                vector_results, graph_results, query, limit
            )

            # 3.1 计算概率(对 combined_score 做 softmax 归一化)
            scores = [r.get("combined_score", r.get("vector_score", 0.0)) for r in combined_results]
            if scores:
                import math
                max_s = max(scores)
                exps = [math.exp(s - max_s) for s in scores]
                denom = sum(exps) or 1.0
                probs = [e / denom for e in exps]
            else:
                probs = []
            
            # 4. 过滤已遗忘记忆并转换为MemoryItem
            result_memories = []
            for idx, result in enumerate(combined_results):
                memory_id = result.get("memory_id")
                
                # 检查是否已遗忘
                memory = next((m for m in self.semantic_memories if m.id == memory_id), None)
                if memory and memory.metadata.get("forgotten", False):
                    continue  # 跳过已遗忘的记忆
                
                # 处理时间戳
                timestamp = result.get("timestamp")
                if isinstance(timestamp, str):
                    try:
                        timestamp = datetime.fromisoformat(timestamp)
                    except ValueError:
                        timestamp = datetime.now()
                elif isinstance(timestamp, (int, float)):
                    timestamp = datetime.fromtimestamp(timestamp)
                else:
                    timestamp = datetime.now()
                
                # 直接从结果数据构建MemoryItem(附带分数与概率)
                memory_item = MemoryItem(
                    id=result["memory_id"],
                    content=result["content"],
                    memory_type="semantic",
                    user_id=result.get("user_id", "default"),
                    timestamp=timestamp,
                    importance=result.get("importance", 0.5),
                    metadata={
                        **result.get("metadata", {}),
                        "combined_score": result.get("combined_score", 0.0),
                        "vector_score": result.get("vector_score", 0.0),
                        "graph_score": result.get("graph_score", 0.0),
                        "probability": probs[idx] if idx < len(probs) else 0.0,
                    }
                )
                result_memories.append(memory_item)
            
            logger.info(f"✅ 检索到 {len(result_memories)} 条相关记忆")
            return result_memories[:limit]
                
        except Exception as e:
            logger.error(f"❌ 检索语义记忆失败: {e}")
            return []
    
    def _vector_search(self, query: str, limit: int, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
        """Qdrant向量搜索"""
        try:
            # 生成查询向量
            query_embedding = self.embedding_model.encode(query)
            
            # 构建过滤条件
            where_filter = {"memory_type": "semantic"}
            if user_id:
                where_filter["user_id"] = user_id

            # Qdrant向量检索
            results = self.vector_store.search_similar(
                query_vector=query_embedding.tolist(),
                limit=limit,
                where=where_filter if where_filter else None
            )

            # 转换结果格式以保持兼容性
            formatted_results = []
            for result in results:
                formatted_result = {
                    "id": result["id"],
                    "score": result["score"],
                    **result["metadata"]  # 包含所有元数据
                }
                formatted_results.append(formatted_result)

            logger.debug(f"🔍 Qdrant向量搜索返回 {len(formatted_results)} 个结果")
            return formatted_results
                
        except Exception as e:
            logger.error(f"❌ Qdrant向量搜索失败: {e}")
            return []

    def _graph_search(self, query: str, limit: int, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
        """Neo4j图搜索"""
        try:
            # 从查询中提取实体
            query_entities = self._extract_entities(query)
            
            if not query_entities:
                # 如果没有提取到实体,尝试按名称搜索
                entities_by_name = self.graph_store.search_entities_by_name(
                    name_pattern=query, 
                    limit=10
                )
                if entities_by_name:
                    query_entities = [Entity(
                        entity_id=e["id"],
                        name=e["name"],
                        entity_type=e["type"]
                    ) for e in entities_by_name[:3]]
                else:
                    return []
            
            # 在Neo4j图中查找相关实体和记忆
            related_memory_ids = set()
            
            for entity in query_entities:
                try:
                    # 查找相关实体
                    related_entities = self.graph_store.find_related_entities(
                        entity_id=entity.entity_id,
                        max_depth=2,
                        limit=20
                    )
                    
                    # 收集相关记忆ID
                    for rel_entity in related_entities:
                        if "memory_id" in rel_entity:
                            related_memory_ids.add(rel_entity["memory_id"])
                    
                    # 也添加直接匹配的实体记忆
                    entity_rels = self.graph_store.get_entity_relationships(entity.entity_id)
                    for rel in entity_rels:
                        rel_data = rel.get("relationship", {})
                        if "memory_id" in rel_data:
                            related_memory_ids.add(rel_data["memory_id"])
                            
                except Exception as e:
                    logger.debug(f"图搜索实体 {entity.entity_id} 失败: {e}")
                    continue
            
            # 构建结果 - 从向量数据库获取完整记忆信息
            results = []
            for memory_id in list(related_memory_ids)[:limit * 2]:  # 获取更多候选
                try:
                    # 优先从本地缓存获取记忆详情,避免占位向量维度不一致问题
                    mem = self._find_memory_by_id(memory_id)
                    if not mem:
                        continue

                    if user_id and mem.user_id != user_id:
                        continue

                    metadata = {
                        "content": mem.content,
                        "user_id": mem.user_id,
                        "memory_type": mem.memory_type,
                        "importance": mem.importance,
                        "timestamp": int(mem.timestamp.timestamp()),
                        "entities": mem.metadata.get("entities", [])
                    }

                    # 计算图相关性分数
                    graph_score = self._calculate_graph_relevance_neo4j(metadata, query_entities)

                    results.append({
                        "id": memory_id,
                        "memory_id": memory_id,
                        "content": metadata.get("content", ""),
                        "similarity": graph_score,
                        "user_id": metadata.get("user_id"),
                        "memory_type": metadata.get("memory_type"),
                        "importance": metadata.get("importance", 0.5),
                        "timestamp": metadata.get("timestamp"),
                        "entities": metadata.get("entities", [])
                    })

                except Exception as e:
                    logger.debug(f"获取记忆 {memory_id} 详情失败: {e}")
                    continue
            
            # 按图相关性排序
            results.sort(key=lambda x: x["similarity"], reverse=True)
            logger.debug(f"🕸️ Neo4j图搜索返回 {len(results)} 个结果")
            return results[:limit]
            
        except Exception as e:
            logger.error(f"❌ Neo4j图搜索失败: {e}")
            return []

    def _combine_and_rank_results(
        self,
        vector_results: List[Dict[str, Any]],
        graph_results: List[Dict[str, Any]],
        query: str,
        limit: int
    ) -> List[Dict[str, Any]]:
        """混合排序结果 - 仅基于向量与图分数的简单融合"""
        # 合并结果,按内容去重
        combined = {}
        content_seen = set()  # 用于内容去重
        
        # 添加向量结果
        for result in vector_results:
            memory_id = result["memory_id"]
            content = result.get("content", "")
            
            # 内容去重:检查是否已经有相同或高度相似的内容
            content_hash = hash(content.strip())
            if content_hash in content_seen:
                logger.debug(f"⚠️ 跳过重复内容: {content[:30]}...")
                continue
            
            content_seen.add(content_hash)
            combined[memory_id] = {
                **result,
                "vector_score": result.get("score", 0.0), 
                "graph_score": 0.0,
                "content_hash": content_hash
            }
        
        # 添加图结果
        for result in graph_results:
            memory_id = result["memory_id"]
            content = result.get("content", "")
            content_hash = hash(content.strip())
            
            if memory_id in combined:
                combined[memory_id]["graph_score"] = result.get("similarity", 0.0)
            elif content_hash not in content_seen:
                content_seen.add(content_hash)
                combined[memory_id] = {
                    **result,
                    "vector_score": 0.0,
                    "graph_score": result.get("similarity", 0.0),
                    "content_hash": content_hash
                }
        
        # 计算混合分数:相似度为主,重要性为辅助排序因子
        for memory_id, result in combined.items():
            vector_score = result["vector_score"]
            graph_score = result["graph_score"]
            importance = result.get("importance", 0.5)
            
            # 新评分算法:向量检索纯基于相似度,重要性作为加权因子
            # 基础相似度得分(不受重要性影响)
            base_relevance = vector_score * 0.7 + graph_score * 0.3
            
            # 重要性作为乘法加权因子,范围 [0.8, 1.2]
            # importance in [0,1] -> weight in [0.8,1.2]
            importance_weight = 0.8 + (importance * 0.4)
            
            # 最终得分:相似度 * 重要性权重
            combined_score = base_relevance * importance_weight
            
            # 调试信息:查看分数分解
            result["debug_info"] = {
                "base_relevance": base_relevance,
                "importance_weight": importance_weight,
                "combined_score": combined_score
            }

            result["combined_score"] = combined_score
        
        # 应用最小相关性阈值
        min_threshold = 0.1  # 最小相关性阈值
        filtered_results = [
            result for result in combined.values() 
            if result["combined_score"] >= min_threshold
        ]

        # 排序并返回
        sorted_results = sorted(
            filtered_results,
            key=lambda x: x["combined_score"],
            reverse=True
        )
        
        # 调试信息
        logger.debug(f"🔍 向量结果: {len(vector_results)}, 图结果: {len(graph_results)}")
        logger.debug(f"📝 去重后: {len(combined)}, 过滤后: {len(filtered_results)}")
        
        if logger.level <= logging.DEBUG:
            for i, result in enumerate(sorted_results[:3]):
                logger.debug(f"  结果{i+1}: 向量={result['vector_score']:.3f}, 图={result['graph_score']:.3f}, 精确={result.get('exact_match_bonus', 0):.3f}, 关键词={result.get('keyword_bonus', 0):.3f}, 公司={result.get('company_bonus', 0):.3f}, 实体={result.get('entity_type_bonus', 0):.3f}, 综合={result['combined_score']:.3f}")
        
        return sorted_results[:limit]
    
    def _detect_language(self, text: str) -> str:
        """简单的语言检测"""
        # 统计中文字符比例(无正则,逐字符判断范围)
        chinese_chars = sum(1 for ch in text if '\u4e00' <= ch <= '\u9fff')
        total_chars = len(text.replace(' ', ''))
        
        if total_chars == 0:
            return "en"
        
        chinese_ratio = chinese_chars / total_chars
        return "zh" if chinese_ratio > 0.3 else "en"
    
    def _extract_entities(self, text: str) -> List[Entity]:
        """智能多语言实体提取"""
        entities = []
        
        # 检测文本语言
        lang = self._detect_language(text)
        
        # 选择合适的spaCy模型
        selected_nlp = None
        if lang == "zh" and "zh_core_web_sm" in self.nlp_models:
            selected_nlp = self.nlp_models["zh_core_web_sm"]
        elif lang == "en" and "en_core_web_sm" in self.nlp_models:
            selected_nlp = self.nlp_models["en_core_web_sm"]
        else:
            # 使用默认模型
            selected_nlp = self.nlp
        
        logger.debug(f"🌐 检测语言: {lang}, 使用模型: {selected_nlp.meta['name'] if selected_nlp else 'None'}")
        
        # 使用spaCy进行实体识别和词法分析
        if selected_nlp:
            try:
                doc = selected_nlp(text)
                logger.debug(f"📝 spaCy处理文本: '{text}' -> {len(doc.ents)} 个实体")
                
                # 存储词法分析结果,供Neo4j使用
                self._store_linguistic_analysis(doc, text)
                
                if not doc.ents:
                    # 如果没有实体,记录详细的词元信息
                    logger.debug("🔍 未找到实体,词元分析:")
                    for token in doc[:5]:  # 只显示前5个词元
                        logger.debug(f"   '{token.text}' -> POS: {token.pos_}, TAG: {token.tag_}, ENT_IOB: {token.ent_iob_}")
                
                for ent in doc.ents:
                    entity = Entity(
                        entity_id=f"entity_{hash(ent.text)}",
                        name=ent.text,
                        entity_type=ent.label_,
                        description=f"从文本中识别的{ent.label_}实体"
                    )
                    entities.append(entity)
                    # 安全获取置信度信息
                    confidence = "N/A"
                    try:
                        if hasattr(ent._, 'confidence'):
                            confidence = getattr(ent._, 'confidence', 'N/A')
                    except:
                        confidence = "N/A"
                    
                    logger.debug(f"🏷️ spaCy识别实体: '{ent.text}' -> {ent.label_} (置信度: {confidence})")
                
            except Exception as e:
                logger.warning(f"⚠️ spaCy实体识别失败: {e}")
                import traceback
                logger.debug(f"详细错误: {traceback.format_exc()}")
        else:
            logger.warning("⚠️ 没有可用的spaCy模型进行实体识别")
        
        return entities
    
    def _store_linguistic_analysis(self, doc, text: str):
        """存储spaCy词法分析结果到Neo4j"""
        if not self.graph_store:
            return
            
        try:
            # 为每个词元创建节点
            for token in doc:
                # 跳过标点符号和空格
                if token.is_punct or token.is_space:
                    continue
                    
                token_id = f"token_{hash(token.text + token.pos_)}"
                
                # 添加词元节点到Neo4j
                self.graph_store.add_entity(
                    entity_id=token_id,
                    name=token.text,
                    entity_type="TOKEN",
                    properties={
                        "pos": token.pos_,        # 词性(NOUN, VERB等)
                        "tag": token.tag_,        # 细粒度标签
                        "lemma": token.lemma_,    # 词元原形
                        "is_alpha": token.is_alpha,
                        "is_stop": token.is_stop,
                        "source_text": text[:50],  # 来源文本片段
                        "language": self._detect_language(text)
                    }
                )
                
                # 如果是名词,可能是潜在的概念
                if token.pos_ in ["NOUN", "PROPN"]:
                    concept_id = f"concept_{hash(token.text)}"
                    self.graph_store.add_entity(
                        entity_id=concept_id,
                        name=token.text,
                        entity_type="CONCEPT",
                        properties={
                            "category": token.pos_,
                            "frequency": 1,  # 可以后续累计
                            "source_text": text[:50]
                        }
                    )
                    
                    # 建立词元到概念的关系
                    self.graph_store.add_relationship(
                        from_entity_id=token_id,
                        to_entity_id=concept_id,
                        relationship_type="REPRESENTS",
                        properties={"confidence": 1.0}
                    )
            
            # 建立词元之间的依存关系
            for token in doc:
                if token.is_punct or token.is_space or token.head == token:
                    continue
                    
                from_id = f"token_{hash(token.text + token.pos_)}"
                to_id = f"token_{hash(token.head.text + token.head.pos_)}"
                
                # Neo4j不允许关系类型包含冒号,需要清理
                relation_type = token.dep_.upper().replace(":", "_")
                
                self.graph_store.add_relationship(
                    from_entity_id=from_id,
                    to_entity_id=to_id,
                    relationship_type=relation_type,  # 清理后的依存关系类型
                    properties={
                        "dependency": token.dep_,  # 保留原始依存关系
                        "source_text": text[:50]
                    }
                )
            
            logger.debug(f"🔗 已将词法分析结果存储到Neo4j: {len([t for t in doc if not t.is_punct and not t.is_space])} 个词元")
            
        except Exception as e:
            logger.warning(f"⚠️ 存储词法分析失败: {e}")
    
    def _extract_relations(self, text: str, entities: List[Entity]) -> List[Relation]:
        """提取关系"""
        relations = []
        # 仅保留简单共现关系,不做任何正则/关键词匹配
        for i, entity1 in enumerate(entities):
            for entity2 in entities[i+1:]:
                relations.append(Relation(
                    from_entity=entity1.entity_id,
                    to_entity=entity2.entity_id,
                    relation_type="CO_OCCURS",
                    strength=0.5,
                    evidence=text[:100]
                ))
        return relations
    
    def _add_entity_to_graph(self, entity: Entity, memory_item: MemoryItem):
        """添加实体到Neo4j图数据库"""
        try:
            # 准备实体属性
            properties = {
                "name": entity.name,
                "description": entity.description,
                "frequency": entity.frequency,
                "memory_id": memory_item.id,
                "user_id": memory_item.user_id,
                "importance": memory_item.importance,
                **entity.properties
            }
            
            # 添加到Neo4j
            success = self.graph_store.add_entity(
                entity_id=entity.entity_id,
                name=entity.name,
                entity_type=entity.entity_type,
                properties=properties
            )
            
            if success:
                # 同时更新本地缓存
                if entity.entity_id in self.entities:
                    self.entities[entity.entity_id].frequency += 1
                    self.entities[entity.entity_id].updated_at = datetime.now()
                else:
                    self.entities[entity.entity_id] = entity
                    
            return success
            
        except Exception as e:
            logger.error(f"❌ 添加实体到图数据库失败: {e}")
            return False
    
    def _add_relation_to_graph(self, relation: Relation, memory_item: MemoryItem):
        """添加关系到Neo4j图数据库"""
        try:
            # 准备关系属性
            properties = {
                "strength": relation.strength,
                "memory_id": memory_item.id,
                "user_id": memory_item.user_id,
                "importance": memory_item.importance,
                "evidence": relation.evidence
            }
            
            # 添加到Neo4j
            success = self.graph_store.add_relationship(
                from_entity_id=relation.from_entity,
                to_entity_id=relation.to_entity,
                relationship_type=relation.relation_type,
                properties=properties
            )
            
            if success:
                # 同时更新本地缓存
                self.relations.append(relation)
                
            return success
            
        except Exception as e:
            logger.error(f"❌ 添加关系到图数据库失败: {e}")
            return False
    
    def _calculate_graph_relevance_neo4j(self, memory_metadata: Dict[str, Any], query_entities: List[Entity]) -> float:
        """计算Neo4j图相关性分数"""
        try:
            memory_entities = memory_metadata.get("entities", [])
            if not memory_entities or not query_entities:
                return 0.0
            
            # 实体匹配度
            query_entity_ids = {e.entity_id for e in query_entities}
            matching_entities = len(set(memory_entities).intersection(query_entity_ids))
            entity_score = matching_entities / len(query_entity_ids) if query_entity_ids else 0
            
            # 实体数量加权
            entity_count = memory_metadata.get("entity_count", 0)
            entity_density = min(entity_count / 10, 1.0)  # 归一化到[0,1]
            
            # 关系数量加权
            relation_count = memory_metadata.get("relation_count", 0)
            relation_density = min(relation_count / 5, 1.0)  # 归一化到[0,1]
            
            # 综合分数
            relevance_score = (
                entity_score * 0.6 +           # 实体匹配权重60%
                entity_density * 0.2 +         # 实体密度权重20%
                relation_density * 0.2         # 关系密度权重20%
            )
            
            return min(relevance_score, 1.0)
            
        except Exception as e:
            logger.debug(f"计算图相关性失败: {e}")
            return 0.0

    def _add_or_update_entity(self, entity: Entity):
        """添加或更新实体"""
        if entity.entity_id in self.entities:
            # 更新现有实体
            existing = self.entities[entity.entity_id]
            existing.frequency += 1
            existing.updated_at = datetime.now()
        else:
            # 添加新实体
            self.entities[entity.entity_id] = entity
    
    def _add_or_update_relation(self, relation: Relation):
        """添加或更新关系"""
        # 检查是否已存在相同关系
        existing_relation = None
        for r in self.relations:
            if (r.from_entity == relation.from_entity and
                r.to_entity == relation.to_entity and
                r.relation_type == relation.relation_type):
                existing_relation = r
                break
        
        if existing_relation:
            # 更新现有关系
            existing_relation.frequency += 1
            existing_relation.strength = min(1.0, existing_relation.strength + 0.1)
        else:
            # 添加新关系
            self.relations.append(relation)
    
    # 旧的图相关性计算方法已被 _calculate_graph_relevance_neo4j 替代
    
    def _find_memory_by_id(self, memory_id: str) -> Optional[MemoryItem]:
        """根据ID查找记忆"""
        logger.debug(f"🔍 查找记忆ID: {memory_id}, 当前记忆数: {len(self.semantic_memories)}")
        for memory in self.semantic_memories:
            if memory.id == memory_id:
                logger.debug(f"✅ 找到记忆: {memory.content[:50]}...")
                return memory
        logger.debug(f"❌ 未找到记忆ID: {memory_id}")
        return None
    
    def update(
        self,
        memory_id: str,
        content: str = None,
        importance: float = None,
        metadata: Dict[str, Any] = None
    ) -> bool:
        """更新语义记忆"""
        memory = self._find_memory_by_id(memory_id)
        if not memory:
            return False
        
        try:
            if content is not None:
                # 重新生成嵌入和提取实体
                embedding = self.embedding_model.encode(content)
                self.memory_embeddings[memory_id] = embedding
                
                # 清理旧的实体关系
                old_entities = memory.metadata.get("entities", [])
                self._cleanup_entities_and_relations(old_entities)
                
                # 提取新的实体和关系
                memory.content = content
                entities = self._extract_entities(content)
                relations = self._extract_relations(content, entities)
                
                # 更新知识图谱
                for entity in entities:
                    self._add_or_update_entity(entity)
                for relation in relations:
                    self._add_or_update_relation(relation)
                
                # 更新元数据
                memory.metadata["entities"] = [e.entity_id for e in entities]
                memory.metadata["relations"] = [
                    f"{r.from_entity}-{r.relation_type}-{r.to_entity}" for r in relations
                ]
                
            if importance is not None:
                memory.importance = importance
            
            if metadata is not None:
                memory.metadata.update(metadata)
                
                return True
            
        except Exception as e:
            logger.error(f"❌ 更新记忆失败: {e}")
        return False
    
    def remove(self, memory_id: str) -> bool:
        """删除语义记忆"""
        memory = self._find_memory_by_id(memory_id)
        if not memory:
            return False
        
        try:
            # 删除向量
            self.vector_store.delete_memories([memory_id])
            
            # 清理实体和关系
            entities = memory.metadata.get("entities", [])
            self._cleanup_entities_and_relations(entities)
            
            # 删除记忆
            self.semantic_memories.remove(memory)
            if memory_id in self.memory_embeddings:
                del self.memory_embeddings[memory_id]
                
                return True
            
        except Exception as e:
            logger.error(f"❌ 删除记忆失败: {e}")
        return False
    
    def _cleanup_entities_and_relations(self, entity_ids: List[str]):
        """清理实体和关系"""
        # 这里可以实现更智能的清理逻辑
        # 例如,如果实体不再被任何记忆引用,则删除它
        pass
    
    def has_memory(self, memory_id: str) -> bool:
        """检查记忆是否存在"""
        return self._find_memory_by_id(memory_id) is not None
    
    def forget(self, strategy: str = "importance_based", threshold: float = 0.1, max_age_days: int = 30) -> int:
        """语义记忆遗忘机制(硬删除)"""
        forgotten_count = 0
        current_time = datetime.now()
        
        to_remove = []  # 收集要删除的记忆ID
        
        for memory in self.semantic_memories:
            should_forget = False
            
            if strategy == "importance_based":
                # 基于重要性遗忘
                if memory.importance < threshold:
                    should_forget = True
            elif strategy == "time_based":
                # 基于时间遗忘
                cutoff_time = current_time - timedelta(days=max_age_days)
                if memory.timestamp < cutoff_time:
                    should_forget = True
            elif strategy == "capacity_based":
                # 基于容量遗忘(保留最重要的)
                if len(self.semantic_memories) > self.config.max_capacity:
                    sorted_memories = sorted(self.semantic_memories, key=lambda m: m.importance)
                    excess_count = len(self.semantic_memories) - self.config.max_capacity
                    if memory in sorted_memories[:excess_count]:
                        should_forget = True
            
            if should_forget:
                to_remove.append(memory.id)
        
        # 执行硬删除
        for memory_id in to_remove:
            if self.remove(memory_id):
                forgotten_count += 1
                logger.info(f"语义记忆硬删除: {memory_id[:8]}... (策略: {strategy})")
        
        return forgotten_count

    def clear(self):
        """清空所有语义记忆 - 包括专业数据库"""
        try:
            # 清空Qdrant向量数据库
            if self.vector_store:
                success = self.vector_store.clear_collection()
                if success:
                    logger.info("✅ Qdrant向量数据库已清空")
                else:
                    logger.warning("⚠️ Qdrant清空失败")
            
            # 清空Neo4j图数据库
            if self.graph_store:
                success = self.graph_store.clear_all()
                if success:
                    logger.info("✅ Neo4j图数据库已清空")
                else:
                    logger.warning("⚠️ Neo4j清空失败")
            
            # 清空本地缓存
            self.semantic_memories.clear()
            self.memory_embeddings.clear()
            self.entities.clear()
            self.relations.clear()
            
            logger.info("🧹 语义记忆系统已完全清空")
            
        except Exception as e:
            logger.error(f"❌ 清空语义记忆失败: {e}")
            # 即使数据库清空失败,也要清空本地缓存
        self.semantic_memories.clear()
        self.memory_embeddings.clear()
        self.entities.clear()
        self.relations.clear()

    def get_all(self) -> List[MemoryItem]:
        """获取所有语义记忆"""
        return self.semantic_memories.copy()
    
    def get_stats(self) -> Dict[str, Any]:
        """获取语义记忆统计信息"""
        graph_stats = {}
        try:
            if self.graph_store:
                graph_stats = self.graph_store.get_stats() or {}
        except Exception:
            graph_stats = {}

        # 硬删除模式:所有记忆都是活跃的
        active_memories = self.semantic_memories

        return {
            "count": len(active_memories),  # 活跃记忆数量
            "forgotten_count": 0,  # 硬删除模式下已遗忘的记忆会被直接删除
            "total_count": len(self.semantic_memories),  # 总记忆数量
            "entities_count": len(self.entities),
            "relations_count": len(self.relations),
            "graph_nodes": graph_stats.get("total_nodes", 0),
            "graph_edges": graph_stats.get("total_relationships", 0),
            "avg_importance": sum(m.importance for m in active_memories) / len(active_memories) if active_memories else 0.0,
            "memory_type": "enhanced_semantic"
        }
    
    def get_entity(self, entity_id: str) -> Optional[Entity]:
        """获取实体"""
        return self.entities.get(entity_id)
    
    def search_entities(self, query: str, limit: int = 10) -> List[Entity]:
        """搜索实体"""
        query_lower = query.lower()
        scored_entities = []
        
        for entity in self.entities.values():
            score = 0.0
            
            # 名称匹配
            if query_lower in entity.name.lower():
                score += 2.0
            
            # 类型匹配
            if query_lower in entity.entity_type.lower():
                score += 1.0
            
            # 描述匹配
            if query_lower in entity.description.lower():
                score += 0.5
            
            # 频率权重
            score *= math.log(1 + entity.frequency)
            
            if score > 0:
                scored_entities.append((score, entity))
        
        scored_entities.sort(key=lambda x: x[0], reverse=True)
        return [entity for _, entity in scored_entities[:limit]]
    
    def get_related_entities(
        self,
        entity_id: str,
        relation_types: List[str] = None,
        max_hops: int = 2
    ) -> List[Dict[str, Any]]:
        """获取相关实体 - 使用Neo4j图数据库"""
        
        related = []
        
        try:
            # 使用Neo4j图数据库查找相关实体
            if not self.graph_store:
                logger.warning("⚠️ Neo4j图数据库不可用")
                return []
            
            # 使用Neo4j查找相关实体
            related_entities = self.graph_store.find_related_entities(
                entity_id=entity_id,
                relationship_types=relation_types,
                max_depth=max_hops,
                limit=50
            )
            
            # 转换格式以保持兼容性
            for entity_data in related_entities:
                # 尝试从本地缓存获取实体对象
                entity_obj = self.entities.get(entity_data.get("id"))
                if not entity_obj:
                    # 如果本地缓存没有,创建临时实体对象
                    entity_obj = Entity(
                        entity_id=entity_data.get("id", entity_id),
                        name=entity_data.get("name", ""),
                        entity_type=entity_data.get("type", "MISC")
                    )
                
                    related.append({
                    "entity": entity_obj,
                    "relation_type": entity_data.get("relationship_path", ["RELATED"])[-1] if entity_data.get("relationship_path") else "RELATED",
                    "strength": 1.0 / max(entity_data.get("distance", 1), 1),  # 距离越近强度越高
                    "distance": entity_data.get("distance", max_hops)
                })
            
            # 按距离和强度排序
            related.sort(key=lambda x: (x["distance"], -x["strength"]))
            
        except Exception as e:
            logger.error(f"❌ 获取相关实体失败: {e}")
        
        return related
    
    def export_knowledge_graph(self) -> Dict[str, Any]:
        """导出知识图谱 - 从Neo4j获取统计信息"""
        try:
            # 从Neo4j获取统计信息
            stats = {}
            if self.graph_store:
                stats = self.graph_store.get_stats()
            
            return {
                "entities": {eid: entity.to_dict() for eid, entity in self.entities.items()},
                "relations": [relation.to_dict() for relation in self.relations],
                "graph_stats": {
                    "total_nodes": stats.get("total_nodes", 0),
                    "entity_nodes": stats.get("entity_nodes", 0),
                    "memory_nodes": stats.get("memory_nodes", 0),
                    "total_relationships": stats.get("total_relationships", 0),
                    "cached_entities": len(self.entities),
                    "cached_relations": len(self.relations)
                }
            }
        except Exception as e:
            logger.error(f"❌ 导出知识图谱失败: {e}")
            return {
                "entities": {},
                "relations": [],
                "graph_stats": {"error": str(e)}
            }

2.4 感知记忆(PerceptualMemory)

python 复制代码
"""感知记忆实现(长存的多模态)

按照第8章架构设计的感知记忆(长期、多模态),提供:
- 多模态数据存储(文本、图像、音频等)
- 结构化元数据 + 向量索引(SQLite + Qdrant)
- 同模态检索(跨模态在无CLIP/CLAP依赖时有限)
- 懒加载编码:文本用 sentence-transformers;图像/音频用轻量确定性哈希向量
"""

from typing import List, Dict, Any, Optional, Union, Tuple
from datetime import datetime, timedelta
import hashlib
import os
import random
import logging

logger = logging.getLogger(__name__)

from ..base import BaseMemory, MemoryItem, MemoryConfig
from ..storage import SQLiteDocumentStore, QdrantVectorStore
from ..embedding import get_text_embedder, get_dimension

class Perception:
    """感知数据实体"""
    
    def __init__(
        self,
        perception_id: str,
        data: Any,
        modality: str,
        encoding: Optional[List[float]] = None,
        metadata: Dict[str, Any] = None
    ):
        self.perception_id = perception_id
        self.data = data
        self.modality = modality  # text, image, audio, video, structured
        self.encoding = encoding or []
        self.metadata = metadata or {}
        self.timestamp = datetime.now()
        self.data_hash = self._calculate_hash()
    
    def _calculate_hash(self) -> str:
        """计算数据哈希"""
        if isinstance(self.data, str):
            return hashlib.md5(self.data.encode()).hexdigest()
        elif isinstance(self.data, bytes):
            return hashlib.md5(self.data).hexdigest()
        else:
            return hashlib.md5(str(self.data).encode()).hexdigest()

class PerceptualMemory(BaseMemory):
    """感知记忆实现
    
    特点:
    - 支持多模态数据(文本、图像、音频等)
    - 跨模态相似性搜索
    - 感知数据的语义理解
    - 支持内容生成和检索
    """
    
    def __init__(self, config: MemoryConfig, storage_backend=None):
        super().__init__(config, storage_backend)
        
        # 感知数据存储(内存缓存)
        self.perceptions: Dict[str, Perception] = {}
        self.perceptual_memories: List[MemoryItem] = []
        
        # 模态索引
        self.modality_index: Dict[str, List[str]] = {}  # modality -> perception_ids
        
        # 支持的模态
        self.supported_modalities = set(self.config.perceptual_memory_modalities)
        
        # 文档权威存储(SQLite)
        db_dir = getattr(self.config, 'storage_path', "./memory_data")
        os.makedirs(db_dir, exist_ok=True)
        db_path = os.path.join(db_dir, "memory.db")
        self.doc_store = SQLiteDocumentStore(db_path=db_path)

        # 嵌入维度(与统一文本嵌入保持一致)
        self.text_embedder = get_text_embedder()
        self.vector_dim = get_dimension(getattr(self.text_embedder, 'dimension', 384))

        # 可选加载:图像CLIP与音频CLAP(缺依赖则优雅降级为哈希编码)
        self._clip_model = None
        self._clip_processor = None
        self._clap_model = None
        self._clap_processor = None
        self._image_dim = None
        self._audio_dim = None
        try:
            from transformers import CLIPModel, CLIPProcessor
            clip_name = os.getenv("CLIP_MODEL", "openai/clip-vit-base-patch32")
            self._clip_model = CLIPModel.from_pretrained(clip_name)
            self._clip_processor = CLIPProcessor.from_pretrained(clip_name)
            # 估计输出维度
            self._image_dim = self._clip_model.config.projection_dim if hasattr(self._clip_model.config, 'projection_dim') else 512
        except Exception:
            self._clip_model = None
            self._clip_processor = None
            self._image_dim = self.vector_dim
        try:
            from transformers import ClapProcessor, ClapModel
            clap_name = os.getenv("CLAP_MODEL", "laion/clap-htsat-unfused")
            self._clap_model = ClapModel.from_pretrained(clap_name)
            self._clap_processor = ClapProcessor.from_pretrained(clap_name)
            # 估计输出维度
            self._audio_dim = getattr(self._clap_model.config, 'projection_dim', None) or 512
        except Exception:
            self._clap_model = None
            self._clap_processor = None
            self._audio_dim = self.vector_dim

        # 向量存储(Qdrant)--- 按模态拆分集合,避免维度冲突,使用连接管理器避免重复连接
        from ..storage.qdrant_store import QdrantConnectionManager
        qdrant_url = os.getenv("QDRANT_URL")
        qdrant_api_key = os.getenv("QDRANT_API_KEY")
        base_collection = os.getenv("QDRANT_COLLECTION", "hello_agents_vectors")
        distance = os.getenv("QDRANT_DISTANCE", "cosine")
        
        self.vector_stores: Dict[str, QdrantVectorStore] = {}
        # 文本集合
        self.vector_stores["text"] = QdrantConnectionManager.get_instance(
            url=qdrant_url,
            api_key=qdrant_api_key,
            collection_name=f"{base_collection}_perceptual_text",
            vector_size=self.vector_dim,
            distance=distance
        )
        # 图像集合(若CLIP不可用,维度退化为text维度)
        self.vector_stores["image"] = QdrantConnectionManager.get_instance(
            url=qdrant_url,
            api_key=qdrant_api_key,
            collection_name=f"{base_collection}_perceptual_image",
            vector_size=int(self._image_dim or self.vector_dim),
            distance=distance
        )
        # 音频集合(若CLAP不可用,维度退化为text维度)
        self.vector_stores["audio"] = QdrantConnectionManager.get_instance(
            url=qdrant_url,
            api_key=qdrant_api_key,
            collection_name=f"{base_collection}_perceptual_audio",
            vector_size=int(self._audio_dim or self.vector_dim),
            distance=distance
        )
        
        # 编码器(轻量实现;真实场景可替换为CLIP/CLAP等)
        self.encoders = self._init_encoders()
    
    def add(self, memory_item: MemoryItem) -> str:
        """添加感知记忆(SQLite权威 + Qdrant向量)"""
        modality = memory_item.metadata.get("modality", "text")
        raw_data = memory_item.metadata.get("raw_data", memory_item.content)
        if modality not in self.supported_modalities:
            raise ValueError(f"不支持的模态类型: {modality}")

        # 编码感知数据
        perception = self._encode_perception(raw_data, modality, memory_item.id)

        # 缓存与索引
        self.perceptions[perception.perception_id] = perception
        if modality not in self.modality_index:
            self.modality_index[modality] = []
        self.modality_index[modality].append(perception.perception_id)

        # 存储记忆项(缓存)
        memory_item.metadata["perception_id"] = perception.perception_id
        memory_item.metadata["modality"] = modality
        # 不把大向量放到metadata中,避免膨胀
        self.perceptual_memories.append(memory_item)

        # 1) SQLite 权威入库
        ts_int = int(memory_item.timestamp.timestamp())
        self.doc_store.add_memory(
            memory_id=memory_item.id,
            user_id=memory_item.user_id,
            content=memory_item.content,
            memory_type="perceptual",
            timestamp=ts_int,
            importance=memory_item.importance,
            properties={
                "perception_id": perception.perception_id,
                "modality": modality,
                "context": memory_item.metadata.get("context", {}),
                "tags": memory_item.metadata.get("tags", []),
            }
        )

        # 2) Qdrant 向量入库(按模态写入对应集合)
        try:
            vector = perception.encoding
            store = self._get_vector_store_for_modality(modality)
            store.add_vectors(
                vectors=[vector],
                metadata=[{
                    "memory_id": memory_item.id,
                    "user_id": memory_item.user_id,
                    "memory_type": "perceptual",
                    "modality": modality,
                    "importance": memory_item.importance,
                    "content": memory_item.content,
                }],
                ids=[memory_item.id]
            )
        except Exception:
            pass

        return memory_item.id
    
    def retrieve(self, query: str, limit: int = 5, **kwargs) -> List[MemoryItem]:
        """检索感知记忆(可筛模态;同模态向量检索+时间/重要性融合)"""
        user_id = kwargs.get("user_id")
        target_modality = kwargs.get("target_modality")  # 可选:限制目标模态
        query_modality = kwargs.get("query_modality", target_modality or "text")

        # 仅在同模态情况下进行向量检索(跨模态需要CLIP/CLAP,此处保留简单回退)
        try:
            qvec = self._encode_data(query, query_modality)
            where = {"memory_type": "perceptual"}
            if user_id:
                where["user_id"] = user_id
            if target_modality:
                where["modality"] = target_modality
            store = self._get_vector_store_for_modality(target_modality or query_modality)
            hits = store.search_similar(
                query_vector=qvec,
                limit=max(limit * 5, 20),
                where=where
            )
        except Exception:
            hits = []

        # 融合排序
        now_ts = int(datetime.now().timestamp())
        results: List[Tuple[float, MemoryItem]] = []
        seen = set()
        for hit in hits:
            meta = hit.get("metadata", {})
            mem_id = meta.get("memory_id")
            if not mem_id or mem_id in seen:
                continue
            if target_modality and meta.get("modality") != target_modality:
                continue
            doc = self.doc_store.get_memory(mem_id)
            if not doc:
                continue
            vec_score = float(hit.get("score", 0.0))
            age_days = max(0.0, (now_ts - int(doc["timestamp"])) / 86400.0)
            recency_score = 1.0 / (1.0 + age_days)
            imp = float(doc.get("importance", 0.5))
            
            # 新评分算法:向量检索纯基于相似度,重要性作为加权因子
            # 基础相似度得分(不受重要性影响)
            base_relevance = vec_score * 0.8 + recency_score * 0.2
            
            # 重要性作为乘法加权因子,范围 [0.8, 1.2]
            importance_weight = 0.8 + (imp * 0.4)
            
            # 最终得分:相似度 * 重要性权重
            combined = base_relevance * importance_weight

            item = MemoryItem(
                id=doc["memory_id"],
                content=doc["content"],
                memory_type=doc["memory_type"],
                user_id=doc["user_id"],
                timestamp=datetime.fromtimestamp(doc["timestamp"]),
                importance=imp,
                metadata={**doc.get("properties", {}), "relevance_score": combined,
                          "vector_score": vec_score, "recency_score": recency_score}
            )
            results.append((combined, item))
            seen.add(mem_id)

        # 简单回退:若无命中且有目标模态,则按SQLite结构化过滤+关键词兜底
        if not results:
            for m in self.perceptual_memories:
                if target_modality and m.metadata.get("modality") != target_modality:
                    continue
                if query.lower() in (m.content or "").lower():
                    recency_score = 1.0 / (1.0 + max(0.0, (now_ts - int(m.timestamp.timestamp())) / 86400.0))
                    # 回退匹配:新评分算法
                    keyword_score = 0.5  # 简单关键词匹配的基础分数
                    base_relevance = keyword_score * 0.8 + recency_score * 0.2
                    importance_weight = 0.8 + (m.importance * 0.4)
                    combined = base_relevance * importance_weight
                    results.append((combined, m))

        results.sort(key=lambda x: x[0], reverse=True)
        return [it for _, it in results[:limit]]
    
    def update(
        self,
        memory_id: str,
        content: str = None,
        importance: float = None,
        metadata: Dict[str, Any] = None
    ) -> bool:
        """更新感知记忆"""
        updated = False
        modality_cache = None
        for memory in self.perceptual_memories:
            if memory.id == memory_id:
                if content is not None:
                    memory.content = content
                if importance is not None:
                    memory.importance = importance
                if metadata is not None:
                    memory.metadata.update(metadata)
                modality_cache = memory.metadata.get("modality", "text")
                updated = True
                break

        # 更新SQLite
        self.doc_store.update_memory(
            memory_id=memory_id,
            content=content,
            importance=importance,
            properties=metadata
        )

        # 如内容或原始数据改变,则重嵌入并upsert到Qdrant
        if content is not None or (metadata and "raw_data" in metadata):
            modality = metadata.get("modality", modality_cache or "text") if metadata else (modality_cache or "text")
            raw = metadata.get("raw_data", content) if metadata else content
            try:
                perception = self._encode_perception(raw or "", modality, memory_id)
                payload = self.doc_store.get_memory(memory_id) or {}
                self.vector_store.add_vectors(
                    vectors=[perception.encoding],
                    metadata=[{
                        "memory_id": memory_id,
                        "user_id": payload.get("user_id", ""),
                        "memory_type": "perceptual",
                        "modality": modality,
                        "importance": (payload.get("importance") or importance) or 0.5,
                        "content": content or (payload.get("content", "")),
                    }],
                    ids=[memory_id]
                )
            except Exception:
                pass

        return updated
    
    def remove(self, memory_id: str) -> bool:
        """删除感知记忆"""
        removed = False
        for i, memory in enumerate(self.perceptual_memories):
            if memory.id == memory_id:
                removed_memory = self.perceptual_memories.pop(i)
                perception_id = removed_memory.metadata.get("perception_id")
                if perception_id and perception_id in self.perceptions:
                    perception = self.perceptions.pop(perception_id)
                    modality = perception.modality
                    if modality in self.modality_index:
                        if perception_id in self.modality_index[modality]:
                            self.modality_index[modality].remove(perception_id)
                        if not self.modality_index[modality]:
                            del self.modality_index[modality]
                removed = True
                break

        # 权威库删除
        self.doc_store.delete_memory(memory_id)
        # 向量库删除(所有模态集合尝试删除)
        for store in self.vector_stores.values():
            try:
                store.delete_memories([memory_id])
            except Exception:
                pass

        return removed
    
    def has_memory(self, memory_id: str) -> bool:
        """检查记忆是否存在"""
        return any(memory.id == memory_id for memory in self.perceptual_memories)
    
    def forget(self, strategy: str = "importance_based", threshold: float = 0.1, max_age_days: int = 30) -> int:
        """感知记忆遗忘机制(硬删除)"""
        forgotten_count = 0
        current_time = datetime.now()
        
        to_remove = []  # 收集要删除的记忆ID
        
        for memory in self.perceptual_memories:
            should_forget = False
            
            if strategy == "importance_based":
                # 基于重要性遗忘
                if memory.importance < threshold:
                    should_forget = True
            elif strategy == "time_based":
                # 基于时间遗忘
                cutoff_time = current_time - timedelta(days=max_age_days)
                if memory.timestamp < cutoff_time:
                    should_forget = True
            elif strategy == "capacity_based":
                # 基于容量遗忘(保留最重要的)
                if len(self.perceptual_memories) > self.config.max_capacity:
                    sorted_memories = sorted(self.perceptual_memories, key=lambda m: m.importance)
                    excess_count = len(self.perceptual_memories) - self.config.max_capacity
                    if memory in sorted_memories[:excess_count]:
                        should_forget = True
            
            if should_forget:
                to_remove.append(memory.id)
        
        # 执行硬删除
        for memory_id in to_remove:
            if self.remove(memory_id):
                forgotten_count += 1
                logger.info(f"感知记忆硬删除: {memory_id[:8]}... (策略: {strategy})")
        
        return forgotten_count

    def clear(self):
        """清空所有感知记忆"""
        self.perceptual_memories.clear()
        self.perceptions.clear()
        self.modality_index.clear()
        # 删除SQLite中的perceptual记录
        docs = self.doc_store.search_memories(memory_type="perceptual", limit=10000)
        ids = [d["memory_id"] for d in docs]
        for mid in ids:
            self.doc_store.delete_memory(mid)
        # 删除Qdrant向量(所有模态集合)
        for store in self.vector_stores.values():
            try:
                if ids:
                    store.delete_memories(ids)
            except Exception:
                pass

    def get_all(self) -> List[MemoryItem]:
        """获取所有感知记忆"""
        return self.perceptual_memories.copy()
    
    def get_stats(self) -> Dict[str, Any]:
        """获取感知记忆统计信息"""
        # 硬删除模式:所有记忆都是活跃的
        active_memories = self.perceptual_memories
        
        modality_counts = {modality: len(ids) for modality, ids in self.modality_index.items()}
        vs_stats_all = {}
        for mod, store in self.vector_stores.items():
            try:
                vs_stats_all[mod] = store.get_collection_stats()
            except Exception:
                vs_stats_all[mod] = {"store_type": "qdrant"}
        db_stats = self.doc_store.get_database_stats()
        
        return {
            "count": len(active_memories),  # 活跃记忆数量
            "forgotten_count": 0,  # 硬删除模式下已遗忘的记忆会被直接删除
            "total_count": len(self.perceptual_memories),  # 总记忆数量
            "perceptions_count": len(self.perceptions),
            "modality_counts": modality_counts,
            "supported_modalities": list(self.supported_modalities),
            "avg_importance": sum(m.importance for m in active_memories) / len(active_memories) if active_memories else 0.0,
            "memory_type": "perceptual",
            "vector_stores": vs_stats_all,
            "document_store": {k: v for k, v in db_stats.items() if k.endswith("_count") or k in ["store_type", "db_path"]}
        }
    
    def cross_modal_search(
        self,
        query: Any,
        query_modality: str,
        target_modality: str = None,
        limit: int = 5
    ) -> List[MemoryItem]:
        """跨模态搜索"""
        return self.retrieve(
            query=str(query),
            limit=limit,
            query_modality=query_modality,
            target_modality=target_modality
        )
    
    def get_by_modality(self, modality: str, limit: int = 10) -> List[MemoryItem]:
        """按模态获取记忆"""
        if modality not in self.modality_index:
            return []
        
        perception_ids = self.modality_index[modality]
        results = []
        
        for memory in self.perceptual_memories:
            if memory.metadata.get("perception_id") in perception_ids:
                results.append(memory)
                if len(results) >= limit:
                    break
        
        return results
    
    def generate_content(self, prompt: str, target_modality: str) -> Optional[str]:
        """基于感知记忆生成内容"""
        # 简化的内容生成实现
        # 实际应用中需要使用生成模型
        
        if target_modality not in self.supported_modalities:
            return None
        
        # 检索相关感知记忆
        relevant_memories = self.retrieve(prompt, limit=3)
        
        if not relevant_memories:
            return None
        
        # 简单的内容组合
        if target_modality == "text":
            contents = [memory.content for memory in relevant_memories]
            return f"基于感知记忆生成的内容:\n" + "\n".join(contents)
        
        return f"生成的{target_modality}内容(基于{len(relevant_memories)}个相关记忆)"
    
    def _init_encoders(self) -> Dict[str, Any]:
        """初始化编码器(轻量、确定性,统一输出self.vector_dim维)"""
        encoders = {}
        for modality in self.supported_modalities:
            if modality == "text":
                encoders[modality] = self._text_encoder
            elif modality == "image":
                encoders[modality] = self._image_encoder
            elif modality == "audio":
                encoders[modality] = self._audio_encoder
            else:
                encoders[modality] = self._default_encoder
        return encoders
    
    def _encode_perception(self, data: Any, modality: str, memory_id: str) -> Perception:
        """编码感知数据"""
        encoding = self._encode_data(data, modality)
        
        perception = Perception(
            perception_id=f"perception_{memory_id}",
            data=data,
            modality=modality,
            encoding=encoding,
            metadata={"source": "memory_system"}
        )
        
        return perception
    
    def _encode_data(self, data: Any, modality: str) -> List[float]:
        """编码数据为固定维度向量(按模态维度对齐)"""
        target_dim = self._get_dim_for_modality(modality)
        encoder = self.encoders.get(modality, self._default_encoder)
        vec = encoder(data)
        if not isinstance(vec, list):
            vec = list(vec)
        if len(vec) < target_dim:
            vec = vec + [0.0] * (target_dim - len(vec))
        elif len(vec) > target_dim:
            vec = vec[:target_dim]
        return vec
    
    def _text_encoder(self, text: str) -> List[float]:
        """文本编码器(使用嵌入模型)"""
        emb = self.text_embedder.encode(text or "")
        if hasattr(emb, "tolist"):
            emb = emb.tolist()
        return emb
    
    def _image_encoder_hash(self, image_data: Any) -> List[float]:
        """图像编码器(轻量确定性哈希向量,跨环境稳定)"""
        try:
            if isinstance(image_data, (bytes, bytearray)):
                data_bytes = bytes(image_data)
            elif isinstance(image_data, str) and os.path.exists(image_data):
                with open(image_data, 'rb') as f:
                    data_bytes = f.read()
            else:
                data_bytes = str(image_data).encode('utf-8', errors='ignore')
            hex_str = hashlib.sha256(data_bytes).hexdigest()
            return self._hash_to_vector(hex_str, self._get_dim_for_modality("image"))
        except Exception:
            return self._hash_to_vector(str(image_data), self._get_dim_for_modality("image"))

    def _image_encoder(self, image_data: Any) -> List[float]:
        """图像编码器(优先CLIP,不可用则哈希)"""
        if self._clip_model is None or self._clip_processor is None:
            return self._image_encoder_hash(image_data)
        try:
            from PIL import Image
            if isinstance(image_data, str) and os.path.exists(image_data):
                image = Image.open(image_data).convert('RGB')
            elif isinstance(image_data, (bytes, bytearray)):
                from io import BytesIO
                image = Image.open(BytesIO(bytes(image_data))).convert('RGB')
            else:
                # 退回到哈希
                return self._image_encoder_hash(image_data)
            inputs = self._clip_processor(images=image, return_tensors="pt")
            with self._no_grad():
                feats = self._clip_model.get_image_features(**inputs)
            vec = feats[0].detach().cpu().numpy().tolist()
            return vec
        except Exception:
            return self._image_encoder_hash(image_data)
    
    def _audio_encoder_hash(self, audio_data: Any) -> List[float]:
        """音频编码器(轻量确定性哈希向量)"""
        try:
            if isinstance(audio_data, (bytes, bytearray)):
                data_bytes = bytes(audio_data)
            elif isinstance(audio_data, str) and os.path.exists(audio_data):
                with open(audio_data, 'rb') as f:
                    data_bytes = f.read()
            else:
                data_bytes = str(audio_data).encode('utf-8', errors='ignore')
            hex_str = hashlib.sha256(data_bytes).hexdigest()
            return self._hash_to_vector(hex_str, self._get_dim_for_modality("audio"))
        except Exception:
            return self._hash_to_vector(str(audio_data), self._get_dim_for_modality("audio"))

    def _audio_encoder(self, audio_data: Any) -> List[float]:
        """音频编码器(优先CLAP,不可用则哈希)"""
        if self._clap_model is None or self._clap_processor is None:
            return self._audio_encoder_hash(audio_data)
        try:
            import numpy as np
            # 加载音频(需要 librosa)
            import librosa
            if isinstance(audio_data, str) and os.path.exists(audio_data):
                speech, sr = librosa.load(audio_data, sr=48000, mono=True)
            elif isinstance(audio_data, (bytes, bytearray)):
                # 临时文件方式加载
                import tempfile
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
                    tmp.write(bytes(audio_data))
                    tmp_path = tmp.name
                speech, sr = librosa.load(tmp_path, sr=48000, mono=True)
                try:
                    os.remove(tmp_path)
                except Exception:
                    pass
            else:
                return self._audio_encoder_hash(audio_data)
            inputs = self._clap_processor(audios=speech, sampling_rate=48000, return_tensors="pt")
            with self._no_grad():
                feats = self._clap_model.get_audio_features(**inputs)
            vec = feats[0].detach().cpu().numpy().tolist()
            return vec
        except Exception:
            return self._audio_encoder_hash(audio_data)

    def _default_encoder(self, data: Any) -> List[float]:
        """默认编码器(退化为文本嵌入或哈希)"""
        try:
            return self._text_encoder(str(data))
        except Exception:
            return self._hash_to_vector(str(data), self.vector_dim)
    
    def _calculate_similarity(self, encoding1: List[float], encoding2: List[float]) -> float:
        """计算编码相似度"""
        if not encoding1 or not encoding2:
            return 0.0
        
        # 确保长度一致
        min_len = min(len(encoding1), len(encoding2))
        if min_len == 0:
            return 0.0
        
        # 计算余弦相似度
        dot_product = sum(a * b for a, b in zip(encoding1[:min_len], encoding2[:min_len]))
        norm1 = sum(a * a for a in encoding1[:min_len]) ** 0.5
        norm2 = sum(a * a for a in encoding2[:min_len]) ** 0.5
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        return dot_product / (norm1 * norm2)

    def _hash_to_vector(self, data_str: str, dim: int) -> List[float]:
        """将字符串哈希为固定维度的[0,1]向量(确定性)"""
        seed = int(hashlib.sha256(data_str.encode("utf-8", errors="ignore")).hexdigest(), 16) % (2**32)
        rng = random.Random(seed)
        return [rng.random() for _ in range(dim)]

    class _no_grad:
        def __enter__(self):
            try:
                import torch
                self.prev = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
            except Exception:
                self.prev = None
            return self
        def __exit__(self, exc_type, exc, tb):
            try:
                import torch
                if self.prev is not None:
                    torch.set_grad_enabled(self.prev)
            except Exception:
                pass

    def _get_vector_store_for_modality(self, modality: Optional[str]) -> QdrantVectorStore:
        mod = (modality or "text").lower()
        return self.vector_stores.get(mod, self.vector_stores["text"])

    def _get_dim_for_modality(self, modality: Optional[str]) -> int:
        mod = (modality or "text").lower()
        if mod == "image":
            return int(self._image_dim or self.vector_dim)
        if mod == "audio":
            return int(self._audio_dim or self.vector_dim)
        return int(self.vector_dim)

有了上述模块,智能体就可以实现记忆功能。

二、RAG系统:知识检索增强

1.RAG工具基类

首先,我们需要建立一个统一的RAG接口RAGTool。

python 复制代码
"""RAG工具 - 检索增强生成

为HelloAgents框架提供简洁易用的RAG能力:
- 🔄 数据流程:用户数据 → 文档解析 → 向量化存储 → 智能检索 → LLM增强问答
- 📚 多格式支持:PDF、Word、Excel、PPT、图片、音频、网页等
- 🧠 智能问答:自动检索相关内容,注入提示词,生成准确答案
- 🏷️ 命名空间:支持多项目隔离,便于管理不同知识库

使用示例:
```python
# 1. 初始化RAG工具
rag = RAGTool()

# 2. 添加文档
rag.run({"action": "add_document", "file_path": "document.pdf"})

# 3. 智能问答
answer = rag.run({"action": "ask", "question": "什么是机器学习?"})
```
"""

from typing import Dict, Any, List, Optional
import os
import time

from ..base import Tool, ToolParameter
from ...memory.rag.pipeline import create_rag_pipeline
from ...core.llm import HelloAgentsLLM

class RAGTool(Tool):
    """RAG工具
    
    提供完整的 RAG 能力:
    - 添加多格式文档(PDF、Office、图片、音频等)
    - 智能检索与召回
    - LLM 增强问答
    - 知识库管理
    """
    
    def __init__(
        self,
        knowledge_base_path: str = "./knowledge_base",
        qdrant_url: str = None,
        qdrant_api_key: str = None,
        collection_name: str = "rag_knowledge_base",
        rag_namespace: str = "default"
    ):
        super().__init__(
            name="rag",
            description="RAG工具 - 支持多格式文档检索增强生成,提供智能问答能力"
        )
        
        self.knowledge_base_path = knowledge_base_path
        self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL")
        self.qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
        self.collection_name = collection_name
        self.rag_namespace = rag_namespace
        self._pipelines: Dict[str, Dict[str, Any]] = {}
        
        # 确保知识库目录存在
        os.makedirs(knowledge_base_path, exist_ok=True)
        
        # 初始化组件
        self._init_components()
    
    def _init_components(self):
        """初始化RAG组件"""
        try:
            # 初始化默认命名空间的 RAG 管道
            default_pipeline = create_rag_pipeline(
                qdrant_url=self.qdrant_url,
                qdrant_api_key=self.qdrant_api_key,
                collection_name=self.collection_name,
                rag_namespace=self.rag_namespace
            )
            self._pipelines[self.rag_namespace] = default_pipeline

            # 初始化 LLM 用于回答生成
            self.llm = HelloAgentsLLM()

            self.initialized = True
            print(f"✅ RAG工具初始化成功: namespace={self.rag_namespace}, collection={self.collection_name}")
            
        except Exception as e:
            self.initialized = False
            self.init_error = str(e)
            print(f"❌ RAG工具初始化失败: {e}")

    def _get_pipeline(self, namespace: Optional[str] = None) -> Dict[str, Any]:
        """获取指定命名空间的 RAG 管道,若不存在则自动创建"""
        target_ns = namespace or self.rag_namespace
        if target_ns in self._pipelines:
            return self._pipelines[target_ns]

        pipeline = create_rag_pipeline(
            qdrant_url=self.qdrant_url,
            qdrant_api_key=self.qdrant_api_key,
            collection_name=self.collection_name,
            rag_namespace=target_ns
        )
        self._pipelines[target_ns] = pipeline
        return pipeline

    def run(self, parameters: Dict[str, Any]) -> str:
        """执行工具 - Tool基类要求的接口

        Args:
            parameters: 工具参数字典,必须包含action参数

        Returns:
            执行结果字符串
        """
        if not self.validate_parameters(parameters):
            return "❌ 参数验证失败:缺少必需的参数"

        action = parameters.get("action")
        # 移除action参数,传递其余参数给execute方法
        kwargs = {k: v for k, v in parameters.items() if k != "action"}

        return self.execute(action, **kwargs)

    def get_parameters(self) -> List[ToolParameter]:
        """获取工具参数定义 - Tool基类要求的接口"""
        return [
            # 核心操作参数
            ToolParameter(
                name="action",
                type="string",
                description="操作类型:add_document(添加文档), add_text(添加文本), ask(智能问答), search(搜索), stats(统计), clear(清空)",
                required=True
            ),
            
            # 内容参数
            ToolParameter(
                name="file_path",
                type="string",
                description="文档文件路径(支持PDF、Word、Excel、PPT、图片、音频等多种格式)",
                required=False
            ),
            ToolParameter(
                name="text",
                type="string",
                description="要添加的文本内容",
                required=False
            ),
            ToolParameter(
                name="question",
                type="string", 
                description="用户问题(用于智能问答)",
                required=False
            ),
            ToolParameter(
                name="query",
                type="string",
                description="搜索查询词(用于基础搜索)",
                required=False
            ),
            
            # 可选配置参数
            ToolParameter(
                name="namespace",
                type="string",
                description="知识库命名空间(用于隔离不同项目,默认:default)",
                required=False,
                default="default"
            ),
            ToolParameter(
                name="limit",
                type="integer",
                description="返回结果数量(默认:5)",
                required=False,
                default=5
            ),
            ToolParameter(
                name="include_citations",
                type="boolean",
                description="是否包含引用来源(默认:true)",
                required=False,
                default=True
            )
        ]
    
    def execute(self, action: str, **kwargs) -> str:
        """执行RAG操作
        
        主要操作流程:
        1. add_document/add_text: 数据 → 解析 → 分块 → 向量化 → 存储
        2. ask: 问题 → 检索 → 上下文注入 → LLM生成答案
        3. search: 查询 → 向量检索 → 返回相关片段
        """
        
        if not self.initialized:
            return f"❌ RAG工具未正确初始化,请检查配置: {getattr(self, 'init_error', '未知错误')}"
        
        # 参数预处理
        kwargs = self._preprocess_parameters(action, **kwargs)
        
        try:
            if action == "add_document":
                return self._add_document(**kwargs)
            elif action == "add_text":
                return self._add_text(**kwargs)
            elif action == "ask":
                return self._ask(**kwargs)
            elif action == "search":
                return self._search(**kwargs)
            elif action == "stats":
                return self._get_stats(namespace=kwargs.get("namespace"))
            elif action == "clear":
                return self._clear_knowledge_base(**kwargs)
            else:
                available_actions = ["add_document", "add_text", "ask", "search", "stats", "clear"]
                return f"❌ 不支持的操作: {action}\n✅ 可用操作: {', '.join(available_actions)}"
                
        except Exception as e:
            return f"❌ 执行操作 '{action}' 时发生错误: {str(e)}"
    
    def _preprocess_parameters(self, action: str, **kwargs) -> Dict[str, Any]:
        """预处理参数,设置默认值和验证"""
        # 设置默认值
        defaults = {
            "namespace": "default",
            "limit": 5,
            "include_citations": True,
            "enable_advanced_search": True,
            "max_chars": 1200,
            "min_score": 0.1,
            "chunk_size": 800,
            "chunk_overlap": 100
        }
        
        for key, value in defaults.items():
            if key not in kwargs or kwargs[key] is None:
                kwargs[key] = value
        
        # 参数验证
        if action in ["add_document"] and not kwargs.get("file_path"):
            raise ValueError("add_document 操作需要提供 file_path 参数")
        elif action in ["add_text"] and not kwargs.get("text"):
            raise ValueError("add_text 操作需要提供 text 参数")
        elif action in ["ask"] and not (kwargs.get("question") or kwargs.get("query")):
            raise ValueError("ask 操作需要提供 question 或 query 参数")
        elif action in ["search"] and not (kwargs.get("query") or kwargs.get("question")):
            raise ValueError("search 操作需要提供 query 或 question 参数")
            
        return kwargs

    def _add_document(self, file_path: str, document_id: str = None, namespace: Optional[str] = None, chunk_size: int = 800, chunk_overlap: int = 100, **kwargs) -> str:
        """添加文档到知识库(支持多格式)"""
        try:
            if not file_path or not os.path.exists(file_path):
                return f"❌ 文件不存在: {file_path}"
            
            pipeline = self._get_pipeline(namespace)
            t0 = time.time()

            chunks_added = pipeline["add_documents"](
                file_paths=[file_path],
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap
            )
            
            t1 = time.time()
            process_ms = int((t1 - t0) * 1000)
            
            if chunks_added == 0:
                return f"⚠️ 未能从文件解析内容: {os.path.basename(file_path)}"
            
            return (
                f"✅ 文档已添加到知识库: {os.path.basename(file_path)}\n"
                f"📊 分块数量: {chunks_added}\n"
                f"⏱️ 处理时间: {process_ms}ms\n"
                f"📝 命名空间: {pipeline.get('namespace', self.rag_namespace)}"
            )
            
        except Exception as e:
            return f"❌ 添加文档失败: {str(e)}"
    
    def _add_text(self, text: str, document_id: str = None, metadata: Optional[Dict[str, Any]] = None, namespace: Optional[str] = None, chunk_size: int = 800, chunk_overlap: int = 100, **kwargs) -> str:
        """添加文本到知识库"""
        try:
            if not text or not text.strip():
                return "❌ 文本内容不能为空"
            
            # 创建临时文件
            document_id = document_id or f"text_{abs(hash(text)) % 100000}"
            tmp_path = os.path.join(self.knowledge_base_path, f"{document_id}.md")
            
            try:
                with open(tmp_path, 'w', encoding='utf-8') as f:
                    f.write(text)
                
                pipeline = self._get_pipeline(namespace)
                t0 = time.time()

                chunks_added = pipeline["add_documents"](
                    file_paths=[tmp_path],
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap
                )
                
                t1 = time.time()
                process_ms = int((t1 - t0) * 1000)
                
                if chunks_added == 0:
                    return f"⚠️ 未能从文本生成有效分块"
                
                return (
                    f"✅ 文本已添加到知识库: {document_id}\n"
                    f"📊 分块数量: {chunks_added}\n"
                    f"⏱️ 处理时间: {process_ms}ms\n"
                    f"📝 命名空间: {pipeline.get('namespace', self.rag_namespace)}"
                )
                
            finally:
                # 清理临时文件
                try:
                    if os.path.exists(tmp_path):
                        os.remove(tmp_path)
                except Exception:
                    pass
            
        except Exception as e:
            return f"❌ 添加文本失败: {str(e)}"
    
    def _search(self, query: str, limit: int = 5, min_score: float = 0.1, enable_advanced_search: bool = True, max_chars: int = 1200, include_citations: bool = True, namespace: Optional[str] = None, **kwargs) -> str:
        """搜索知识库"""
        try:
            if not query or not query.strip():
                return "❌ 搜索查询不能为空"
            
            # 使用统一 RAG 管道搜索
            pipeline = self._get_pipeline(namespace)

            if enable_advanced_search:
                results = pipeline["search_advanced"](
                    query=query,
                    top_k=limit,
                    enable_mqe=True,
                    enable_hyde=True,
                    score_threshold=min_score if min_score > 0 else None
                )
            else:
                results = pipeline["search"](
                    query=query,
                    top_k=limit,
                    score_threshold=min_score if min_score > 0 else None
                )
            
            if not results:
                return f"🔍 未找到与 '{query}' 相关的内容"
            
            # 格式化搜索结果
            search_result = ["搜索结果:"]
            for i, result in enumerate(results, 1):
                meta = result.get("metadata", {})
                score = result.get("score", 0.0)
                content = meta.get("content", "")[:200] + "..."
                source = meta.get("source_path", "unknown")
                
                # 安全处理Unicode
                def clean_text(text):
                    try:
                        return str(text).encode('utf-8', errors='ignore').decode('utf-8')
                    except Exception:
                        return str(text)
                
                clean_content = clean_text(content)
                clean_source = clean_text(source)
                
                search_result.append(f"\n{i}. 文档: **{clean_source}** (相似度: {score:.3f})")
                search_result.append(f"   {clean_content}")
                
                if include_citations and meta.get("heading_path"):
                    clean_heading = clean_text(str(meta['heading_path']))
                    search_result.append(f"   章节: {clean_heading}")
            
            return "\n".join(search_result)
            
        except Exception as e:
            return f"❌ 搜索失败: {str(e)}"
    
    def _ask(self, question: Optional[str] = None, query: Optional[str] = None, limit: int = 5, enable_advanced_search: bool = True, include_citations: bool = True, max_chars: int = 1200, namespace: Optional[str] = None, **kwargs) -> str:
        """智能问答:检索 → 上下文注入 → LLM生成答案
        
        核心流程:
        1. 解析用户问题
        2. 智能检索相关内容
        3. 构建上下文和提示词
        4. LLM生成准确答案
        5. 添加引用来源
        """
        try:
            # 获取用户问题(question 优先级高于 query)
            user_question = question or query
            if not user_question or not user_question.strip():
                return "❌ 请提供要询问的问题"
            
            user_question = user_question.strip()
            print(f"🔍 智能问答: {user_question}")
            
            # 1. 检索相关内容
            pipeline = self._get_pipeline(namespace)
            search_start = time.time()
            
            if enable_advanced_search:
                results = pipeline["search_advanced"](
                    query=user_question,
                    top_k=limit,
                    enable_mqe=True,
                    enable_hyde=True
                )
            else:
                results = pipeline["search"](
                    query=user_question,
                    top_k=limit
                )
            
            search_time = int((time.time() - search_start) * 1000)
            
            if not results:
                return (
                    f"🤔 抱歉,我在知识库中没有找到与「{user_question}」相关的信息。\n\n"
                    f"💡 建议:\n"
                    f"• 尝试使用更简洁的关键词\n"
                    f"• 检查是否已添加相关文档\n"
                    f"• 使用 stats 操作查看知识库状态"
                )
            
            # 2. 智能整理上下文
            context_parts = []
            citations = []
            total_score = 0
            
            for i, result in enumerate(results):
                meta = result.get("metadata", {})
                content = meta.get("content", "").strip()
                source = meta.get("source_path", "unknown")
                score = result.get("score", 0.0)
                total_score += score
                
                if content:
                    # 清理内容格式
                    cleaned_content = self._clean_content_for_context(content)
                    context_parts.append(f"片段 {i+1}:{cleaned_content}")
                    
                    if include_citations:
                        citations.append({
                            "index": i+1,
                            "source": os.path.basename(source),
                            "score": score
                        })
            
            # 3. 构建上下文(智能截断)
            context = "\n\n".join(context_parts)
            if len(context) > max_chars:
                # 智能截断,保持完整性
                context = self._smart_truncate_context(context, max_chars)
            
            # 4. 构建增强提示词
            system_prompt = self._build_system_prompt()
            user_prompt = self._build_user_prompt(user_question, context)
            
            enhanced_prompt = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            # 5. 调用 LLM 生成答案
            llm_start = time.time()
            answer = self.llm.invoke(enhanced_prompt)
            llm_time = int((time.time() - llm_start) * 1000)
            
            if not answer or not answer.strip():
                return "❌ LLM未能生成有效答案,请稍后重试"
            
            # 6. 构建最终回答
            final_answer = self._format_final_answer(
                question=user_question,
                answer=answer.strip(),
                citations=citations if include_citations else None,
                search_time=search_time,
                llm_time=llm_time,
                avg_score=total_score / len(results) if results else 0
            )
            
            return final_answer
            
        except Exception as e:
            return f"❌ 智能问答失败: {str(e)}\n💡 请检查知识库状态或稍后重试"
    
    def _clean_content_for_context(self, content: str) -> str:
        """清理内容用于上下文"""
        # 移除过多的换行和空格
        content = " ".join(content.split())
        # 截断过长内容
        if len(content) > 300:
            content = content[:300] + "..."
        return content
    
    def _smart_truncate_context(self, context: str, max_chars: int) -> str:
        """智能截断上下文,保持段落完整性"""
        if len(context) <= max_chars:
            return context
        
        # 寻找最近的段落分隔符
        truncated = context[:max_chars]
        last_break = truncated.rfind("\n\n")
        
        if last_break > max_chars * 0.7:  # 如果断点位置合理
            return truncated[:last_break] + "\n\n[...更多内容被截断]"
        else:
            return truncated[:max_chars-20] + "...[内容被截断]"
    
    def _build_system_prompt(self) -> str:
        """构建系统提示词"""
        return (
            "你是一个专业的知识助手,具备以下能力:\n"
            "1. 📖 精准理解:仔细理解用户问题的核心意图\n"
            "2. 🎯 可信回答:严格基于提供的上下文信息回答,不编造内容\n"
            "3. 🔍 信息整合:从多个片段中提取关键信息,形成完整答案\n"
            "4. 💡 清晰表达:用简洁明了的语言回答,适当使用结构化格式\n"
            "5. 🚫 诚实表达:如果上下文不足以回答问题,请坦诚说明\n\n"
            "回答格式要求:\n"
            "• 直接回答核心问题\n"
            "• 必要时使用要点或步骤\n"
            "• 引用关键原文时使用引号\n"
            "• 避免重复和冗余"
        )
    
    def _build_user_prompt(self, question: str, context: str) -> str:
        """构建用户提示词"""
        return (
            f"请基于以下上下文信息回答问题:\n\n"
            f"【问题】{question}\n\n"
            f"【相关上下文】\n{context}\n\n"
            f"【要求】请提供准确、有帮助的回答。如果上下文信息不足,请说明需要什么额外信息。"
        )
    
    def _format_final_answer(self, question: str, answer: str, citations: Optional[List[Dict]] = None, search_time: int = 0, llm_time: int = 0, avg_score: float = 0) -> str:
        """格式化最终答案"""
        result = [f"🤖 **智能问答结果**\n"]
        result.append(answer)
        
        if citations:
            result.append("\n\n📚 **参考来源**")
            for citation in citations:
                score_emoji = "🟢" if citation["score"] > 0.8 else "🟡" if citation["score"] > 0.6 else "🔵"
                result.append(f"{score_emoji} [{citation['index']}] {citation['source']} (相似度: {citation['score']:.3f})")
        
        # 添加性能信息(调试模式)
        result.append(f"\n⚡ 检索: {search_time}ms | 生成: {llm_time}ms | 平均相似度: {avg_score:.3f}")
        
        return "\n".join(result)

    def _clear_knowledge_base(self, confirm: bool = False, namespace: Optional[str] = None, **kwargs) -> str:
        """清空知识库"""
        try:
            if not confirm:
                return (
                    "⚠️ 危险操作:清空知识库将删除所有数据!\n"
                    "请使用 confirm=true 参数确认执行。"
                )
            
            pipeline = self._get_pipeline(namespace)
            store = pipeline.get("store")
            namespace_id = pipeline.get("namespace", self.rag_namespace)
            success = store.clear_collection() if store else False
            
            if success:
                # 重新初始化该命名空间
                self._pipelines[namespace_id] = create_rag_pipeline(
                    qdrant_url=self.qdrant_url,
                    qdrant_api_key=self.qdrant_api_key,
                    collection_name=self.collection_name,
                    rag_namespace=namespace_id
                )
                return f"✅ 知识库已成功清空(命名空间:{namespace_id})"
            else:
                return "❌ 清空知识库失败"
            
        except Exception as e:
            return f"❌ 清空知识库失败: {str(e)}"

    def _get_stats(self, namespace: Optional[str] = None) -> str:
        """获取知识库统计"""
        try:
            pipeline = self._get_pipeline(namespace)
            stats = pipeline["get_stats"]()
            
            stats_info = [
                "📊 **RAG 知识库统计**",
                f"📝 命名空间: {pipeline.get('namespace', self.rag_namespace)}",
                f"📋 集合名称: {self.collection_name}",
                f"📂 存储根路径: {self.knowledge_base_path}"
            ]
            
            # 添加存储统计
            if stats:
                store_type = stats.get("store_type", "unknown")
                total_vectors = (
                    stats.get("points_count") or 
                    stats.get("vectors_count") or 
                    stats.get("count") or 0
                )
                
                stats_info.extend([
                    f"📦 存储类型: {store_type}",
                    f"📊 文档分块数: {int(total_vectors)}",
                ])
                
                if "config" in stats:
                    config = stats["config"]
                    if isinstance(config, dict):
                        vector_size = config.get("vector_size", "unknown")
                        distance = config.get("distance", "unknown")
                        stats_info.extend([
                            f"🔢 向量维度: {vector_size}",
                            f"📎 距离度量: {distance}"
                        ])
            
            # 添加系统状态
            stats_info.extend([
                "",
                "🟢 **系统状态**",
                f"✅ RAG 管道: {'正常' if self.initialized else '异常'}",
                f"✅ LLM 连接: {'正常' if hasattr(self, 'llm') else '异常'}"
            ])
            
            return "\n".join(stats_info)
            
        except Exception as e:
            return f"❌ 获取统计信息失败: {str(e)}"

    def get_relevant_context(self, query: str, limit: int = 3, max_chars: int = 1200, namespace: Optional[str] = None) -> str:
        """为查询获取相关上下文
        
        这个方法可以被Agent调用来获取相关的知识库上下文
        """
        try:
            if not query:
                return ""
            
            # 使用统一 RAG 管道搜索
            pipeline = self._get_pipeline(namespace)
            results = pipeline["search"](
                query=query,
                top_k=limit
            )
            
            if not results:
                return ""
            
            # 合并上下文
            context_parts = []
            for result in results:
                content = result.get("metadata", {}).get("content", "")
                if content:
                    context_parts.append(content)
            
            merged_context = "\n\n".join(context_parts)
            
            # 限制长度
            if len(merged_context) > max_chars:
                merged_context = merged_context[:max_chars] + "..."
            
            return merged_context
            
        except Exception as e:
            return f"获取上下文失败: {str(e)}"
    
    def batch_add_texts(self, texts: List[str], document_ids: Optional[List[str]] = None, chunk_size: int = 800, chunk_overlap: int = 100, namespace: Optional[str] = None) -> str:
        """批量添加文本"""
        try:
            if not texts:
                return "❌ 文本列表不能为空"
            
            if document_ids and len(document_ids) != len(texts):
                return "❌ 文本数量和文档ID数量不匹配"
            
            pipeline = self._get_pipeline(namespace)
            t0 = time.time()
            
            total_chunks = 0
            successful_files = []
            
            for i, text in enumerate(texts):
                if not text or not text.strip():
                    continue
                    
                doc_id = document_ids[i] if document_ids else f"batch_text_{i}"
                tmp_path = os.path.join(self.knowledge_base_path, f"{doc_id}.md")
                
                try:
                    with open(tmp_path, 'w', encoding='utf-8') as f:
                        f.write(text)
                    
                    chunks_added = pipeline["add_documents"](
                        file_paths=[tmp_path],
                        chunk_size=chunk_size,
                        chunk_overlap=chunk_overlap
                    )
                    
                    total_chunks += chunks_added
                    successful_files.append(doc_id)
                    
                finally:
                    # 清理临时文件
                    try:
                        if os.path.exists(tmp_path):
                            os.remove(tmp_path)
                    except Exception:
                        pass
            
            t1 = time.time()
            process_ms = int((t1 - t0) * 1000)
            
            return (
                f"✅ 批量添加完成\n"
                f"📊 成功文件: {len(successful_files)}/{len(texts)}\n"
                f"📊 总分块数: {total_chunks}\n"
                f"⏱️ 处理时间: {process_ms}ms"
            )
            
        except Exception as e:
            return f"❌ 批量添加失败: {str(e)}"
    
    def clear_all_namespaces(self) -> str:
        """清空当前工具管理的所有命名空间数据"""
        try:
            for ns, pipeline in self._pipelines.items():
                store = pipeline.get("store")
                if store:
                    store.clear_collection()
            self._pipelines.clear()
            # 重新初始化默认命名空间
            self._init_components()
            return "✅ 所有命名空间数据已清空并重新初始化"
        except Exception as e:
            return f"❌ 清空所有命名空间失败: {str(e)}"
    
    # ========================================
    # 便捷接口方法(简化用户调用)
    # ========================================
    
    def add_document(self, file_path: str, namespace: str = "default") -> str:
        """便捷方法:添加单个文档"""
        return self.run({
            "action": "add_document",
            "file_path": file_path,
            "namespace": namespace
        })
    
    def add_text(self, text: str, namespace: str = "default", document_id: str = None) -> str:
        """便捷方法:添加文本内容"""
        return self.run({
            "action": "add_text",
            "text": text,
            "namespace": namespace,
            "document_id": document_id
        })
    
    def ask(self, question: str, namespace: str = "default", **kwargs) -> str:
        """便捷方法:智能问答"""
        params = {
            "action": "ask",
            "question": question,
            "namespace": namespace
        }
        params.update(kwargs)
        return self.run(params)
    
    def search(self, query: str, namespace: str = "default", **kwargs) -> str:
        """便捷方法:搜索知识库"""
        params = {
            "action": "search",
            "query": query,
            "namespace": namespace
        }
        params.update(kwargs)
        return self.run(params)
    
    def add_documents_batch(self, file_paths: List[str], namespace: str = "default") -> str:
        """批量添加多个文档"""
        if not file_paths:
            return "❌ 文件路径列表不能为空"
        
        results = []
        successful = 0
        failed = 0
        total_chunks = 0
        start_time = time.time()
        
        for i, file_path in enumerate(file_paths, 1):
            print(f"📄 处理文档 {i}/{len(file_paths)}: {os.path.basename(file_path)}")
            
            try:
                result = self.add_document(file_path, namespace)
                if "✅" in result:
                    successful += 1
                    # 提取分块数量
                    if "分块数量:" in result:
                        chunks = int(result.split("分块数量: ")[1].split("\n")[0])
                        total_chunks += chunks
                else:
                    failed += 1
                    results.append(f"❌ {os.path.basename(file_path)}: 处理失败")
            except Exception as e:
                failed += 1
                results.append(f"❌ {os.path.basename(file_path)}: {str(e)}")
        
        process_time = int((time.time() - start_time) * 1000)
        
        summary = [
            "📊 **批量处理完成**",
            f"✅ 成功: {successful}/{len(file_paths)} 个文档",
            f"📊 总分块数: {total_chunks}",
            f"⏱️ 总耗时: {process_time}ms",
            f"📝 命名空间: {namespace}"
        ]
        
        if failed > 0:
            summary.append(f"❌ 失败: {failed} 个文档")
            summary.append("\n**失败详情:**")
            summary.extend(results)
        
        return "\n".join(summary)
    
    def add_texts_batch(self, texts: List[str], namespace: str = "default", document_ids: Optional[List[str]] = None) -> str:
        """批量添加多个文本"""
        if not texts:
            return "❌ 文本列表不能为空"
        
        if document_ids and len(document_ids) != len(texts):
            return "❌ 文本数量和文档ID数量不匹配"
        
        results = []
        successful = 0
        failed = 0
        total_chunks = 0
        start_time = time.time()
        
        for i, text in enumerate(texts):
            doc_id = document_ids[i] if document_ids else f"batch_text_{i+1}"
            print(f"📝 处理文本 {i+1}/{len(texts)}: {doc_id}")
            
            try:
                result = self.add_text(text, namespace, doc_id)
                if "✅" in result:
                    successful += 1
                    # 提取分块数量
                    if "分块数量:" in result:
                        chunks = int(result.split("分块数量: ")[1].split("\n")[0])
                        total_chunks += chunks
                else:
                    failed += 1
                    results.append(f"❌ {doc_id}: 处理失败")
            except Exception as e:
                failed += 1
                results.append(f"❌ {doc_id}: {str(e)}")
        
        process_time = int((time.time() - start_time) * 1000)
        
        summary = [
            "📊 **批量文本处理完成**",
            f"✅ 成功: {successful}/{len(texts)} 个文本",
            f"📊 总分块数: {total_chunks}",
            f"⏱️ 总耗时: {process_time}ms",
            f"📝 命名空间: {namespace}"
        ]
        
        if failed > 0:
            summary.append(f"❌ 失败: {failed} 个文本")
            summary.append("\n**失败详情:**")
            summary.extend(results)
        
        return "\n".join(summary)

2.多模态文档载入

在这个智能体中使用MarkltDown作为一个统一的文档转换引擎,支持几乎所有常见的文档格式。无论是PDF、Word、Excel还是图片,最终都会转换为标准Markdown格式。

python 复制代码
def _convert_to_markdown(path: str) -> str:
    """
    Universal document reader using MarkItDown with enhanced PDF processing.
    Converts any supported file format to markdown text.
    """
    if not os.path.exists(path):
        return ""
    
    # 对PDF文件使用增强处理
    ext = (os.path.splitext(path)[1] or '').lower()
    if ext == '.pdf':
        return _enhanced_pdf_processing(path)
    
    # 其他格式使用原有MarkItDown
    md_instance = _get_markitdown_instance()
    if md_instance is None:
        return _fallback_text_reader(path)
    
    try:
        result = md_instance.convert(path)
        text = getattr(result, "text_content", None)
        if isinstance(text, str) and text.strip():
            return text
        return ""
    except Exception as e:
        print(f"[WARNING] MarkItDown failed for {path}: {e}")
        return _fallback_text_reader(path)

3.智能分块策略

python 复制代码
def _split_paragraphs_with_headings(text: str) -> List[Dict]:
    lines = text.splitlines()
    heading_stack: List[str] = []
    paragraphs: List[Dict] = []
    buf: List[str] = []
    char_pos = 0
    def flush_buf(end_pos: int):
        if not buf:
            return
        content = "\n".join(buf).strip()
        if not content:
            return
        paragraphs.append({
            "content": content,
            "heading_path": " > ".join(heading_stack) if heading_stack else None,
            "start": max(0, end_pos - len(content)),
            "end": end_pos,
        })
    for ln in lines:
        raw = ln
        if raw.strip().startswith("#"):
            # heading line
            flush_buf(char_pos)
            level = len(raw) - len(raw.lstrip('#'))
            title = raw.lstrip('#').strip()
            if level <= 0:
                level = 1
            if level <= len(heading_stack):
                heading_stack = heading_stack[:level-1]
            heading_stack.append(title)
            char_pos += len(raw) + 1
            continue
        # paragraph accumulation
        if raw.strip() == "":
            flush_buf(char_pos)
            buf = []
        else:
            buf.append(raw)
        char_pos += len(raw) + 1
    flush_buf(char_pos)
    if not paragraphs:
        paragraphs = [{"content": text, "heading_path": None, "start": 0, "end": len(text)}]
    return paragraphs


def _chunk_paragraphs(paragraphs: List[Dict], chunk_tokens: int, overlap_tokens: int) -> List[Dict]:
    chunks: List[Dict] = []
    cur: List[Dict] = []
    cur_tokens = 0
    i = 0
    while i < len(paragraphs):
        p = paragraphs[i]
        p_tokens = _approx_token_len(p["content"]) or 1
        if cur_tokens + p_tokens <= chunk_tokens or not cur:
            cur.append(p)
            cur_tokens += p_tokens
            i += 1
        else:
            # emit current chunk
            content = "\n\n".join(x["content"] for x in cur)
            start = cur[0]["start"]
            end = cur[-1]["end"]
            heading_path = next((x["heading_path"] for x in reversed(cur) if x.get("heading_path")), None)
            chunks.append({
                "content": content,
                "start": start,
                "end": end,
                "heading_path": heading_path,
            })
            # build overlap by keeping tail tokens
            if overlap_tokens > 0 and cur:
                kept: List[Dict] = []
                kept_tokens = 0
                for x in reversed(cur):
                    t = _approx_token_len(x["content"]) or 1
                    if kept_tokens + t > overlap_tokens:
                        break
                    kept.append(x)
                    kept_tokens += t
                cur = list(reversed(kept))
                cur_tokens = kept_tokens
            else:
                cur = []
                cur_tokens = 0
    if cur:
        content = "\n\n".join(x["content"] for x in cur)
        start = cur[0]["start"]
        end = cur[-1]["end"]
        heading_path = next((x["heading_path"] for x in reversed(cur) if x.get("heading_path")), None)
        chunks.append({
            "content": content,
            "start": start,
            "end": end,
            "heading_path": heading_path,
        })
    return chunks

4.统一嵌入与向量存储

python 复制代码
def index_chunks(
    store = None, 
    chunks: List[Dict] = None, 
    cache_db: Optional[str] = None, 
    batch_size: int = 64,
    rag_namespace: str = "default"
) -> None:
    """
    Index markdown chunks with unified embedding and Qdrant storage.
    Uses百炼 API with fallback to sentence-transformers.
    """
    if not chunks:
        print("[RAG] No chunks to index")
        return
    
    # Use unified embedding from embedding module
    embedder = get_text_embedder()
    dimension = get_dimension(384)
    
    # Create default Qdrant store if not provided
    if store is None:
        store = _create_default_vector_store(dimension)
        print(f"[RAG] Created default Qdrant store with dimension {dimension}")
    
    # Preprocess markdown texts for better embeddings
    processed_texts = []
    for c in chunks:
        raw_content = c["content"]
        processed_content = _preprocess_markdown_for_embedding(raw_content)
        processed_texts.append(processed_content)
    
    print(f"[RAG] Embedding start: total_texts={len(processed_texts)} batch_size={batch_size}")
    
    # Batch encoding with unified embedder
    vecs: List[List[float]] = []
    for i in range(0, len(processed_texts), batch_size):
        part = processed_texts[i:i+batch_size]
        try:
            # Use unified embedder directly (handles caching internally)
            part_vecs = embedder.encode(part)
            
            # Normalize to List[List[float]]
            if not isinstance(part_vecs, list):
                # 单个numpy数组转为列表中的列表
                if hasattr(part_vecs, "tolist"):
                    part_vecs = [part_vecs.tolist()]
                else:
                    part_vecs = [list(part_vecs)]
            else:
                # 检查是否是嵌套列表
                if part_vecs and not isinstance(part_vecs[0], (list, tuple)) and hasattr(part_vecs[0], "__len__"):
                    # numpy数组列表 -> 转换每个数组
                    normalized_vecs = []
                    for v in part_vecs:
                        if hasattr(v, "tolist"):
                            normalized_vecs.append(v.tolist())
                        else:
                            normalized_vecs.append(list(v))
                    part_vecs = normalized_vecs
                elif part_vecs and not isinstance(part_vecs[0], (list, tuple)):
                    # 单个向量被误判为列表,实际应该包装成[[...]]
                    if hasattr(part_vecs, "tolist"):
                        part_vecs = [part_vecs.tolist()]
                    else:
                        part_vecs = [list(part_vecs)]
            
            for v in part_vecs:
                try:
                    # 确保向量是float列表
                    if hasattr(v, "tolist"):
                        v = v.tolist()
                    v_norm = [float(x) for x in v]
                    if len(v_norm) != dimension:
                        print(f"[WARNING] 向量维度异常: 期望{dimension}, 实际{len(v_norm)}")
                        # 用零向量填充或截断
                        if len(v_norm) < dimension:
                            v_norm.extend([0.0] * (dimension - len(v_norm)))
                        else:
                            v_norm = v_norm[:dimension]
                    vecs.append(v_norm)
                except Exception as e:
                    print(f"[WARNING] 向量转换失败: {e}, 使用零向量")
                    vecs.append([0.0] * dimension)
                
        except Exception as e:
            print(f"[WARNING] Batch {i} encoding failed: {e}")
            print(f"[RAG] Retrying batch {i} with smaller chunks...")
            
            # 尝试重试:将批次分解为更小的块
            success = False
            for j in range(0, len(part), 8):  # 更小的批次
                small_part = part[j:j+8]
                try:
                    import time
                    time.sleep(2)  # 等待2秒避免频率限制
                    
                    small_vecs = embedder.encode(small_part)
                    # Normalize to List[List[float]]
                    if isinstance(small_vecs, list) and small_vecs and not isinstance(small_vecs[0], list):
                        small_vecs = [small_vecs]
                    
                    for v in small_vecs:
                        if hasattr(v, "tolist"):
                            v = v.tolist()
                        try:
                            v_norm = [float(x) for x in v]
                            if len(v_norm) != dimension:
                                print(f"[WARNING] 向量维度异常: 期望{dimension}, 实际{len(v_norm)}")
                                if len(v_norm) < dimension:
                                    v_norm.extend([0.0] * (dimension - len(v_norm)))
                                else:
                                    v_norm = v_norm[:dimension]
                            vecs.append(v_norm)
                            success = True
                        except Exception as e2:
                            print(f"[WARNING] 小批次向量转换失败: {e2}")
                            vecs.append([0.0] * dimension)
                except Exception as e2:
                    print(f"[WARNING] 小批次 {j//8} 仍然失败: {e2}")
                    # 为这个小批次创建零向量
                    for _ in range(len(small_part)):
                        vecs.append([0.0] * dimension)
            
            if not success:
                print(f"[ERROR] 批次 {i} 完全失败,使用零向量")
        
        print(f"[RAG] Embedding progress: {min(i+batch_size, len(processed_texts))}/{len(processed_texts)}")
    
    # Prepare metadata with RAG tags
    metas: List[Dict] = []
    ids: List[str] = []
    for ch in chunks:
        meta = {
            "memory_id": ch["id"],
            "user_id": "rag_user",
            "memory_type": "rag_chunk",
            "content": ch["content"],  # Keep original markdown content
            "data_source": "rag_pipeline",  # RAG identification tag
            "rag_namespace": rag_namespace,
            "is_rag_data": True,  # Clear RAG data marker
        }
        # Merge chunk metadata
        meta.update(ch.get("metadata", {}))
        metas.append(meta)
        ids.append(ch["id"])
    
    print(f"[RAG] Qdrant upsert start: n={len(vecs)}")
    success = store.add_vectors(vectors=vecs, metadata=metas, ids=ids)
    if success:
        print(f"[RAG] Qdrant upsert done: {len(vecs)} vectors indexed")
    else:
        print(f"[RAG] Qdrant upsert failed")
        raise RuntimeError("Failed to index vectors to Qdrant")

三、构建智能文档问答助手

1.核心类初始化

python 复制代码
class PDFLearningAssistant:
    """智能文档问答助手"""

    def __init__(self, user_id: str = "default_user"):
        """初始化学习助手

        Args:
            user_id: 用户ID,用于隔离不同用户的数据
        """
        self.user_id = user_id
        self.session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

        # 初始化工具
        self.memory_tool = MemoryTool(user_id=user_id)
        self.rag_tool = RAGTool(rag_namespace=f"pdf_{user_id}")

        # 学习统计
        self.stats = {
            "session_start": datetime.now(),
            "documents_loaded": 0,
            "questions_asked": 0,
            "concepts_learned": 0
        }

        # 当前加载的文档
        self.current_document = None

2.加载PDF文件

python 复制代码
def load_document(self, pdf_path: str) -> Dict[str, Any]:
        """加载PDF文档到知识库

        Args:
            pdf_path: PDF文件路径

        Returns:
            Dict: 包含success和message的结果
        """
        if not os.path.exists(pdf_path):
            return {"success": False, "message": f"文件不存在: {pdf_path}"}

        start_time = time.time()

        try:
            # 使用RAG工具处理PDF
            result = self.rag_tool.run({
                "action":"add_document",
                "file_path":pdf_path,
                "chunk_size":1000,
                "chunk_overlap":200
            })

            process_time = time.time() - start_time

            # RAG工具返回的是字符串消息
            self.current_document = os.path.basename(pdf_path)
            self.stats["documents_loaded"] += 1

            # 记录到学习记忆
            self.memory_tool.run({
                "action":"add",
                "content":f"加载了文档《{self.current_document}》",
                "memory_type":"episodic",
                "importance":0.9,
                "event_type":"document_loaded",
                "session_id":self.session_id
            })

            return {
                "success": True,
                "message": f"加载成功!(耗时: {process_time:.1f}秒)",
                "document": self.current_document
            }
        except Exception as e:
            return {
                "success": False,
                "message": f"加载失败: {str(e)}"
            }

3.智能问答

python 复制代码
def ask(self, question: str, use_advanced_search: bool = True) -> str:
        """向文档提问

        Args:
            question: 用户问题
            use_advanced_search: 是否使用高级检索(MQE + HyDE)

        Returns:
            str: 答案
        """
        if not self.current_document:
            return "⚠️ 请先加载文档!使用 load_document() 方法加载PDF文档。"

        # 记录问题到工作记忆
        self.memory_tool.run({
            "action":"add",
            "content":f"提问: {question}",
            "memory_type":"working",
            "importance":0.6,
            "session_id":self.session_id
        })

        # 使用RAG检索答案
        answer = self.rag_tool.run({
            "action":"ask",
            "question":question,
            "limit":5,
            "enable_advanced_search":use_advanced_search,
            "enable_mqe":use_advanced_search,
            "enable_hyde":use_advanced_search
        })

        # 记录到情景记忆
        self.memory_tool.run({
            "action":"add",
            "content":f"关于'{question}'的学习",
            "memory_type":"episodic",
            "importance":0.7,
            "event_type":"qa_interaction",
            "session_id":self.session_id
        })

        self.stats["questions_asked"] += 1

        return answer

4.运行效果展示

相关推荐
invicinble2 小时前
对于代码阅读能力的思考和总结
java
搬砖的前端2 小时前
本地模型+TRAE CN 打造最优模型组合实测:开源主模型+本地辅模型,对标GPT5.2/5.3/Gemini-3-Flash
前端·ai·mac·ai编程·qwen·trae·qwen3.6
2402_854808372 小时前
c++ RAII机制详解 c++如何利用RAII管理资源
jvm·数据库·python
talen_hx2962 小时前
飞书机器人发文本消息
java·前端·飞书
吕源林2 小时前
CSS如何使用Less的Merge功能合并多个属性值_通过逗号或空格组织css参数
jvm·数据库·python
qq_330037992 小时前
Go语言如何写负载均衡器_Go语言负载均衡器实战教程【完整】
jvm·数据库·python
2601_949816582 小时前
Spring Boot--@PathVariable、@RequestParam、@RequestBody
java·spring boot·后端
2501_914245932 小时前
如何验证SQL删除操作的影响行数_通过ROW_COUNT获取反馈
jvm·数据库·python
2301_816660212 小时前
如何处理DG Broker的ORA-16664错误_主备库网络通信与TNS配置排查
jvm·数据库·python