复制代码
"""
自定义工作流Agent示例 - 直接继承BaseAgentWorkflow
演示如何创建完全自定义的工作流结构,
包括自定义State、节点函数和工作流图。
"""
from typing import Annotated, Dict, Any, Literal, List
from typing_extensions import TypedDict
from langchain_core.tools import tool, BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from base import BaseAgentWorkflow, LLMConfig
class CustomComplexAgent(BaseAgentWorkflow):
"""
自定义复杂Agent工作流
特点:
- 扩展的State结构(包含上下文和步骤计数)
- 自定义的预处理和后处理节点
- 复杂的条件路由逻辑
- 多步骤工作流控制
"""
def __init__(self, llm_config: LLMConfig = None):
"""初始化自定义复杂Agent"""
super().__init__(llm_config)
def get_state_schema(self) -> type:
"""
自定义State结构,包含额外字段
"""
class ComplexState(TypedDict):
messages: Annotated[list, add_messages]
context: Dict[str, Any] # 上下文信息
step_count: int # 步骤计数
max_steps: int # 最大步骤限制
return ComplexState
def get_llm(self) -> ChatOpenAI:
"""配置LLM实例"""
return ChatOpenAI(**self.llm_config.to_dict())
@tool
def search_database(self, query: str) -> str:
"""模拟数据库搜索"""
print(f"------------------------------------------搜索数据库: {query} ------------------------------------------------")
return f"数据库搜索结果: 找到与'{query}'相关的3条记录"
@tool
def send_email(self, recipient: str, subject: str, content: str) -> str:
"""模拟发送邮件"""
print(f"------------------------------------------发送邮件------------------------------------------------")
print(f"收件人: {recipient}")
print(f"主题: {subject}")
print(f"内容: {content}")
return f"邮件已成功发送给 {recipient}"
@tool
def analyze_data(self, data: str) -> str:
"""模拟数据分析"""
print(f"------------------------------------------分析数据: {data} ------------------------------------------------")
return f"数据分析结果: {data} 的趋势为上升,置信度85%"
def get_tools(self) -> List[BaseTool]:
"""返回工具列表"""
return [self.search_database, self.send_email, self.analyze_data]
def get_tools_dict(self) -> Dict[str, BaseTool]:
"""返回工具映射"""
return {
"search_database": self.search_database,
"send_email": self.send_email,
"analyze_data": self.analyze_data
}
def preprocessing_node(self, state):
"""预处理节点:初始化上下文和计数器"""
print("→ 预处理阶段:初始化上下文")
# 初始化或更新上下文
context = state.get("context", {})
context["start_time"] = "2024-01-01 10:00:00"
context["user_id"] = "user_001"
return {
"context": context,
"step_count": 1,
"max_steps": 5
}
def custom_call_model_node(self, state):
"""自定义模型调用节点:包含步骤控制"""
print(f"→ 模型调用阶段:第{state['step_count']}步")
# 检查步骤限制
if state["step_count"] > state["max_steps"]:
return {
"messages": [AIMessage(content="已达到最大步骤限制,停止执行。")]
}
# 在系统消息中添加上下文信息
context_info = f"当前上下文: 用户ID={state['context']['user_id']}, 步骤={state['step_count']}/{state['max_steps']}"
# 调用模型
ai_msg = self.llm_with_tools.invoke(state['messages'])
print(f"AI回复: {ai_msg}")
print(f"工具调用: {ai_msg.tool_calls}")
return {
"messages": [ai_msg],
"step_count": state["step_count"] + 1
}
def custom_tool_action_node(self, state):
"""自定义工具执行节点"""
print("→ 工具执行阶段")
messages = []
for tool_call in state['messages'][-1].tool_calls:
selected_tool = self.tools_dict[tool_call["name"].lower()]
print(f"执行工具: {selected_tool}")
tool_msg = selected_tool.invoke(tool_call)
print(f"工具结果: {tool_msg}")
messages.append(tool_msg)
return {"messages": messages}
def postprocessing_node(self, state):
"""后处理节点:清理和总结"""
print("→ 后处理阶段:生成总结")
# 生成执行总结
summary = f"任务完成。共执行了{state['step_count']-1}个步骤。"
summary_msg = AIMessage(content=f"执行总结: {summary}")
return {"messages": [summary_msg]}
def custom_route_function(self, state) -> Literal["tool_action", "postprocessing", "__end__"]:
"""自定义路由函数:包含复杂的条件判断"""
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage):
raise ValueError(f"Expected AIMessage, got {type(last_message).__name__}")
# 检查步骤限制
if state["step_count"] > state["max_steps"]:
return "postprocessing"
# 检查是否有工具调用
if last_message.tool_calls:
return "tool_action"
# 检查是否需要进入后处理
if "完成" in last_message.content or "结束" in last_message.content:
return "postprocessing"
return "__end__"
def build_graph(self) -> StateGraph:
"""
构建自定义复杂工作流图
工作流结构:
开始 → 预处理 → 调用模型 → [路由判断] → 工具执行 ↗
↓ ↓
后处理 ← ← ← ← ← ← ← ↙
↓
结束
"""
builder = StateGraph(self.get_state_schema())
# 添加节点
builder.add_node("preprocessing", self.preprocessing_node)
builder.add_node("call_model", self.custom_call_model_node)
builder.add_node("tool_action", self.custom_tool_action_node)
builder.add_node("postprocessing", self.postprocessing_node)
# 构建复杂的边关系
builder.add_edge("__start__", "preprocessing")
builder.add_edge("preprocessing", "call_model")
builder.add_conditional_edges(
"call_model",
self.custom_route_function,
{
"tool_action": "tool_action",
"postprocessing": "postprocessing",
"__end__": "__end__"
}
)
builder.add_edge("tool_action", "call_model")
builder.add_edge("postprocessing", "__end__")
# 编译并设置名称
graph = builder.compile()
graph.name = "Custom Complex Agent"
return graph
def process_result(self, result: Dict[str, Any]) -> str:
"""自定义结果处理"""
# 查找执行总结
for message in reversed(result["messages"]):
if isinstance(message, AIMessage) and "执行总结" in message.content:
return message.content
# 如果没有总结,返回最后的AI回复
return super().process_result(result)
def main():
"""主函数 - 演示自定义Agent使用"""
print("=== 自定义复杂Agent示例 ===\n")
# 创建Agent实例
agent = CustomComplexAgent()
# 测试用例1:数据搜索和分析
print("测试1: 数据搜索和分析任务")
result1 = agent.run("请搜索用户行为数据,然后分析结果趋势")
print(f"结果: {result1}\n")
# 测试用例2:复合任务(搜索 + 分析 + 邮件)
print("测试2: 复合任务")
result2 = agent.run("搜索销售数据,分析趋势,然后将结果发邮件给manager@company.com")
print(f"结果: {result2}\n")
if __name__ == "__main__":
main()