Dify源码深度剖析3

第七章:Agent 智能体框架

7.1 Agent 架构设计

核心组件
Agent 架构
BaseAgentRunner
CotAgentRunner
FunctionCallAgentRunner
ChatAgentRunner
AgentScratchpad
思考记录
ToolEngine
工具执行
OutputParser
输出解析

Agent 基类实现:

python 复制代码
from abc import ABC, abstractmethod
from typing import Generator, List, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
import re
import json

from core.model_runtime import ModelManager


class AgentStrategy(Enum):
    """
    Agent 策略枚举
    """
    CHAIN_OF_THOUGHT = 'chain-of-thought'    # 思维链策略
    FUNCTION_CALLING = 'function-calling'    # 函数调用策略
    REACT = 'react'                          # ReAct 策略


@dataclass
class AgentScratchpadUnit:
    """
    Agent 思考记录单元
    记录单轮思考-行动-观察的过程
    """
    thought: str = ''              # 思考内容
    action: Dict = None            # 行动(工具调用)
    observation: str = ''          # 观察结果
    
    def is_final(self) -> bool:
        """
        判断是否为最终答案
        
        Returns:
            是否为最终答案
        """
        # 如果没有行动,则为最终答案
        return self.action is None
    
    def to_prompt(self) -> str:
        """
        转换为提示词格式
        
        Returns:
            格式化的字符串
        """
        parts = []
        
        if self.thought:
            parts.append(f"Thought: {self.thought}")
        
        if self.action:
            action_str = f"Action: {self.action.get('action_name')}"
            if 'action_input' in self.action:
                action_str += f"\nAction Input: {self.action['action_input']}"
            parts.append(action_str)
        
        if self.observation:
            parts.append(f"Observation: {self.observation}")
        
        return '\n'.join(parts)


class BaseAgentRunner(ABC):
    """
    Agent 运行器基类
    所有 Agent 策略必须实现此接口
    """
    
    def __init__(
        self,
        model_config: Dict,
        tools: List[Dict],
        max_iterations: int = 5
    ):
        """
        初始化 Agent
        
        Args:
            model_config: 模型配置
            tools: 可用工具列表
            max_iterations: 最大迭代次数
        """
        self.model_config = model_config
        self.tools = tools
        self.max_iterations = max_iterations
        
        # 思考记录
        self.scratchpad: List[AgentScratchpadUnit] = []
    
    @abstractmethod
    def run(
        self,
        query: str,
        inputs: Dict[str, Any]
    ) -> Generator:
        """
        运行 Agent
        
        Args:
            query: 用户问题
            inputs: 输入参数
            
        Yields:
            执行事件
        """
        pass
    
    def _build_tool_descriptions(self) -> str:
        """
        构建工具描述
        
        Returns:
            工具描述字符串
        """
        descriptions = []
        for i, tool in enumerate(self.tools, 1):
            desc = f"{i}. {tool['name']}: {tool.get('description', '')}"
            if 'parameters' in tool:
                params = ', '.join(tool['parameters'].keys())
                desc += f" (参数: {params})"
            descriptions.append(desc)
        
        return '\n'.join(descriptions)
    
    def _get_tool(self, tool_name: str) -> Optional[Dict]:
        """
        获取工具配置
        
        Args:
            tool_name: 工具名称
            
        Returns:
            工具配置或 None
        """
        for tool in self.tools:
            if tool['name'] == tool_name:
                return tool
        return None


