第七章:Agent 智能体框架
7.1 Agent 架构设计
核心组件
Agent 架构
BaseAgentRunner
CotAgentRunner
FunctionCallAgentRunner
ChatAgentRunner
AgentScratchpad
思考记录
ToolEngine
工具执行
OutputParser
输出解析
Agent 基类实现:
python
from abc import ABC, abstractmethod
from typing import Generator, List, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
import re
import json
from core.model_runtime import ModelManager
class AgentStrategy(Enum):
"""
Agent 策略枚举
"""
CHAIN_OF_THOUGHT = 'chain-of-thought' # 思维链策略
FUNCTION_CALLING = 'function-calling' # 函数调用策略
REACT = 'react' # ReAct 策略
@dataclass
class AgentScratchpadUnit:
"""
Agent 思考记录单元
记录单轮思考-行动-观察的过程
"""
thought: str = '' # 思考内容
action: Dict = None # 行动(工具调用)
observation: str = '' # 观察结果
def is_final(self) -> bool:
"""
判断是否为最终答案
Returns:
是否为最终答案
"""
# 如果没有行动,则为最终答案
return self.action is None
def to_prompt(self) -> str:
"""
转换为提示词格式
Returns:
格式化的字符串
"""
parts = []
if self.thought:
parts.append(f"Thought: {self.thought}")
if self.action:
action_str = f"Action: {self.action.get('action_name')}"
if 'action_input' in self.action:
action_str += f"\nAction Input: {self.action['action_input']}"
parts.append(action_str)
if self.observation:
parts.append(f"Observation: {self.observation}")
return '\n'.join(parts)
class BaseAgentRunner(ABC):
"""
Agent 运行器基类
所有 Agent 策略必须实现此接口
"""
def __init__(
self,
model_config: Dict,
tools: List[Dict],
max_iterations: int = 5
):
"""
初始化 Agent
Args:
model_config: 模型配置
tools: 可用工具列表
max_iterations: 最大迭代次数
"""
self.model_config = model_config
self.tools = tools
self.max_iterations = max_iterations
# 思考记录
self.scratchpad: List[AgentScratchpadUnit] = []
@abstractmethod
def run(
self,
query: str,
inputs: Dict[str, Any]
) -> Generator:
"""
运行 Agent
Args:
query: 用户问题
inputs: 输入参数
Yields:
执行事件
"""
pass
def _build_tool_descriptions(self) -> str:
"""
构建工具描述
Returns:
工具描述字符串
"""
descriptions = []
for i, tool in enumerate(self.tools, 1):
desc = f"{i}. {tool['name']}: {tool.get('description', '')}"
if 'parameters' in tool:
params = ', '.join(tool['parameters'].keys())
desc += f" (参数: {params})"
descriptions.append(desc)
return '\n'.join(descriptions)
def _get_tool(self, tool_name: str) -> Optional[Dict]:
"""
获取工具配置
Args:
tool_name: 工具名称
Returns:
工具配置或 None
"""
for tool in self.tools:
if tool['name'] == tool_name:
return tool
return None
class CotAgentRunner(BaseAgentRunner):
"""
思维链 Agent 运行器
通过 Thought-Action-Observation 循环逐步解决问题
"""
def run(
self,
query: str,
inputs: Dict[str, Any]
) -> Generator:
"""
运行思维链 Agent
Args:
query: 用户问题
inputs: 输入参数
Yields:
执行事件
"""
# 初始化
iteration = 0
is_finished = False
# 构建 Agent 提示词
system_prompt = self._build_system_prompt()
while not is_finished and iteration < self.max_iterations:
iteration += 1
# 构建当前轮次的提示词
messages = self._build_messages(system_prompt, query)
# 调用 LLM
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
provider=self.model_config.get('provider'),
model=self.model_config.get('name'),
credentials=self.model_config.get('credentials', {})
)
# 流式获取响应
response_text = ''
for chunk in model_manager.invoke(
model_instance=model_instance,
prompt_messages=messages,
model_parameters=self.model_config.get('parameters', {}),
stream=True
):
if 'content' in chunk:
response_text += chunk['content']
# 发送思考事件
yield {
'event': 'agent_thought',
'data': {
'thought': chunk['content'],
'iteration': iteration
}
}
# 解析响应
scratchpad_unit = self._parse_response(response_text)
self.scratchpad.append(scratchpad_unit)
# 判断是否完成
if scratchpad_unit.is_final():
is_finished = True
yield {
'event': 'agent_message',
'data': {
'answer': scratchpad_unit.thought,
'status': 'completed'
}
}
else:
# 执行工具
tool_result = self._execute_tool(scratchpad_unit.action)
scratchpad_unit.observation = tool_result
# 发送观察事件
yield {
'event': 'agent_observation',
'data': {
'tool': scratchpad_unit.action.get('action_name'),
'result': tool_result,
'iteration': iteration
}
}
# 达到最大迭代次数
if not is_finished:
yield {
'event': 'agent_message',
'data': {
'answer': '抱歉,我无法在有限的步骤内完成这个任务。',
'status': 'max_iterations_reached'
}
}
def _build_system_prompt(self) -> str:
"""
构建系统提示词
Returns:
系统提示词
"""
return f"""你是一个智能助手,可以使用工具来帮助用户解决问题。
你可以使用以下工具:
{self._build_tool_descriptions()}
请按照以下格式思考和行动:
Thought: 思考下一步应该做什么
Action: 工具名称
Action Input: 工具输入参数(JSON 格式)
或者当你有最终答案时:
Thought: 我现在知道最终答案了
Final Answer: 最终答案
开始!"""
def _build_messages(self, system_prompt: str, query: str) -> List[Dict]:
"""
构建消息列表
Args:
system_prompt: 系统提示词
query: 用户问题
Returns:
消息列表
"""
messages = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': query}
]
# 添加历史思考记录
for unit in self.scratchpad:
messages.append({
'role': 'assistant',
'content': unit.to_prompt()
})
return messages
def _parse_response(self, response: str) -> AgentScratchpadUnit:
"""
解析 LLM 响应
Args:
response: LLM 响应文本
Returns:
思考记录单元
"""
import re
import json
unit = AgentScratchpadUnit()
# 解析 Thought
thought_match = re.search(r'Thought:\s*(.+?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
if thought_match:
unit.thought = thought_match.group(1).strip()
# 解析 Final Answer
final_match = re.search(r'Final Answer:\s*(.+?)$', response, re.DOTALL)
if final_match:
unit.thought = final_match.group(1).strip()
return unit # 没有行动,返回最终答案
# 解析 Action
action_match = re.search(r'Action:\s*(.+?)(?=Action Input:|$)', response, re.DOTALL)
if action_match:
action_name = action_match.group(1).strip()
# 解析 Action Input
action_input = {}
input_match = re.search(r'Action Input:\s*(.+?)$', response, re.DOTALL)
if input_match:
try:
action_input = json.loads(input_match.group(1).strip())
except json.JSONDecodeError:
action_input = {'input': input_match.group(1).strip()}
unit.action = {
'action_name': action_name,
'action_input': action_input
}
return unit
def _execute_tool(self, action: Dict) -> str:
"""
执行工具
Args:
action: 行动配置
Returns:
执行结果
"""
tool_name = action.get('action_name')
tool_input = action.get('action_input', {})
tool = self._get_tool(tool_name)
if not tool:
return f"Error: Tool '{tool_name}' not found"
# 实际执行工具
# 这里简化处理,实际会调用 ToolEngine
try:
result = f"Tool '{tool_name}' executed with input: {tool_input}"
return result
except Exception as e:
return f"Error executing tool: {str(e)}"
class FunctionCallAgentRunner(BaseAgentRunner):
"""
函数调用 Agent 运行器
利用 LLM 原生的 Function Calling 能力进行工具调用
与 CotAgentRunner 的区别:
- CotAgentRunner:通过文本解析提取工具调用
- FunctionCallAgentRunner:使用 LLM 原生的 function_call 接口
优势:
- 更准确的工具参数解析
- 支持并行工具调用
- 更好的错误处理
"""
def run(
self,
query: str,
inputs: Dict[str, Any]
) -> Generator:
"""
运行函数调用 Agent
Args:
query: 用户问题
inputs: 输入参数
Yields:
执行事件
"""
# 初始化
iteration = 0
is_finished = False
# 构建消息列表
messages = [
{'role': 'system', 'content': self._build_system_prompt()},
{'role': 'user', 'content': query}
]
while not is_finished and iteration < self.max_iterations:
iteration += 1
# 调用 LLM(启用 function calling)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
provider=self.model_config.get('provider'),
model=self.model_config.get('name'),
credentials=self.model_config.get('credentials', {})
)
# 构建工具定义(OpenAI 格式)
tools = self._build_tools_schema()
# 流式获取响应
response_text = ''
tool_calls = []
for chunk in model_manager.invoke(
model_instance=model_instance,
prompt_messages=messages,
model_parameters=self.model_config.get('parameters', {}),
tools=tools,
tool_choice='auto',
stream=True
):
if 'content' in chunk and chunk['content']:
response_text += chunk['content']
yield {
'event': 'agent_thought',
'data': {
'thought': chunk['content'],
'iteration': iteration
}
}
# 收集工具调用
if 'tool_calls' in chunk:
for tc in chunk['tool_calls']:
tool_calls.append(tc)
# 判断是否有工具调用
if tool_calls:
# 添加助手消息(包含工具调用)
messages.append({
'role': 'assistant',
'content': response_text,
'tool_calls': tool_calls
})
# 执行所有工具调用
for tool_call in tool_calls:
function_name = tool_call['function']['name']
function_args = json.loads(tool_call['function']['arguments'])
# 执行工具
tool_result = self._execute_tool({
'action_name': function_name,
'action_input': function_args
})
# 添加工具结果消息
messages.append({
'role': 'tool',
'tool_call_id': tool_call['id'],
'content': tool_result
})
yield {
'event': 'agent_observation',
'data': {
'tool': function_name,
'result': tool_result,
'iteration': iteration
}
}
else:
# 没有工具调用,返回最终答案
is_finished = True
yield {
'event': 'agent_message',
'data': {
'answer': response_text,
'status': 'completed'
}
}
# 达到最大迭代次数
if not is_finished:
yield {
'event': 'agent_message',
'data': {
'answer': '抱歉,我无法在有限的步骤内完成这个任务。',
'status': 'max_iterations_reached'
}
}
def _build_system_prompt(self) -> str:
"""
构建系统提示词
Returns:
系统提示词
"""
return """你是一个智能助手,可以使用工具来帮助用户解决问题。
当需要使用工具时,系统会自动调用相应的工具函数。
请根据用户的问题选择合适的工具,并提供准确的参数。"""
def _build_tools_schema(self) -> List[Dict]:
"""
构建工具 Schema(OpenAI Function Calling 格式)
Returns:
工具 Schema 列表
"""
tools = []
for tool in self.tools:
tools.append({
'type': 'function',
'function': {
'name': tool['name'],
'description': tool.get('description', ''),
'parameters': tool.get('parameters', {})
}
})
return tools
def _execute_tool(self, action: Dict) -> str:
"""
执行工具
Args:
action: 行动配置
Returns:
执行结果
"""
tool_name = action.get('action_name')
tool_input = action.get('action_input', {})
tool = self._get_tool(tool_name)
if not tool:
return f"Error: Tool '{tool_name}' not found"
try:
# 实际执行工具
result = f"Tool '{tool_name}' executed with input: {tool_input}"
return result
except Exception as e:
return f"Error executing tool: {str(e)}"
7.2 推理策略实现
Tool LLM Agent 用户 Tool LLM Agent 用户 loop [迭代循环] 提交问题 发送 Prompt Thought + Action 解析响应 执行工具 返回结果 记录 Observation 发送下一轮 Prompt Final Answer 返回答案
7.3 工具调用机制
python
from typing import Dict, Any, List, Optional
from abc import ABC, abstractmethod
import json
class BaseTool(ABC):
"""
工具基类
所有工具必须实现此接口
"""
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述"""
pass
@property
def parameters(self) -> Dict:
"""
参数定义(JSON Schema 格式)
"""
return {}
@abstractmethod
def execute(self, parameters: Dict[str, Any]) -> str:
"""
执行工具
Args:
parameters: 参数字典
Returns:
执行结果
"""
pass
class WebSearchTool(BaseTool):
"""
网络搜索工具
"""
@property
def name(self) -> str:
return 'web_search'
@property
def description(self) -> str:
return '搜索互联网获取信息'
@property
def parameters(self) -> Dict:
return {
'type': 'object',
'properties': {
'query': {
'type': 'string',
'description': '搜索关键词'
},
'num_results': {
'type': 'integer',
'description': '返回结果数量',
'default': 5
}
},
'required': ['query']
}
def execute(self, parameters: Dict[str, Any]) -> str:
"""
执行搜索
Args:
parameters: 搜索参数
Returns:
搜索结果
"""
query = parameters.get('query')
num_results = parameters.get('num_results', 5)
# 实际实现会调用搜索 API
results = f"搜索 '{query}' 的前 {num_results} 条结果..."
return results
class CalculatorTool(BaseTool):
"""
计算器工具
"""
@property
def name(self) -> str:
return 'calculator'
@property
def description(self) -> str:
return '执行数学计算'
@property
def parameters(self) -> Dict:
return {
'type': 'object',
'properties': {
'expression': {
'type': 'string',
'description': '数学表达式,如 "2 + 3 * 4"'
}
},
'required': ['expression']
}
def execute(self, parameters: Dict[str, Any]) -> str:
"""
执行计算
Args:
parameters: 计算参数
Returns:
计算结果
"""
import ast
import operator
expression = parameters.get('expression')
# 安全的运算符映射
ALLOWED_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
def safe_eval(node):
"""
安全的表达式求值
只允许基本数学运算,禁止函数调用和属性访问
Args:
node: AST 节点
Returns:
计算结果
Raises:
ValueError: 不允许的操作
"""
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Num):
return node.n
elif isinstance(node, ast.BinOp):
if type(node.op) not in ALLOWED_OPERATORS:
raise ValueError(f"不允许的运算符: {type(node.op).__name__}")
left = safe_eval(node.left)
right = safe_eval(node.right)
return ALLOWED_OPERATORS[type(node.op)](left, right)
elif isinstance(node, ast.UnaryOp):
if type(node.op) not in ALLOWED_OPERATORS:
raise ValueError(f"不允许的运算符: {type(node.op).__name__}")
operand = safe_eval(node.operand)
return ALLOWED_OPERATORS[type(node.op)](operand)
else:
raise ValueError(f"不允许的表达式类型: {type(node).__name__}")
try:
# 解析表达式为 AST
tree = ast.parse(expression, mode='eval')
# 安全求值
result = safe_eval(tree.body)
return f"计算结果: {result}"
except ValueError as e:
return f"计算错误: {str(e)}"
except SyntaxError:
return "计算错误: 表达式语法无效"
except Exception as e:
return f"计算错误: {str(e)}"
class ToolEngine:
"""
工具引擎
管理和执行工具
"""
def __init__(self):
self._tools: Dict[str, BaseTool] = {}
self._register_default_tools()
def _register_default_tools(self):
"""注册默认工具"""
self.register_tool(WebSearchTool())
self.register_tool(CalculatorTool())
def register_tool(self, tool: BaseTool):
"""
注册工具
Args:
tool: 工具实例
"""
self._tools[tool.name] = tool
def get_tool_schemas(self) -> List[Dict]:
"""
获取所有工具的 Schema
Returns:
工具 Schema 列表
"""
schemas = []
for tool in self._tools.values():
schemas.append({
'name': tool.name,
'description': tool.description,
'parameters': tool.parameters
})
return schemas
def invoke(
self,
tool_name: str,
parameters: Dict[str, Any]
) -> str:
"""
调用工具
Args:
tool_name: 工具名称
parameters: 参数
Returns:
执行结果
"""
tool = self._tools.get(tool_name)
if not tool:
raise ValueError(f'Tool not found: {tool_name}')
return tool.execute(parameters)
7.4 Agent 执行时序图
ToolEngine LLM ModelManager AgentRunner Client ToolEngine LLM ModelManager AgentRunner Client alt [需要工具调用] [最终答案] loop [迭代 (max_iterations)] run(query, inputs) 初始化 scratchpad 构建消息(含历史) invoke(model, messages) API 调用 流式响应 LLMResultChunk 解析响应 agent_thought 事件 invoke(tool, params) 工具结果 agent_observation 事件 记录到 scratchpad agent_message 事件 结束循环 workflow_finished 事件
第八章:会话管理与可观测性
8.1 会话数据模型
python
from sqlalchemy import Column, String, Text, JSON, DateTime, Integer
from sqlalchemy.dialects.postgresql import UUID
from models.base import Base
import uuid
from datetime import datetime
class ConversationVariable(Base):
"""
会话变量模型
存储多轮对话的状态信息
"""
__tablename__ = 'conversation_variables'
id = Column(UUID, primary_key=True, default=uuid.uuid4)
# 所属会话
conversation_id = Column(UUID, nullable=False, index=True)
# 变量名
name = Column(String(255), nullable=False)
# 变量值(JSON 格式)
value = Column(JSON, nullable=True)
# 变量类型
value_type = Column(String(50), default='string')
# 创建时间
created_at = Column(DateTime, default=datetime.utcnow)
# 更新时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MessageFeedback(Base):
"""
消息反馈模型
存储用户对消息的反馈
"""
__tablename__ = 'message_feedbacks'
id = Column(UUID, primary_key=True, default=uuid.uuid4)
# 关联消息
message_id = Column(UUID, nullable=False, index=True)
# 反馈类型:like/dislike
rating = Column(String(50), nullable=False)
# 反馈内容
content = Column(Text, nullable=True)
# 创建者
created_by = Column(UUID, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
8.2 上下文窗口管理
python
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class ContextWindowConfig:
"""
上下文窗口配置
"""
max_tokens: int = 4096 # 最大 Token 数
max_messages: int = 20 # 最大消息数
strategy: str = 'truncate' # 截断策略:truncate/compress
class ContextWindowManager:
"""
上下文窗口管理器
负责管理对话历史的截断和压缩
"""
def __init__(self, config: ContextWindowConfig):
"""
初始化管理器
Args:
config: 配置
"""
self.config = config
def manage(
self,
messages: List[Dict],
system_prompt: str = None,
new_message: str = None
) -> List[Dict]:
"""
管理上下文窗口
Args:
messages: 历史消息列表
system_prompt: 系统提示词
new_message: 新消息
Returns:
处理后的消息列表
"""
# 计算当前 Token 数
total_tokens = self._count_tokens(messages, system_prompt, new_message)
# 如果超过限制,执行截断
if total_tokens > self.config.max_tokens:
messages = self._truncate_messages(messages)
# 限制消息数量
if len(messages) > self.config.max_messages:
messages = messages[-self.config.max_messages:]
return messages
def _count_tokens(
self,
messages: List[Dict],
system_prompt: str = None,
new_message: str = None
) -> int:
"""
计算 Token 数量
Args:
messages: 消息列表
system_prompt: 系统提示词
new_message: 新消息
Returns:
Token 数量
"""
total = 0
if system_prompt:
total += len(system_prompt) // 4 # 粗略估算
for msg in messages:
content = msg.get('content', '')
total += len(content) // 4
if new_message:
total += len(new_message) // 4
return total
def _truncate_messages(self, messages: List[Dict]) -> List[Dict]:
"""
截断消息列表
Args:
messages: 原始消息列表
Returns:
截断后的消息列表
"""
# 保留第一条和最后几条消息
if len(messages) <= 2:
return messages
# 保留第一条(通常是系统消息或重要上下文)
truncated = [messages[0]]
# 添加最后的消息
truncated.extend(messages[-(self.config.max_messages - 1):])
return truncated
8.3 API 接口层设计
python
from flask import Blueprint, request, Response, stream_with_context
from services.chat_service import ChatService
import json
chat_api = Blueprint('chat', __name__)
@chat_api.route('/chat-messages', methods=['POST'])
def create_chat_message():
"""
创建聊天消息
Request Body:
{
"inputs": {},
"query": "用户问题",
"response_mode": "streaming",
"conversation_id": "可选,用于继续对话",
"user": "用户标识"
}
Response:
SSE 流式响应或 JSON 响应
"""
data = request.get_json()
# 参数校验
query = data.get('query')
if not query:
return {'error': 'query is required'}, 400
response_mode = data.get('response_mode', 'streaming')
# 调用服务
chat_service = ChatService()
if response_mode == 'streaming':
# 流式响应
def generate():
for event in chat_service.chat_stream(
query=query,
inputs=data.get('inputs', {}),
conversation_id=data.get('conversation_id'),
user=data.get('user')
):
# SSE 格式
yield f"data: {json.dumps(event)}\n\n"
yield "data: [DONE]\n\n"
return Response(
stream_with_context(generate()),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no'
}
)
else:
# 同步响应
result = chat_service.chat_sync(
query=query,
inputs=data.get('inputs', {}),
conversation_id=data.get('conversation_id'),
user=data.get('user')
)
return result
@chat_api.route('/conversations/<conversation_id>/messages', methods=['GET'])
def get_conversation_messages(conversation_id):
"""
获取会话消息列表
Args:
conversation_id: 会话 ID
Query Parameters:
first_id: 第一条消息 ID(用于分页)
limit: 返回数量
Response:
{
"data": [...],
"has_more": true
}
"""
first_id = request.args.get('first_id')
limit = request.args.get('limit', 20, type=int)
chat_service = ChatService()
messages = chat_service.get_messages(
conversation_id=conversation_id,
first_id=first_id,
limit=limit
)
return {
'data': messages,
'has_more': len(messages) == limit
}
8.4 可观测性体系
可观测性三支柱
监控目标
API 响应时间
Token 消耗
错误率
工作流执行时间
日志 Logs
问题诊断
指标 Metrics
性能监控
追踪 Traces
链路分析
告警系统
追踪系统实现:
python
import time
import uuid
from functools import wraps
from typing import Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
class TraceContext:
"""
追踪上下文
存储当前请求的追踪信息
"""
def __init__(self, trace_id: str = None):
self.trace_id = trace_id or str(uuid.uuid4())
self.spans: list = []
self.current_span: Optional[Dict] = None
class Tracer:
"""
追踪器
记录请求的执行链路
"""
_context: Optional[TraceContext] = None
@classmethod
def start_trace(cls, trace_id: str = None) -> str:
"""
开始追踪
Args:
trace_id: 追踪 ID
Returns:
追踪 ID
"""
cls._context = TraceContext(trace_id)
return cls._context.trace_id
@classmethod
def end_trace(cls) -> Dict:
"""
结束追踪
Returns:
追踪结果
"""
if cls._context:
result = {
'trace_id': cls._context.trace_id,
'spans': cls._context.spans,
'total_duration': sum(s['duration'] for s in cls._context.spans)
}
cls._context = None
return result
return {}
@classmethod
def start_span(
cls,
operation: str,
attributes: Dict = None
) -> str:
"""
开始 Span
Args:
operation: 操作名称
attributes: 属性
Returns:
Span ID
"""
if not cls._context:
return None
span_id = str(uuid.uuid4())
span = {
'span_id': span_id,
'operation': operation,
'start_time': time.time(),
'attributes': attributes or {},
'events': []
}
cls._context.current_span = span
return span_id
@classmethod
def end_span(cls, status: str = 'ok', attributes: Dict = None):
"""
结束 Span
Args:
status: 状态
attributes: 额外属性
"""
if not cls._context or not cls._context.current_span:
return
span = cls._context.current_span
span['end_time'] = time.time()
span['duration'] = span['end_time'] - span['start_time']
span['status'] = status
if attributes:
span['attributes'].update(attributes)
cls._context.spans.append(span)
cls._context.current_span = None
@classmethod
def add_event(cls, name: str, attributes: Dict = None):
"""
添加事件
Args:
name: 事件名称
attributes: 属性
"""
if not cls._context or not cls._context.current_span:
return
cls._context.current_span['events'].append({
'name': name,
'timestamp': time.time(),
'attributes': attributes or {}
})
def traced(operation: str):
"""
追踪装饰器
Args:
operation: 操作名称
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
Tracer.start_span(operation, {
'function': func.__name__,
'args': str(args)[:100]
})
try:
result = func(*args, **kwargs)
Tracer.end_span('ok')
return result
except Exception as e:
Tracer.end_span('error', {
'error': str(e)
})
raise
return wrapper
return decorator
第九章:安全机制与扩展系统
9.1 多租户与数据隔离
数据隔离
多租户架构
请求
认证中间件
租户识别
权限校验
数据过滤
租户 A 数据
租户 B 数据
租户 C 数据
租户隔离实现:
python
from functools import wraps
from flask import request, g
from typing import Optional
import uuid
class TenantContext:
"""
租户上下文
存储当前请求的租户信息
"""
def __init__(
self,
tenant_id: str,
user_id: str,
role: str
):
self.tenant_id = tenant_id
self.user_id = user_id
self.role = role
def get_tenant_context() -> Optional[TenantContext]:
"""
获取当前租户上下文
Returns:
租户上下文或 None
"""
return getattr(g, 'tenant_context', None)
def tenant_required(f):
"""
租户验证装饰器
确保请求包含有效的租户信息
"""
@wraps(f)
def decorated(*args, **kwargs):
# 从请求头获取租户信息
tenant_id = request.headers.get('X-Tenant-ID')
user_id = request.headers.get('X-User-ID')
if not tenant_id or not user_id:
return {'error': 'Tenant information required'}, 401
# 验证租户是否存在
# 实际实现会查询数据库
# 设置租户上下文
g.tenant_context = TenantContext(
tenant_id=tenant_id,
user_id=user_id,
role='user' # 从数据库获取
)
return f(*args, **kwargs)
return decorated
class TenantQueryFilter:
"""
租户查询过滤器
自动为查询添加租户过滤条件
"""
@staticmethod
def filter_by_tenant(query, model_class):
"""
为查询添加租户过滤
Args:
query: SQLAlchemy 查询
model_class: 模型类
Returns:
过滤后的查询
"""
context = get_tenant_context()
if not context:
return query
# 检查模型是否有 tenant_id 字段
if hasattr(model_class, 'tenant_id'):
query = query.filter(model_class.tenant_id == context.tenant_id)
return query
9.2 代码沙箱安全机制
python
from typing import Dict, Any, Optional
import subprocess
import json
import tempfile
import os
class CodeSandbox:
"""
代码沙箱
在隔离环境中安全执行用户代码
"""
def __init__(
self,
timeout: int = 30,
memory_limit: int = 128, # MB
cpu_limit: float = 1.0 # 核数
):
"""
初始化沙箱
Args:
timeout: 执行超时(秒)
memory_limit: 内存限制(MB)
cpu_limit: CPU 限制(核数)
"""
self.timeout = timeout
self.memory_limit = memory_limit
self.cpu_limit = cpu_limit
def execute(
self,
code: str,
variables: Dict[str, Any]
) -> Dict[str, Any]:
"""
执行代码
Args:
code: Python 代码
variables: 输入变量
Returns:
执行结果
"""
# 准备执行脚本
script = self._prepare_script(code, variables)
# 创建临时文件
with tempfile.NamedTemporaryFile(
mode='w',
suffix='.py',
delete=False
) as f:
f.write(script)
script_path = f.name
try:
# 在 Docker 容器中执行
result = self._execute_in_container(script_path)
return result
finally:
# 清理临时文件
os.unlink(script_path)
def _prepare_script(
self,
code: str,
variables: Dict[str, Any]
) -> str:
"""
准备执行脚本
Args:
code: 用户代码
variables: 输入变量
Returns:
完整的执行脚本
"""
# 包装代码,添加输入输出处理
script = f'''
import json
import sys
# 输入变量
inputs = {json.dumps(variables)}
# 用户代码
{code}
# 提取输出
def main():
result = {{}}
# 调用用户定义的 main 函数(如果存在)
if 'main' in dir():
output = main(**inputs)
if isinstance(output, dict):
result = output
else:
result['output'] = output
# 输出结果
print(json.dumps(result))
if __name__ == '__main__':
main()
'''
return script
def _execute_in_container(self, script_path: str) -> Dict[str, Any]:
"""
在 Docker 容器中执行脚本
Args:
script_path: 脚本路径
Returns:
执行结果
"""
try:
# 构建 Docker 命令
cmd = [
'docker', 'run',
'--rm', # 执行后删除容器
'--network', 'none', # 禁用网络
'--memory', f'{self.memory_limit}m', # 内存限制
'--cpus', str(self.cpu_limit), # CPU 限制
'-v', f'{script_path}:/app/script.py:ro', # 挂载脚本
'python:3.10-slim', # 使用精简镜像
'python', '/app/script.py'
]
# 执行命令
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=self.timeout
)
if result.returncode == 0:
try:
output = json.loads(result.stdout)
return {
'success': True,
'output': output
}
except json.JSONDecodeError:
return {
'success': True,
'output': {'result': result.stdout}
}
else:
return {
'success': False,
'error': result.stderr
}
except subprocess.TimeoutExpired:
return {
'success': False,
'error': 'Execution timeout'
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
9.3 插件扩展系统
python
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
@dataclass
class PluginMetadata:
"""
插件元数据
"""
name: str # 插件名称
version: str # 版本号
description: str # 描述
author: str # 作者
dependencies: List[str] = None # 依赖
class BasePlugin(ABC):
"""
插件基类
所有插件必须实现此接口
"""
@property
@abstractmethod
def metadata(self) -> PluginMetadata:
"""插件元数据"""
pass
@abstractmethod
def initialize(self, config: Dict[str, Any]):
"""
初始化插件
Args:
config: 配置
"""
pass
@abstractmethod
def teardown(self):
"""清理插件"""
pass
class ToolPlugin(BasePlugin):
"""
工具插件
添加新的工具到 Dify
"""
@abstractmethod
def get_tools(self) -> List[Dict]:
"""
获取插件提供的工具
Returns:
工具列表
"""
pass
class NodePlugin(BasePlugin):
"""
节点插件
添加新的工作流节点类型
"""
@abstractmethod
def get_node_types(self) -> List[Dict]:
"""
获取插件提供的节点类型
Returns:
节点类型列表
"""
pass
class PluginManager:
"""
插件管理器
负责插件的加载、管理和调用
"""
def __init__(self):
self._plugins: Dict[str, BasePlugin] = {}
def load_plugin(self, plugin: BasePlugin, config: Dict = None):
"""
加载插件
Args:
plugin: 插件实例
config: 配置
"""
metadata = plugin.metadata
plugin.initialize(config or {})
self._plugins[metadata.name] = plugin
def unload_plugin(self, plugin_name: str):
"""
卸载插件
Args:
plugin_name: 插件名称
"""
if plugin_name in self._plugins:
self._plugins[plugin_name].teardown()
del self._plugins[plugin_name]
def get_all_tools(self) -> List[Dict]:
"""
获取所有插件提供的工具
Returns:
工具列表
"""
tools = []
for plugin in self._plugins.values():
if isinstance(plugin, ToolPlugin):
tools.extend(plugin.get_tools())
return tools
def get_all_node_types(self) -> List[Dict]:
"""
获取所有插件提供的节点类型
Returns:
节点类型列表
"""
node_types = []
for plugin in self._plugins.values():
if isinstance(plugin, NodePlugin):
node_types.extend(plugin.get_node_types())
return node_types
9.4 任务调度系统
python
from celery import Celery
from typing import Dict, Any
import logging
logger = logging.getLogger(__name__)
# 创建 Celery 应用
celery_app = Celery(
'dify',
broker='redis://localhost:6379/0',
backend='redis://localhost:6379/1'
)
# 配置
celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
timezone='UTC',
enable_utc=True,
task_track_started=True,
task_time_limit=3600, # 1小时超时
worker_prefetch_multiplier=1
)
@celery_app.task(bind=True, name='document_indexing')
def document_indexing_task(
self,
document_id: str,
dataset_id: str
) -> Dict[str, Any]:
"""
文档索引任务
Args:
self: Celery 任务实例
document_id: 文档 ID
dataset_id: 数据集 ID
Returns:
执行结果
"""
from services.document_service import DocumentService
logger.info(f'Starting document indexing: {document_id}')
try:
# 更新任务状态
self.update_state(
state='PROGRESS',
meta={'status': 'processing', 'progress': 0}
)
# 执行索引
document_service = DocumentService()
result = document_service.index_document(
document_id=document_id,
dataset_id=dataset_id,
progress_callback=lambda p: self.update_state(
state='PROGRESS',
meta={'status': 'processing', 'progress': p}
)
)
logger.info(f'Document indexing completed: {document_id}')
return {
'status': 'completed',
'document_id': document_id,
'segment_count': result.get('segment_count', 0)
}
except Exception as e:
logger.exception(f'Document indexing failed: {document_id}')
return {
'status': 'failed',
'document_id': document_id,
'error': str(e)
}
@celery_app.task(name='batch_workflow_execution')
def batch_workflow_execution_task(
workflow_id: str,
inputs_list: List[Dict]
) -> Dict[str, Any]:
"""
批量工作流执行任务
Args:
workflow_id: 工作流 ID
inputs_list: 输入列表
Returns:
执行结果
"""
from core.workflow import WorkflowRunner
results = []
for i, inputs in enumerate(inputs_list):
try:
runner = WorkflowRunner(workflow_id, inputs)
# 执行工作流
for event in runner.run():
pass # 收集结果
results.append({
'index': i,
'status': 'success',
'outputs': runner.state.variable_pool
})
except Exception as e:
results.append({
'index': i,
'status': 'failed',
'error': str(e)
})
return {
'total': len(inputs_list),
'success': sum(1 for r in results if r['status'] == 'success'),
'failed': sum(1 for r in results if r['status'] == 'failed'),
'results': results
}