AI Agent Workflow基类及实现类,快速实现一个react agent
基类
复制代码
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Literal
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from .llm_config import LLMConfig, DEFAULT_OLLAMA_CONFIG
class BaseAgentWorkflow(ABC):
"""
Agent工作流的抽象基类
提供最大的灵活性,所有核心组件都需要子类实现:
- 工具定义和管理
- LLM配置
- State状态结构
- StateGraph工作流构建
"""
def __init__(self, llm_config: Optional[LLMConfig] = None):
"""
初始化Agent工作流
Args:
llm_config: LLM配置,如果不提供则使用默认配置
"""
self.llm_config = llm_config or DEFAULT_OLLAMA_CONFIG
self.llm = self.get_llm()
self.tools = self.get_tools()
self.tools_dict = self.get_tools_dict()
self.llm_with_tools = self.llm.bind_tools(self.tools)
self.graph = self.build_graph()
@abstractmethod
def get_tools(self) -> List[BaseTool]:
"""
子类必须实现:返回工具列表
Returns:
工具对象列表
"""
pass
@abstractmethod
def get_tools_dict(self) -> Dict[str, BaseTool]:
"""
子类必须实现:返回工具名称到工具对象的映射
Returns:
工具名称到工具对象的字典映射
"""
pass
@abstractmethod
def get_llm(self) -> ChatOpenAI:
"""
子类必须实现:返回配置好的LLM实例
Returns:
配置好的ChatOpenAI实例
"""
pass
@abstractmethod
def get_state_schema(self) -> type:
"""
子类必须实现:返回State类型定义
Returns:
State类型定义
"""
pass
@abstractmethod
def build_graph(self) -> StateGraph:
"""
子类必须实现:构建自定义的StateGraph工作流
Returns:
编译好的StateGraph实例
"""
pass
# 提供一些可选的通用节点函数,子类可以选择使用
def default_call_model_node(self, state):
"""
默认的模型调用节点(可选使用)
Args:
state: 当前状态
Returns:
包含AI消息的状态更新
"""
ai_msg = self.llm_with_tools.invoke(state['messages'])
子类(继承基类),即插即用,自带测试,可直接运行
复制代码
"""
重构版本 - 通过继承BaseAgentWorkflow实现test_agent_tool_id.py的功能
这个文件展示如何使用新的基类体系来实现与原始 test_agent_tool_id.py 完全相同的功能,
包括相同的工具、工作流结构、执行逻辑和结果处理。
"""
from typing import List, Dict, Any, Literal, Annotated
from typing_extensions import TypedDict
from langchain_core.tools import tool, BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from base import BaseAgentWorkflow, LLMConfig
# ==================== 工具函数定义 ====================
@tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b
@tool
def query_weather(location: str) -> str:
"""查询天气"""
print("------------------------------------------调用 query_weather ------------------------------------------------")
return f"{location} :天气是暴风雨"
@tool
def query_film_info(film: str) -> str:
"""查询电影票房等信息"""
print("------------------------------------------调用 query_film_info ------------------------------------------------")
return f"{film} :电影票房是80亿"
class TestAgent(BaseAgentWorkflow):
"""
重构版本的测试Agent - 使用BaseAgentWorkflow实现
功能与原始 test_agent_tool_id.py 完全相同:
- 数学运算(加法、乘法)
- 天气查询
- 电影信息查询
- 相同的工作流结构和节点命名
- 相同的结果处理逻辑
"""
def __init__(self):
"""初始化Agent的LLM配置"""
llm_config = LLMConfig(
model="qwen3:8b",
api_key="ollama",
base_url="http://localhost:11434/v1/"
)
super().__init__(llm_config)
# ==================== 基类抽象方法实现 ====================
def get_tools(self) -> List[BaseTool]:
"""返回工具列表"""
return [add, multiply, query_weather, query_film_info]
def get_tools_dict(self) -> Dict[str, BaseTool]:
"""返回工具字典映射"""
return {
"add": add,
"multiply": multiply,
"query_weather": query_weather,
"query_film_info": query_film_info
}
def get_llm(self) -> ChatOpenAI:
"""返回配置好的LLM实例"""
return ChatOpenAI(**self.llm_config.to_dict())
def get_state_schema(self) -> type:
"""返回State结构定义"""
class State(TypedDict):
messages: Annotated[list, add_messages]
return State
# ==================== 自定义节点函数 ====================
def call_tool_distribute_model(self, state):
"""
模型调用节点
"""
# 使用绑定工具的模型处理用户查询
ai_msg = self.llm_with_tools.invoke(state['messages'])
# 输出AI消息响应和工具调用信息
print(ai_msg)
print(ai_msg.tool_calls)
return {"messages": [ai_msg]}
def tool_action(self, state):
"""
工具执行节点
"""
messages = []
for tool_call in state['messages'][-1].tool_calls:
# 根据工具调用的名称选择相应的工具函数
selected_tool = self.tools_dict[tool_call["name"].lower()]
print(f"selected_tool: {selected_tool}")
# 调用选中的工具函数并获取结果
tool_msg = selected_tool.invoke(tool_call)
print(f"tool_msg: {tool_msg}")
# 将工具调用的结果添加到messages列表中
messages.append(tool_msg)
return {"messages": messages}
def route_model_output(self, state) -> Literal["__end__", "tool_action"]:
"""
路由函数
"""
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
if not last_message.tool_calls:
return "__end__"
return "tool_action"
# ==================== 自定义StateGraph构建 ====================
def build_graph(self) -> StateGraph:
"""
构建StateGraph工作流
"""
builder = StateGraph(self.get_state_schema())
# 添加节点
builder.add_node("call_tool_distribute_model", self.call_tool_distribute_model)
builder.add_node("tool_action", self.tool_action)
# 添加边
builder.add_edge("__start__", "call_tool_distribute_model")
# 添加条件边
builder.add_conditional_edges(
"call_tool_distribute_model",
self.route_model_output,
)
builder.add_edge("tool_action", "call_tool_distribute_model")
# 编译并设置名称
graph = builder.compile()
graph.name = "ReAct Agent"
return graph
# ==================== 自定义结果处理 ====================
def process_result_detailed(self, result: Dict[str, Any]):
"""
详细结果处理
"""
print("\n=== 工作流执行完成 ===")
# 1. 查找并显示最终答案
final_answer = None
for message in reversed(result["messages"]):
if isinstance(message, AIMessage):
# 查找明确标记为"Answer"的内容
if "Answer:" in message.content:
# 提取Answer之后的内容
final_answer = message.content.split("Answer:", 1)[-1].strip()
break
elif not message.tool_calls: # 最后一条非工具调用的AI消息
final_answer = message.content
break
if final_answer:
print(f"最终结果: {final_answer}")
else:
print("未能找到明确的最终答案")
last_ai_message = result["messages"][-1].content
print(f"最后一条AI回复: {last_ai_message}")
# 2. 检查工具调用结果
print("\n工具调用历史:")
for message in result["messages"]:
if isinstance(message, ToolMessage):
# 打印工具调用返回的结果
print(f"工具 '{message.tool_call_id}': {message.content}")
def run_with_detailed_output(self, message: str):
"""
运行工作流并详细输出
"""
messages = [HumanMessage(content=message)]
result = self.graph.invoke({"messages": messages})
self.process_result_detailed(result)
return result
def main():
"""
主函数 - 演示功能
"""
print("=== 测试Agent (使用BaseAgentWorkflow) ===\n")
# 创建Agent实例
agent = TestAgent()
# 查询
input_message = "上海天气如何,以及电影《哪吒》的票房是多少"
print(f"输入查询: {input_message}")
print("=" * 60)
# 执行工作流(使用详细输出格式)
result = agent.run_with_detailed_output(input_message)
print("\n" + "=" * 60)
print("执行完成!")
# # 额外测试:数学运算
# print("\n=== 额外测试:数学运算 ===")
# result2 = agent.run_with_detailed_output("3 * 12 结果是多少")
def compare_with_original():
"""
说明
"""
print("""
=== 功能说明 ===
1. **工具函数**:
- add() - 加法运算
- multiply() - 乘法运算
- query_weather() - 天气查询
- query_film_info() - 电影信息查询
2. **LLM配置**:
- qwen3:8b 模型
- Ollama本地部署配置
3. **工作流结构**:
- 节点名称:call_tool_distribute_model, tool_action
- 路由函数:route_model_output
4. **优点**:
✨ 使用面向对象的设计,代码更清晰
✨ 继承基类,便于扩展和维护
✨ 可以轻松创建类似的Agent变体
""")
if __name__ == "__main__":
main()
compare_with_original()
一个测试agent
复制代码
"""
重构版本 - 通过继承BaseAgentWorkflow实现test_agent_tool_id.py的功能
这个文件展示如何使用新的基类体系来实现与原始 test_agent_tool_id.py 完全相同的功能,
包括相同的工具、工作流结构、执行逻辑和结果处理。
"""
from typing import List, Dict, Any, Literal, Annotated
from typing_extensions import TypedDict
from langchain_core.tools import tool, BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from base import BaseAgentWorkflow, LLMConfig
# ==================== 工具函数定义 ====================
@tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b
@tool
def query_weather(location: str) -> str:
"""查询天气"""
print("------------------------------------------调用 query_weather ------------------------------------------------")
return f"{location} :天气是暴风雨"
@tool
def query_film_info(film: str) -> str:
"""查询电影票房等信息"""
print("------------------------------------------调用 query_film_info ------------------------------------------------")
return f"{film} :电影票房是80亿"
class TestAgent(BaseAgentWorkflow):
"""
重构版本的测试Agent - 使用BaseAgentWorkflow实现
功能与原始 test_agent_tool_id.py 完全相同:
- 数学运算(加法、乘法)
- 天气查询
- 电影信息查询
- 相同的工作流结构和节点命名
- 相同的结果处理逻辑
"""
def __init__(self):
"""初始化Agent的LLM配置"""
llm_config = LLMConfig(
model="qwen3:8b",
api_key="ollama",
base_url="http://localhost:11434/v1/"
)
super().__init__(llm_config)
# ==================== 基类抽象方法实现 ====================
def get_tools(self) -> List[BaseTool]:
"""返回工具列表"""
return [add, multiply, query_weather, query_film_info]
def get_tools_dict(self) -> Dict[str, BaseTool]:
"""返回工具字典映射"""
return {
"add": add,
"multiply": multiply,
"query_weather": query_weather,
"query_film_info": query_film_info
}
def get_llm(self) -> ChatOpenAI:
"""返回配置好的LLM实例"""
return ChatOpenAI(**self.llm_config.to_dict())
def get_state_schema(self) -> type:
"""返回State结构定义"""
class State(TypedDict):
messages: Annotated[list, add_messages]
return State
# ==================== 自定义节点函数 ====================
def call_tool_distribute_model(self, state):
"""
模型调用节点
"""
# 使用绑定工具的模型处理用户查询
ai_msg = self.llm_with_tools.invoke(state['messages'])
# 输出AI消息响应和工具调用信息
print(ai_msg)
print(ai_msg.tool_calls)
return {"messages": [ai_msg]}
def tool_action(self, state):
"""
工具执行节点
"""
messages = []
for tool_call in state['messages'][-1].tool_calls:
# 根据工具调用的名称选择相应的工具函数
selected_tool = self.tools_dict[tool_call["name"].lower()]
print(f"selected_tool: {selected_tool}")
# 调用选中的工具函数并获取结果
tool_msg = selected_tool.invoke(tool_call)
print(f"tool_msg: {tool_msg}")
# 将工具调用的结果添加到messages列表中
messages.append(tool_msg)
return {"messages": messages}
def route_model_output(self, state) -> Literal["__end__", "tool_action"]:
"""
路由函数
"""
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
if not last_message.tool_calls:
return "__end__"
return "tool_action"
# ==================== 自定义StateGraph构建 ====================
def build_graph(self) -> StateGraph:
"""
构建StateGraph工作流
"""
builder = StateGraph(self.get_state_schema())
# 添加节点
builder.add_node("call_tool_distribute_model", self.call_tool_distribute_model)
builder.add_node("tool_action", self.tool_action)
# 添加边
builder.add_edge("__start__", "call_tool_distribute_model")
# 添加条件边
builder.add_conditional_edges(
"call_tool_distribute_model",
self.route_model_output,
)
builder.add_edge("tool_action", "call_tool_distribute_model")
# 编译并设置名称
graph = builder.compile()
graph.name = "ReAct Agent"
return graph
# ==================== 自定义结果处理 ====================
def process_result_detailed(self, result: Dict[str, Any]):
"""
详细结果处理
"""
print("\n=== 工作流执行完成 ===")
# 1. 查找并显示最终答案
final_answer = None
for message in reversed(result["messages"]):
if isinstance(message, AIMessage):
# 查找明确标记为"Answer"的内容
if "Answer:" in message.content:
# 提取Answer之后的内容
final_answer = message.content.split("Answer:", 1)[-1].strip()
break
elif not message.tool_calls: # 最后一条非工具调用的AI消息
final_answer = message.content
break
if final_answer:
print(f"最终结果: {final_answer}")
else:
print("未能找到明确的最终答案")
last_ai_message = result["messages"][-1].content
print(f"最后一条AI回复: {last_ai_message}")
# 2. 检查工具调用结果
print("\n工具调用历史:")
for message in result["messages"]:
if isinstance(message, ToolMessage):
# 打印工具调用返回的结果
print(f"工具 '{message.tool_call_id}': {message.content}")
def run_with_detailed_output(self, message: str):
"""
运行工作流并详细输出
"""
messages = [HumanMessage(content=message)]
result = self.graph.invoke({"messages": messages})
self.process_result_detailed(result)
return result
def main():
"""
主函数 - 演示功能
"""
print("=== 测试Agent (使用BaseAgentWorkflow) ===\n")
# 创建Agent实例
agent = TestAgent()
# 查询
input_message = "上海天气如何,以及电影《哪吒》的票房是多少"
print(f"输入查询: {input_message}")
print("=" * 60)
# 执行工作流(使用详细输出格式)
result = agent.run_with_detailed_output(input_message)
print("\n" + "=" * 60)
print("执行完成!")
# # 额外测试:数学运算
# print("\n=== 额外测试:数学运算 ===")
# result2 = agent.run_with_detailed_output("3 * 12 结果是多少")
def compare_with_original():
"""
说明
"""
print("""
=== 功能说明 ===
1. **工具函数**:
- add() - 加法运算
- multiply() - 乘法运算
- query_weather() - 天气查询
- query_film_info() - 电影信息查询
2. **LLM配置**:
- qwen3:8b 模型
- Ollama本地部署配置
3. **工作流结构**:
- 节点名称:call_tool_distribute_model, tool_action
- 路由函数:route_model_output
4. **优点**:
✨ 使用面向对象的设计,代码更清晰
✨ 继承基类,便于扩展和维护
✨ 可以轻松创建类似的Agent变体
""")
if __name__ == "__main__":
main()
compare_with_original()
另一个版本的基类(细微改动)
复制代码
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Literal
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from .llm_config import LLMConfig, DEFAULT_OLLAMA_CONFIG
class BaseAgentWorkflow(ABC):
"""
Agent工作流的抽象基类
提供最大的灵活性,所有核心组件都需要子类实现:
- 工具定义和管理
- LLM配置
- State状态结构
- StateGraph工作流构建
"""
def __init__(self, llm_config: Optional[LLMConfig] = None):
"""
初始化Agent工作流
Args:
llm_config: LLM配置,如果不提供则使用默认配置
"""
self.llm_config = llm_config or DEFAULT_OLLAMA_CONFIG
self.llm = self.get_llm()
self.tools = self.get_tools()
self.tools_dict = self.get_tools_dict()
self.llm_with_tools = self.llm.bind_tools(self.tools)
self.graph = self.build_graph()
@abstractmethod
def get_tools(self) -> List[BaseTool]:
"""
子类必须实现:返回工具列表
Returns:
工具对象列表
"""
pass
@abstractmethod
def get_tools_dict(self) -> Dict[str, BaseTool]:
"""
子类必须实现:返回工具名称到工具对象的映射
Returns:
工具名称到工具对象的字典映射
"""
pass
@abstractmethod
def get_llm(self) -> ChatOpenAI:
"""
子类必须实现:返回配置好的LLM实例
Returns:
配置好的ChatOpenAI实例
"""
pass
@abstractmethod
def get_state_schema(self) -> type:
"""
子类必须实现:返回State类型定义
Returns:
State类型定义
"""
pass
@abstractmethod
def build_graph(self) -> StateGraph:
"""
子类必须实现:构建自定义的StateGraph工作流
Returns:
编译好的StateGraph实例
"""
pass
# 提供一些可选的通用节点函数,子类可以选择使用
def default_call_model_node(self, state):
"""
默认的模型调用节点(可选使用)
Args:
state: 当前状态
Returns:
包含AI消息的状态更新
"""
ai_msg = self.llm_with_tools.invoke(state['messages'])
print(f"AI回复: {ai_msg}")
print(f"工具调用: {ai_msg.tool_calls}")
return {"messages": [ai_msg]}
def default_tool_action_node(self, state):
"""
默认的工具执行节点(可选使用)
Args:
state: 当前状态
Returns:
包含工具执行结果的状态更新
"""
messages = []
for tool_call in state['messages'][-1].tool_calls:
selected_tool = self.tools_dict[tool_call["name"].lower()]
print(f"执行工具: {selected_tool}")
tool_msg = selected_tool.invoke(tool_call)
print(f"工具结果: {tool_msg}")
messages.append(tool_msg)
return {"messages": messages}
def default_route_function(self, state) -> Literal["__end__", "tool_action"]:
"""
默认的路由函数(可选使用)
Args:
state: 当前状态
Returns:
下一个节点名称
"""
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
if not last_message.tool_calls:
return "__end__"
return "tool_action"
def invoke(self, message: str) -> Dict[str, Any]:
"""
标准执行接口
Args:
message: 用户输入消息
Returns:
执行结果字典
"""
messages = [HumanMessage(content=message)]
return self.graph.invoke({"messages": messages})
def process_result(self, result: Dict[str, Any]) -> str:
"""
默认结果处理,子类可重写
Args:
result: graph.invoke的返回结果
Returns:
处理后的最终答案字符串
"""
# 查找并返回最终答案
final_answer = None
for message in reversed(result["messages"]):
if isinstance(message, AIMessage):
if "Answer:" in message.content:
final_answer = message.content.split("Answer:", 1)[-1].strip()
break
elif not message.tool_calls:
final_answer = message.content
break
return final_answer or "未能找到明确的最终答案"
def run(self, message: str, verbose: bool = True) -> str:
"""
运行工作流并返回处理后的结果
Args:
message: 用户输入消息
verbose: 是否打印详细信息
Returns:
最终答案字符串
"""
if verbose:
print(f"用户输入: {message}")
print("=" * 50)
result = self.invoke(message)
final_answer = self.process_result(result)
if verbose:
print("=" * 50)
print(f"最终结果: {final_answer}")
# 打印工具调用历史
print("\n工具调用历史:")
for message in result["messages"]:
if isinstance(message, ToolMessage):
print(f"工具 '{message.tool_call_id}': {message.content}")
return final_answer