class CotAgentRunner(BaseAgentRunner):
    """
    思维链 Agent 运行器
    通过 Thought-Action-Observation 循环逐步解决问题
    """
    
    def run(
        self,
        query: str,
        inputs: Dict[str, Any]
    ) -> Generator:
        """
        运行思维链 Agent
        
        Args:
            query: 用户问题
            inputs: 输入参数
            
        Yields:
            执行事件
        """
        # 初始化
        iteration = 0
        is_finished = False
        
        # 构建 Agent 提示词
        system_prompt = self._build_system_prompt()
        
        while not is_finished and iteration < self.max_iterations:
            iteration += 1
            
            # 构建当前轮次的提示词
            messages = self._build_messages(system_prompt, query)
            
            # 调用 LLM
            model_manager = ModelManager()
            model_instance = model_manager.get_model_instance(
                provider=self.model_config.get('provider'),
                model=self.model_config.get('name'),
                credentials=self.model_config.get('credentials', {})
            )
            
            # 流式获取响应
            response_text = ''
            for chunk in model_manager.invoke(
                model_instance=model_instance,
                prompt_messages=messages,
                model_parameters=self.model_config.get('parameters', {}),
                stream=True
            ):
                if 'content' in chunk:
                    response_text += chunk['content']
                    
                    # 发送思考事件
                    yield {
                        'event': 'agent_thought',
                        'data': {
                            'thought': chunk['content'],
                            'iteration': iteration
                        }
                    }
            
            # 解析响应
            scratchpad_unit = self._parse_response(response_text)
            self.scratchpad.append(scratchpad_unit)
            
            # 判断是否完成
            if scratchpad_unit.is_final():
                is_finished = True
                
                yield {
                    'event': 'agent_message',
                    'data': {
                        'answer': scratchpad_unit.thought,
                        'status': 'completed'
                    }
                }
            else:
                # 执行工具
                tool_result = self._execute_tool(scratchpad_unit.action)
                scratchpad_unit.observation = tool_result
                
                # 发送观察事件
                yield {
                    'event': 'agent_observation',
                    'data': {
                        'tool': scratchpad_unit.action.get('action_name'),
                        'result': tool_result,
                        'iteration': iteration
                    }
                }
        
        # 达到最大迭代次数
        if not is_finished:
            yield {
                'event': 'agent_message',
                'data': {
                    'answer': '抱歉,我无法在有限的步骤内完成这个任务。',
                    'status': 'max_iterations_reached'
                }
            }
    
    def _build_system_prompt(self) -> str:
        """
        构建系统提示词
        
        Returns:
            系统提示词
        """
        return f"""你是一个智能助手,可以使用工具来帮助用户解决问题。

你可以使用以下工具:
{self._build_tool_descriptions()}

请按照以下格式思考和行动:

Thought: 思考下一步应该做什么
Action: 工具名称
Action Input: 工具输入参数(JSON 格式)

或者当你有最终答案时:

Thought: 我现在知道最终答案了
Final Answer: 最终答案

开始!"""
    
    def _build_messages(self, system_prompt: str, query: str) -> List[Dict]:
        """
        构建消息列表
        
        Args:
            system_prompt: 系统提示词
            query: 用户问题
            
        Returns:
            消息列表
        """
        messages = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': query}
        ]
        
        # 添加历史思考记录
        for unit in self.scratchpad:
            messages.append({
                'role': 'assistant',
                'content': unit.to_prompt()
            })
        
        return messages
    
    def _parse_response(self, response: str) -> AgentScratchpadUnit:
        """
        解析 LLM 响应
        
        Args:
            response: LLM 响应文本
            
        Returns:
            思考记录单元
        """
        import re
        import json
        
        unit = AgentScratchpadUnit()
        
        # 解析 Thought
        thought_match = re.search(r'Thought:\s*(.+?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
        if thought_match:
            unit.thought = thought_match.group(1).strip()
        
        # 解析 Final Answer
        final_match = re.search(r'Final Answer:\s*(.+?)$', response, re.DOTALL)
        if final_match:
            unit.thought = final_match.group(1).strip()
            return unit  # 没有行动,返回最终答案
        
        # 解析 Action
        action_match = re.search(r'Action:\s*(.+?)(?=Action Input:|$)', response, re.DOTALL)
        if action_match:
            action_name = action_match.group(1).strip()
            
            # 解析 Action Input
            action_input = {}
            input_match = re.search(r'Action Input:\s*(.+?)$', response, re.DOTALL)
            if input_match:
                try:
                    action_input = json.loads(input_match.group(1).strip())
                except json.JSONDecodeError:
                    action_input = {'input': input_match.group(1).strip()}
            
            unit.action = {
                'action_name': action_name,
                'action_input': action_input
            }
        
        return unit
    
    def _execute_tool(self, action: Dict) -> str:
        """
        执行工具
        
        Args:
            action: 行动配置
            
        Returns:
            执行结果
        """
        tool_name = action.get('action_name')
        tool_input = action.get('action_input', {})
        
        tool = self._get_tool(tool_name)
        if not tool:
            return f"Error: Tool '{tool_name}' not found"
        
        # 实际执行工具
        # 这里简化处理,实际会调用 ToolEngine
        try:
            result = f"Tool '{tool_name}' executed with input: {tool_input}"
            return result
        except Exception as e:
            return f"Error executing tool: {str(e)}"


class FunctionCallAgentRunner(BaseAgentRunner):
    """
    函数调用 Agent 运行器
    利用 LLM 原生的 Function Calling 能力进行工具调用
    
    与 CotAgentRunner 的区别:
    - CotAgentRunner:通过文本解析提取工具调用
    - FunctionCallAgentRunner:使用 LLM 原生的 function_call 接口
    
    优势:
    - 更准确的工具参数解析
    - 支持并行工具调用
    - 更好的错误处理
    """
    
    def run(
        self,
        query: str,
        inputs: Dict[str, Any]
    ) -> Generator:
        """
        运行函数调用 Agent
        
        Args:
            query: 用户问题
            inputs: 输入参数
            
        Yields:
            执行事件
        """
        # 初始化
        iteration = 0
        is_finished = False
        
        # 构建消息列表
        messages = [
            {'role': 'system', 'content': self._build_system_prompt()},
            {'role': 'user', 'content': query}
        ]
        
        while not is_finished and iteration < self.max_iterations:
            iteration += 1
            
            # 调用 LLM(启用 function calling)
            model_manager = ModelManager()
            model_instance = model_manager.get_model_instance(
                provider=self.model_config.get('provider'),
                model=self.model_config.get('name'),
                credentials=self.model_config.get('credentials', {})
            )
            
            # 构建工具定义(OpenAI 格式)
            tools = self._build_tools_schema()
            
            # 流式获取响应
            response_text = ''
            tool_calls = []
            
            for chunk in model_manager.invoke(
                model_instance=model_instance,
                prompt_messages=messages,
                model_parameters=self.model_config.get('parameters', {}),
                tools=tools,
                tool_choice='auto',
                stream=True
            ):
                if 'content' in chunk and chunk['content']:
                    response_text += chunk['content']
                    yield {
                        'event': 'agent_thought',
                        'data': {
                            'thought': chunk['content'],
                            'iteration': iteration
                        }
                    }
                
                # 收集工具调用
                if 'tool_calls' in chunk:
                    for tc in chunk['tool_calls']:
                        tool_calls.append(tc)
            
            # 判断是否有工具调用
            if tool_calls:
                # 添加助手消息(包含工具调用)
                messages.append({
                    'role': 'assistant',
                    'content': response_text,
                    'tool_calls': tool_calls
                })
                
                # 执行所有工具调用
                for tool_call in tool_calls:
                    function_name = tool_call['function']['name']
                    function_args = json.loads(tool_call['function']['arguments'])
                    
                    # 执行工具
                    tool_result = self._execute_tool({
                        'action_name': function_name,
                        'action_input': function_args
                    })
                    
                    # 添加工具结果消息
                    messages.append({
                        'role': 'tool',
                        'tool_call_id': tool_call['id'],
                        'content': tool_result
                    })
                    
                    yield {
                        'event': 'agent_observation',
                        'data': {
                            'tool': function_name,
                            'result': tool_result,
                            'iteration': iteration
                        }
                    }
            else:
                # 没有工具调用,返回最终答案
                is_finished = True
                yield {
                    'event': 'agent_message',
                    'data': {
                        'answer': response_text,
                        'status': 'completed'
                    }
                }
        
        # 达到最大迭代次数
        if not is_finished:
            yield {
                'event': 'agent_message',
                'data': {
                    'answer': '抱歉,我无法在有限的步骤内完成这个任务。',
                    'status': 'max_iterations_reached'
                }
            }
    
    def _build_system_prompt(self) -> str:
        """
        构建系统提示词
        
        Returns:
            系统提示词
        """
        return """你是一个智能助手,可以使用工具来帮助用户解决问题。

当需要使用工具时,系统会自动调用相应的工具函数。
请根据用户的问题选择合适的工具,并提供准确的参数。"""
    
    def _build_tools_schema(self) -> List[Dict]:
        """
        构建工具 Schema(OpenAI Function Calling 格式)
        
        Returns:
            工具 Schema 列表
        """
        tools = []
        for tool in self.tools:
            tools.append({
                'type': 'function',
                'function': {
                    'name': tool['name'],
                    'description': tool.get('description', ''),
                    'parameters': tool.get('parameters', {})
                }
            })
        return tools
    
    def _execute_tool(self, action: Dict) -> str:
        """
        执行工具
        
        Args:
            action: 行动配置
            
        Returns:
            执行结果
        """
        tool_name = action.get('action_name')
        tool_input = action.get('action_input', {})
        
        tool = self._get_tool(tool_name)
        if not tool:
            return f"Error: Tool '{tool_name}' not found"
        
        try:
            # 实际执行工具
            result = f"Tool '{tool_name}' executed with input: {tool_input}"
            return result
        except Exception as e:
            return f"Error executing tool: {str(e)}"

7.2 推理策略实现

Tool LLM Agent 用户 Tool LLM Agent 用户 loop [迭代循环] 提交问题 发送 Prompt Thought + Action 解析响应 执行工具 返回结果 记录 Observation 发送下一轮 Prompt Final Answer 返回答案

7.3 工具调用机制

python 复制代码
from typing import Dict, Any, List, Optional
from abc import ABC, abstractmethod
import json


class BaseTool(ABC):
    """
    工具基类
    所有工具必须实现此接口
    """
    
    @property
    @abstractmethod
    def name(self) -> str:
        """工具名称"""
        pass
    
    @property
    @abstractmethod
    def description(self) -> str:
        """工具描述"""
        pass
    
    @property
    def parameters(self) -> Dict:
        """
        参数定义(JSON Schema 格式)
        """
        return {}
    
    @abstractmethod
    def execute(self, parameters: Dict[str, Any]) -> str:
        """
        执行工具
        
        Args:
            parameters: 参数字典
            
        Returns:
            执行结果
        """
        pass


class WebSearchTool(BaseTool):
    """
    网络搜索工具
    """
    
    @property
    def name(self) -> str:
        return 'web_search'
    
    @property
    def description(self) -> str:
        return '搜索互联网获取信息'
    
    @property
    def parameters(self) -> Dict:
        return {
            'type': 'object',
            'properties': {
                'query': {
                    'type': 'string',
                    'description': '搜索关键词'
                },
                'num_results': {
                    'type': 'integer',
                    'description': '返回结果数量',
                    'default': 5
                }
            },
            'required': ['query']
        }
    
    def execute(self, parameters: Dict[str, Any]) -> str:
        """
        执行搜索
        
        Args:
            parameters: 搜索参数
            
        Returns:
            搜索结果
        """
        query = parameters.get('query')
        num_results = parameters.get('num_results', 5)
        
        # 实际实现会调用搜索 API
        results = f"搜索 '{query}' 的前 {num_results} 条结果..."
        return results


class CalculatorTool(BaseTool):
    """
    计算器工具
    """
    
    @property
    def name(self) -> str:
        return 'calculator'
    
    @property
    def description(self) -> str:
        return '执行数学计算'
    
    @property
    def parameters(self) -> Dict:
        return {
            'type': 'object',
            'properties': {
                'expression': {
                    'type': 'string',
                    'description': '数学表达式,如 "2 + 3 * 4"'
                }
            },
            'required': ['expression']
        }
    
    def execute(self, parameters: Dict[str, Any]) -> str:
        """
        执行计算
        
        Args:
            parameters: 计算参数
            
        Returns:
            计算结果
        """
        import ast
        import operator
        
        expression = parameters.get('expression')
        
        # 安全的运算符映射
        ALLOWED_OPERATORS = {
            ast.Add: operator.add,
            ast.Sub: operator.sub,
            ast.Mult: operator.mul,
            ast.Div: operator.truediv,
            ast.FloorDiv: operator.floordiv,
            ast.Mod: operator.mod,
            ast.Pow: operator.pow,
            ast.USub: operator.neg,
            ast.UAdd: operator.pos,
        }
        
        def safe_eval(node):
            """
            安全的表达式求值
            只允许基本数学运算,禁止函数调用和属性访问
            
            Args:
                node: AST 节点
                
            Returns:
                计算结果
                
            Raises:
                ValueError: 不允许的操作
            """
            if isinstance(node, ast.Constant):
                return node.value
            elif isinstance(node, ast.Num):
                return node.n
            elif isinstance(node, ast.BinOp):
                if type(node.op) not in ALLOWED_OPERATORS:
                    raise ValueError(f"不允许的运算符: {type(node.op).__name__}")
                left = safe_eval(node.left)
                right = safe_eval(node.right)
                return ALLOWED_OPERATORS[type(node.op)](left, right)
            elif isinstance(node, ast.UnaryOp):
                if type(node.op) not in ALLOWED_OPERATORS:
                    raise ValueError(f"不允许的运算符: {type(node.op).__name__}")
                operand = safe_eval(node.operand)
                return ALLOWED_OPERATORS[type(node.op)](operand)
            else:
                raise ValueError(f"不允许的表达式类型: {type(node).__name__}")
        
        try:
            # 解析表达式为 AST
            tree = ast.parse(expression, mode='eval')
            # 安全求值
            result = safe_eval(tree.body)
            return f"计算结果: {result}"
        except ValueError as e:
            return f"计算错误: {str(e)}"
        except SyntaxError:
            return "计算错误: 表达式语法无效"
        except Exception as e:
            return f"计算错误: {str(e)}"


class ToolEngine:
    """
    工具引擎
    管理和执行工具
    """
    
    def __init__(self):
        self._tools: Dict[str, BaseTool] = {}
        self._register_default_tools()
    
    def _register_default_tools(self):
        """注册默认工具"""
        self.register_tool(WebSearchTool())
        self.register_tool(CalculatorTool())
    
    def register_tool(self, tool: BaseTool):
        """
        注册工具
        
        Args:
            tool: 工具实例
        """
        self._tools[tool.name] = tool
    
    def get_tool_schemas(self) -> List[Dict]:
        """
        获取所有工具的 Schema
        
        Returns:
            工具 Schema 列表
        """
        schemas = []
        for tool in self._tools.values():
            schemas.append({
                'name': tool.name,
                'description': tool.description,
                'parameters': tool.parameters
            })
        return schemas
    
    def invoke(
        self,
        tool_name: str,
        parameters: Dict[str, Any]
    ) -> str:
        """
        调用工具
        
        Args:
            tool_name: 工具名称
            parameters: 参数
            
        Returns:
            执行结果
        """
        tool = self._tools.get(tool_name)
        if not tool:
            raise ValueError(f'Tool not found: {tool_name}')
        
        return tool.execute(parameters)

7.4 Agent 执行时序图

ToolEngine LLM ModelManager AgentRunner Client ToolEngine LLM ModelManager AgentRunner Client alt [需要工具调用] [最终答案] loop [迭代 (max_iterations)] run(query, inputs) 初始化 scratchpad 构建消息(含历史) invoke(model, messages) API 调用 流式响应 LLMResultChunk 解析响应 agent_thought 事件 invoke(tool, params) 工具结果 agent_observation 事件 记录到 scratchpad agent_message 事件 结束循环 workflow_finished 事件


第八章:会话管理与可观测性

8.1 会话数据模型

python 复制代码
from sqlalchemy import Column, String, Text, JSON, DateTime, Integer
from sqlalchemy.dialects.postgresql import UUID
from models.base import Base
import uuid
from datetime import datetime


class ConversationVariable(Base):
    """
    会话变量模型
    存储多轮对话的状态信息
    """
    __tablename__ = 'conversation_variables'
    
    id = Column(UUID, primary_key=True, default=uuid.uuid4)
    
    # 所属会话
    conversation_id = Column(UUID, nullable=False, index=True)
    
    # 变量名
    name = Column(String(255), nullable=False)
    
    # 变量值(JSON 格式)
    value = Column(JSON, nullable=True)
    
    # 变量类型
    value_type = Column(String(50), default='string')
    
    # 创建时间
    created_at = Column(DateTime, default=datetime.utcnow)
    
    # 更新时间
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)


