安装
shell
pip install langgraph-checkpoint-sqlite
异步checkpiont初始化:
python
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
conn = aiosqlite.connect(":memory:", check_same_thread=False)
memory = AsyncSqliteSaver(conn)
如果使用异步流式应对,需要确保llm节点或者相关节点也转成异步化操作
python
async def llm(self, state: AgentState):
llm_msgs = state['messages']
if self.systemMessage:
llm_msgs = self.systemMessage + state['messages']
print(f'ask llm to handle request msg, msg: {llm_msgs}')
try:
# 关键修复:await 异步调用并直接获取结果
msg = await self.model.ainvoke(llm_msgs)
print(f'msg={msg}')
return {'messages': [msg]} # 确保返回的是消息对象而非协程
except Exception as e:
print(f"Model invocation error: {e}")
# 返回错误提示消息(需符合Message类型)
from langchain_core.messages import AIMessage
return {'messages': [AIMessage(content=f"Error: {str(e)}")]}
async def take_action_tool(self, state: AgentState):
current_tools: List[ToolCall] = state['messages'][-1].tool_calls
results = []
for t in current_tools:
tool_result = await self.tools[t['name']].ainvoke(t['args'])
results.append(ToolMessage(
tool_call_id=t['id'],
content=str(tool_result),
name=t['name'],
))
print(f'Back to model')
return {'messages': results}
最后的完整代码如下:
python
import asyncio
from typing import Annotated, List, TypedDict
import os
import aiosqlite
from langchain_community.chat_models import ChatTongyi
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from dotenv import load_dotenv
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import BaseTool
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.constants import END, START
from langgraph.graph import add_messages, StateGraph
conn = aiosqlite.connect(":memory:", check_same_thread=False)
load_dotenv(dotenv_path='../keys.env')
ts_tool = TavilySearchResults(max_results=2)
class AgentState(TypedDict):
messages: Annotated[List[AnyMessage], add_messages]
class Agent:
def __init__(
self,
model: BaseChatModel,
systemMessage: List[SystemMessage],
tools: List[BaseTool],
memory,
):
assert all(isinstance(t, BaseTool) for t in tools), 'tools must implement BASEcALL'
graph = StateGraph(AgentState)
graph.add_node('llm', self.llm)
graph.add_node('take_action_tool', self.take_action_tool)
graph.add_conditional_edges(
'llm',
self.exist_action,
{
True: 'take_action_tool',
False: END
}
)
graph.set_entry_point('llm')
graph.add_edge('take_action_tool', 'llm')
self.app = graph.compile(checkpointer=memory)
self.tools = {t.name: t for t in tools}
self.model = model.bind_tools(tools)
self.systemMessage = systemMessage
def exist_action(self, state: AgentState):
tool_calls = state['messages'][-1].tool_calls
print(f'tool_calls size {len(tool_calls)}')
return len(tool_calls) > 0
async def llm(self, state: AgentState):
llm_msgs = state['messages']
if self.systemMessage:
llm_msgs = self.systemMessage + state['messages']
print(f'ask llm to handle request msg, msg: {llm_msgs}')
try:
# 关键修复:await 异步调用并直接获取结果
msg = await self.model.ainvoke(llm_msgs)
print(f'msg={msg}')
return {'messages': [msg]} # 确保返回的是消息对象而非协程
except Exception as e:
print(f"Model invocation error: {e}")
# 返回错误提示消息(需符合Message类型)
from langchain_core.messages import AIMessage
return {'messages': [AIMessage(content=f"Error: {str(e)}")]}
async def take_action_tool(self, state: AgentState):
current_tools: List[ToolCall] = state['messages'][-1].tool_calls
results = []
for t in current_tools:
tool_result = await self.tools[t['name']].ainvoke(t['args'])
results.append(ToolMessage(
tool_call_id=t['id'],
content=str(tool_result),
name=t['name'],
))
print(f'Back to model')
return {'messages': results}
async def work():
prompt = """You are a smart research assistant. Use the search engine to look up information. \
You are allowed to make multiple calls (either together or in sequence). \
Only look up information when you are sure of what you want. \
If you need to look up some information before asking a follow up question, you are allowed to do that!
"""
qwen_model = ChatTongyi(
model=os.getenv('model'),
api_key=os.getenv('api_key'),
base_url=os.getenv('base_url'),
) # reduce inference cost
memory = AsyncSqliteSaver(conn)
agent = Agent(model=qwen_model, tools=[ts_tool], systemMessage=[SystemMessage(content=prompt)], memory=memory)
messages = [HumanMessage("who is the popular football star in the world?")]
configurable = {"configurable": {"thread_id": "1"}}
async for event in agent.app.astream_events({"messages": messages}, configurable, version="v1"):
kind = event["event"]
# print(f"kind = {kind}")
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
# Empty content in the context of OpenAI means
# that the model is asking for a tool to be invoked.
# So we only print non-empty content
print(content, end="|")
if __name__ == '__main__':
asyncio.run(work())