class MessageFeedback(Base):
    """
    消息反馈模型
    存储用户对消息的反馈
    """
    __tablename__ = 'message_feedbacks'
    
    id = Column(UUID, primary_key=True, default=uuid.uuid4)
    
    # 关联消息
    message_id = Column(UUID, nullable=False, index=True)
    
    # 反馈类型:like/dislike
    rating = Column(String(50), nullable=False)
    
    # 反馈内容
    content = Column(Text, nullable=True)
    
    # 创建者
    created_by = Column(UUID, nullable=True)
    
    created_at = Column(DateTime, default=datetime.utcnow)

8.2 上下文窗口管理

python 复制代码
from typing import List, Dict, Any, Optional
from dataclasses import dataclass


@dataclass
class ContextWindowConfig:
    """
    上下文窗口配置
    """
    max_tokens: int = 4096           # 最大 Token 数
    max_messages: int = 20           # 最大消息数
    strategy: str = 'truncate'       # 截断策略:truncate/compress


class ContextWindowManager:
    """
    上下文窗口管理器
    负责管理对话历史的截断和压缩
    """
    
    def __init__(self, config: ContextWindowConfig):
        """
        初始化管理器
        
        Args:
            config: 配置
        """
        self.config = config
    
    def manage(
        self,
        messages: List[Dict],
        system_prompt: str = None,
        new_message: str = None
    ) -> List[Dict]:
        """
        管理上下文窗口
        
        Args:
            messages: 历史消息列表
            system_prompt: 系统提示词
            new_message: 新消息
            
        Returns:
            处理后的消息列表
        """
        # 计算当前 Token 数
        total_tokens = self._count_tokens(messages, system_prompt, new_message)
        
        # 如果超过限制,执行截断
        if total_tokens > self.config.max_tokens:
            messages = self._truncate_messages(messages)
        
        # 限制消息数量
        if len(messages) > self.config.max_messages:
            messages = messages[-self.config.max_messages:]
        
        return messages
    
    def _count_tokens(
        self,
        messages: List[Dict],
        system_prompt: str = None,
        new_message: str = None
    ) -> int:
        """
        计算 Token 数量
        
        Args:
            messages: 消息列表
            system_prompt: 系统提示词
            new_message: 新消息
            
        Returns:
            Token 数量
        """
        total = 0
        
        if system_prompt:
            total += len(system_prompt) // 4  # 粗略估算
        
        for msg in messages:
            content = msg.get('content', '')
            total += len(content) // 4
        
        if new_message:
            total += len(new_message) // 4
        
        return total
    
    def _truncate_messages(self, messages: List[Dict]) -> List[Dict]:
        """
        截断消息列表
        
        Args:
            messages: 原始消息列表
            
        Returns:
            截断后的消息列表
        """
        # 保留第一条和最后几条消息
        if len(messages) <= 2:
            return messages
        
        # 保留第一条(通常是系统消息或重要上下文)
        truncated = [messages[0]]
        
        # 添加最后的消息
        truncated.extend(messages[-(self.config.max_messages - 1):])
        
        return truncated

8.3 API 接口层设计

python 复制代码
from flask import Blueprint, request, Response, stream_with_context
from services.chat_service import ChatService
import json

chat_api = Blueprint('chat', __name__)


@chat_api.route('/chat-messages', methods=['POST'])
def create_chat_message():
    """
    创建聊天消息
    
    Request Body:
        {
            "inputs": {},
            "query": "用户问题",
            "response_mode": "streaming",
            "conversation_id": "可选,用于继续对话",
            "user": "用户标识"
        }
    
    Response:
        SSE 流式响应或 JSON 响应
    """
    data = request.get_json()
    
    # 参数校验
    query = data.get('query')
    if not query:
        return {'error': 'query is required'}, 400
    
    response_mode = data.get('response_mode', 'streaming')
    
    # 调用服务
    chat_service = ChatService()
    
    if response_mode == 'streaming':
        # 流式响应
        def generate():
            for event in chat_service.chat_stream(
                query=query,
                inputs=data.get('inputs', {}),
                conversation_id=data.get('conversation_id'),
                user=data.get('user')
            ):
                # SSE 格式
                yield f"data: {json.dumps(event)}\n\n"
            
            yield "data: [DONE]\n\n"
        
        return Response(
            stream_with_context(generate()),
            mimetype='text/event-stream',
            headers={
                'Cache-Control': 'no-cache',
                'X-Accel-Buffering': 'no'
            }
    )
    else:
        # 同步响应
        result = chat_service.chat_sync(
            query=query,
            inputs=data.get('inputs', {}),
            conversation_id=data.get('conversation_id'),
            user=data.get('user')
        )
        
        return result


@chat_api.route('/conversations/<conversation_id>/messages', methods=['GET'])
def get_conversation_messages(conversation_id):
    """
    获取会话消息列表
    
    Args:
        conversation_id: 会话 ID
    
    Query Parameters:
        first_id: 第一条消息 ID(用于分页)
        limit: 返回数量
    
    Response:
        {
            "data": [...],
            "has_more": true
        }
    """
    first_id = request.args.get('first_id')
    limit = request.args.get('limit', 20, type=int)
    
    chat_service = ChatService()
    messages = chat_service.get_messages(
        conversation_id=conversation_id,
        first_id=first_id,
        limit=limit
    )
    
    return {
        'data': messages,
        'has_more': len(messages) == limit
    }

8.4 可观测性体系

可观测性三支柱
监控目标
API 响应时间
Token 消耗
错误率
工作流执行时间
日志 Logs
问题诊断
指标 Metrics
性能监控
追踪 Traces
链路分析
告警系统

追踪系统实现:

python 复制代码
import time
import uuid
from functools import wraps
from typing import Dict, Any, Optional
import logging

logger = logging.getLogger(__name__)


class TraceContext:
    """
    追踪上下文
    存储当前请求的追踪信息
    """
    
    def __init__(self, trace_id: str = None):
        self.trace_id = trace_id or str(uuid.uuid4())
        self.spans: list = []
        self.current_span: Optional[Dict] = None


class Tracer:
    """
    追踪器
    记录请求的执行链路
    """
    
    _context: Optional[TraceContext] = None
    
    @classmethod
    def start_trace(cls, trace_id: str = None) -> str:
        """
        开始追踪
        
        Args:
            trace_id: 追踪 ID
            
        Returns:
            追踪 ID
        """
        cls._context = TraceContext(trace_id)
        return cls._context.trace_id
    
    @classmethod
    def end_trace(cls) -> Dict:
        """
        结束追踪
        
        Returns:
            追踪结果
        """
        if cls._context:
            result = {
                'trace_id': cls._context.trace_id,
                'spans': cls._context.spans,
                'total_duration': sum(s['duration'] for s in cls._context.spans)
            }
            cls._context = None
            return result
        return {}
    
    @classmethod
    def start_span(
        cls,
        operation: str,
        attributes: Dict = None
    ) -> str:
        """
        开始 Span
        
        Args:
            operation: 操作名称
            attributes: 属性
            
        Returns:
            Span ID
        """
        if not cls._context:
            return None
        
        span_id = str(uuid.uuid4())
        span = {
            'span_id': span_id,
            'operation': operation,
            'start_time': time.time(),
            'attributes': attributes or {},
            'events': []
        }
        
        cls._context.current_span = span
        return span_id
    
    @classmethod
    def end_span(cls, status: str = 'ok', attributes: Dict = None):
        """
        结束 Span
        
        Args:
            status: 状态
            attributes: 额外属性
        """
        if not cls._context or not cls._context.current_span:
            return
        
        span = cls._context.current_span
        span['end_time'] = time.time()
        span['duration'] = span['end_time'] - span['start_time']
        span['status'] = status
        
        if attributes:
            span['attributes'].update(attributes)
        
        cls._context.spans.append(span)
        cls._context.current_span = None
    
    @classmethod
    def add_event(cls, name: str, attributes: Dict = None):
        """
        添加事件
        
        Args:
            name: 事件名称
            attributes: 属性
        """
        if not cls._context or not cls._context.current_span:
            return
        
        cls._context.current_span['events'].append({
            'name': name,
            'timestamp': time.time(),
            'attributes': attributes or {}
        })


def traced(operation: str):
    """
    追踪装饰器
    
    Args:
        operation: 操作名称
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            Tracer.start_span(operation, {
                'function': func.__name__,
                'args': str(args)[:100]
            })
            
            try:
                result = func(*args, **kwargs)
                Tracer.end_span('ok')
                return result
            except Exception as e:
                Tracer.end_span('error', {
                    'error': str(e)
                })
                raise
        
        return wrapper
    return decorator

第九章:安全机制与扩展系统

9.1 多租户与数据隔离

数据隔离
多租户架构
请求
认证中间件
租户识别
权限校验
数据过滤
租户 A 数据
租户 B 数据
租户 C 数据

租户隔离实现:

python 复制代码
from functools import wraps
from flask import request, g
from typing import Optional
import uuid


class TenantContext:
    """
    租户上下文
    存储当前请求的租户信息
    """
    
    def __init__(
        self,
        tenant_id: str,
        user_id: str,
        role: str
    ):
        self.tenant_id = tenant_id
        self.user_id = user_id
        self.role = role


def get_tenant_context() -> Optional[TenantContext]:
    """
    获取当前租户上下文
    
    Returns:
        租户上下文或 None
    """
    return getattr(g, 'tenant_context', None)


def tenant_required(f):
    """
    租户验证装饰器
    确保请求包含有效的租户信息
    """
    @wraps(f)
    def decorated(*args, **kwargs):
        # 从请求头获取租户信息
        tenant_id = request.headers.get('X-Tenant-ID')
        user_id = request.headers.get('X-User-ID')
        
        if not tenant_id or not user_id:
            return {'error': 'Tenant information required'}, 401
        
        # 验证租户是否存在
        # 实际实现会查询数据库
        
        # 设置租户上下文
        g.tenant_context = TenantContext(
            tenant_id=tenant_id,
            user_id=user_id,
            role='user'  # 从数据库获取
        )
        
        return f(*args, **kwargs)
    
    return decorated


class TenantQueryFilter:
    """
    租户查询过滤器
    自动为查询添加租户过滤条件
    """
    
    @staticmethod
    def filter_by_tenant(query, model_class):
        """
        为查询添加租户过滤
        
        Args:
            query: SQLAlchemy 查询
            model_class: 模型类
            
        Returns:
            过滤后的查询
        """
        context = get_tenant_context()
        if not context:
            return query
        
        # 检查模型是否有 tenant_id 字段
        if hasattr(model_class, 'tenant_id'):
            query = query.filter(model_class.tenant_id == context.tenant_id)
        
        return query

9.2 代码沙箱安全机制

python 复制代码
from typing import Dict, Any, Optional
import subprocess
import json
import tempfile
import os


class CodeSandbox:
    """
    代码沙箱
    在隔离环境中安全执行用户代码
    """
    
    def __init__(
        self,
        timeout: int = 30,
        memory_limit: int = 128,  # MB
        cpu_limit: float = 1.0    # 核数
    ):
        """
        初始化沙箱
        
        Args:
            timeout: 执行超时(秒)
            memory_limit: 内存限制(MB)
            cpu_limit: CPU 限制(核数)
        """
        self.timeout = timeout
        self.memory_limit = memory_limit
        self.cpu_limit = cpu_limit
    
    def execute(
        self,
        code: str,
        variables: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        执行代码
        
        Args:
            code: Python 代码
            variables: 输入变量
            
        Returns:
            执行结果
        """
        # 准备执行脚本
        script = self._prepare_script(code, variables)
        
        # 创建临时文件
        with tempfile.NamedTemporaryFile(
            mode='w',
            suffix='.py',
            delete=False
        ) as f:
            f.write(script)
            script_path = f.name
        
        try:
            # 在 Docker 容器中执行
            result = self._execute_in_container(script_path)
            return result
        
        finally:
            # 清理临时文件
            os.unlink(script_path)
    
    def _prepare_script(
        self,
        code: str,
        variables: Dict[str, Any]
    ) -> str:
        """
        准备执行脚本
        
        Args:
            code: 用户代码
            variables: 输入变量
            
        Returns:
            完整的执行脚本
        """
        # 包装代码,添加输入输出处理
        script = f'''
import json
import sys

# 输入变量
inputs = {json.dumps(variables)}

# 用户代码
{code}

# 提取输出
def main():
    result = {{}}
    
    # 调用用户定义的 main 函数(如果存在)
    if 'main' in dir():
        output = main(**inputs)
        if isinstance(output, dict):
            result = output
        else:
            result['output'] = output
    
    # 输出结果
    print(json.dumps(result))

if __name__ == '__main__':
    main()
'''
        return script
    
    def _execute_in_container(self, script_path: str) -> Dict[str, Any]:
        """
        在 Docker 容器中执行脚本
        
        Args:
            script_path: 脚本路径
            
        Returns:
            执行结果
        """
        try:
            # 构建 Docker 命令
            cmd = [
                'docker', 'run',
                '--rm',  # 执行后删除容器
                '--network', 'none',  # 禁用网络
                '--memory', f'{self.memory_limit}m',  # 内存限制
                '--cpus', str(self.cpu_limit),  # CPU 限制
                '-v', f'{script_path}:/app/script.py:ro',  # 挂载脚本
                'python:3.10-slim',  # 使用精简镜像
                'python', '/app/script.py'
            ]
            
            # 执行命令
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=self.timeout
            )
            
            if result.returncode == 0:
                try:
                    output = json.loads(result.stdout)
                    return {
                        'success': True,
                        'output': output
                    }
                except json.JSONDecodeError:
                    return {
                        'success': True,
                        'output': {'result': result.stdout}
                    }
            else:
                return {
                    'success': False,
                    'error': result.stderr
                }
        
        except subprocess.TimeoutExpired:
            return {
                'success': False,
                'error': 'Execution timeout'
            }
        
        except Exception as e:
            return {
                'success': False,
                'error': str(e)
            }

9.3 插件扩展系统

python 复制代码
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from dataclasses import dataclass


@dataclass
class PluginMetadata:
    """
    插件元数据
    """
    name: str                          # 插件名称
    version: str                       # 版本号
    description: str                   # 描述
    author: str                        # 作者
    dependencies: List[str] = None     # 依赖


class BasePlugin(ABC):
    """
    插件基类
    所有插件必须实现此接口
    """
    
    @property
    @abstractmethod
    def metadata(self) -> PluginMetadata:
        """插件元数据"""
        pass
    
    @abstractmethod
    def initialize(self, config: Dict[str, Any]):
        """
        初始化插件
        
        Args:
            config: 配置
        """
        pass
    
    @abstractmethod
    def teardown(self):
        """清理插件"""
        pass


class ToolPlugin(BasePlugin):
    """
    工具插件
    添加新的工具到 Dify
    """
    
    @abstractmethod
    def get_tools(self) -> List[Dict]:
        """
        获取插件提供的工具
        
        Returns:
            工具列表
        """
        pass


class NodePlugin(BasePlugin):
    """
    节点插件
    添加新的工作流节点类型
    """
    
    @abstractmethod
    def get_node_types(self) -> List[Dict]:
        """
        获取插件提供的节点类型
        
        Returns:
            节点类型列表
        """
        pass


class PluginManager:
    """
    插件管理器
    负责插件的加载、管理和调用
    """
    
    def __init__(self):
        self._plugins: Dict[str, BasePlugin] = {}
    
    def load_plugin(self, plugin: BasePlugin, config: Dict = None):
        """
        加载插件
        
        Args:
            plugin: 插件实例
            config: 配置
        """
        metadata = plugin.metadata
        plugin.initialize(config or {})
        self._plugins[metadata.name] = plugin
    
    def unload_plugin(self, plugin_name: str):
        """
        卸载插件
        
        Args:
            plugin_name: 插件名称
        """
        if plugin_name in self._plugins:
            self._plugins[plugin_name].teardown()
            del self._plugins[plugin_name]
    
    def get_all_tools(self) -> List[Dict]:
        """
        获取所有插件提供的工具
        
        Returns:
            工具列表
        """
        tools = []
        for plugin in self._plugins.values():
            if isinstance(plugin, ToolPlugin):
                tools.extend(plugin.get_tools())
        return tools
    
    def get_all_node_types(self) -> List[Dict]:
        """
        获取所有插件提供的节点类型
        
        Returns:
            节点类型列表
        """
        node_types = []
        for plugin in self._plugins.values():
            if isinstance(plugin, NodePlugin):
                node_types.extend(plugin.get_node_types())
        return node_types

9.4 任务调度系统

python 复制代码
from celery import Celery
from typing import Dict, Any
import logging

logger = logging.getLogger(__name__)

# 创建 Celery 应用
celery_app = Celery(
    'dify',
    broker='redis://localhost:6379/0',
    backend='redis://localhost:6379/1'
)

# 配置
celery_app.conf.update(
    task_serializer='json',
    accept_content=['json'],
    result_serializer='json',
    timezone='UTC',
    enable_utc=True,
    task_track_started=True,
    task_time_limit=3600,  # 1小时超时
    worker_prefetch_multiplier=1
)


@celery_app.task(bind=True, name='document_indexing')
def document_indexing_task(
    self,
    document_id: str,
    dataset_id: str
) -> Dict[str, Any]:
    """
    文档索引任务
    
    Args:
        self: Celery 任务实例
        document_id: 文档 ID
        dataset_id: 数据集 ID
        
    Returns:
        执行结果
    """
    from services.document_service import DocumentService
    
    logger.info(f'Starting document indexing: {document_id}')
    
    try:
        # 更新任务状态
        self.update_state(
            state='PROGRESS',
            meta={'status': 'processing', 'progress': 0}
        )
        
        # 执行索引
        document_service = DocumentService()
        result = document_service.index_document(
            document_id=document_id,
            dataset_id=dataset_id,
            progress_callback=lambda p: self.update_state(
                state='PROGRESS',
                meta={'status': 'processing', 'progress': p}
            )
        )
        
        logger.info(f'Document indexing completed: {document_id}')
        
        return {
            'status': 'completed',
            'document_id': document_id,
            'segment_count': result.get('segment_count', 0)
        }
    
    except Exception as e:
        logger.exception(f'Document indexing failed: {document_id}')
        
        return {
            'status': 'failed',
            'document_id': document_id,
            'error': str(e)
        }


@celery_app.task(name='batch_workflow_execution')
def batch_workflow_execution_task(
    workflow_id: str,
    inputs_list: List[Dict]
) -> Dict[str, Any]:
    """
    批量工作流执行任务
    
    Args:
        workflow_id: 工作流 ID
        inputs_list: 输入列表
        
    Returns:
        执行结果
    """
    from core.workflow import WorkflowRunner
    
    results = []
    
    for i, inputs in enumerate(inputs_list):
        try:
            runner = WorkflowRunner(workflow_id, inputs)
            
            # 执行工作流
            for event in runner.run():
                pass  # 收集结果
            
            results.append({
                'index': i,
                'status': 'success',
                'outputs': runner.state.variable_pool
            })
        
        except Exception as e:
            results.append({
                'index': i,
                'status': 'failed',
                'error': str(e)
            })
    
    return {
        'total': len(inputs_list),
        'success': sum(1 for r in results if r['status'] == 'success'),
        'failed': sum(1 for r in results if r['status'] == 'failed'),
        'results': results
    }

相关推荐
@Ma2 小时前
企业微信智能机器人 Python 插件获取回调和发送消息支持文字图片语音视频
python·机器人·企业微信
IT_陈寒2 小时前
JavaScript开发者必看:3个让代码效率翻倍的隐藏技巧
前端·人工智能·后端
七夜zippoe2 小时前
消息队列选型:Kafka vs RabbitMQ vs Redis 深度对比
redis·python·kafka·消息队列·rabbitmq
赵谨言2 小时前
基于YOLOv5的海棠花花朵检测识别:文献综述与研究展望
大数据·开发语言·经验分享·python
-Excalibur-2 小时前
IP数据包在计算机网络传输的全过程
java·网络·c++·笔记·python·网络协议·智能路由器
希望永不加班2 小时前
如何在 SpringBoot 里自定义 Spring MVC 配置
java·spring boot·后端·spring·mvc
weixin199701080162 小时前
“迷你京东”全栈架构设计与实现
java·大数据·python·数据库架构
supersolon2 小时前
OpenClaw安装碰到的一些问题和解决方法
linux·运维·ai·openclaw·龙虾
Welcome_Back2 小时前
SpringBoot后端开发测试全指南
spring boot·后端·log4j