LangGraph异步化sqlite checkpoint

安装

shell 复制代码
pip install langgraph-checkpoint-sqlite

异步checkpiont初始化:

python 复制代码
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
conn = aiosqlite.connect(":memory:", check_same_thread=False)
memory = AsyncSqliteSaver(conn)

如果使用异步流式应对,需要确保llm节点或者相关节点也转成异步化操作

python 复制代码
    async def llm(self, state: AgentState):
        llm_msgs = state['messages']
        if self.systemMessage:
            llm_msgs = self.systemMessage + state['messages']
            print(f'ask llm to handle request msg, msg: {llm_msgs}')
        try:
            # 关键修复:await 异步调用并直接获取结果
            msg = await self.model.ainvoke(llm_msgs)
            print(f'msg={msg}')
            return {'messages': [msg]}  # 确保返回的是消息对象而非协程
        except Exception as e:
            print(f"Model invocation error: {e}")
            # 返回错误提示消息(需符合Message类型)
            from langchain_core.messages import AIMessage
            return {'messages': [AIMessage(content=f"Error: {str(e)}")]}
    async def take_action_tool(self, state: AgentState):
        current_tools: List[ToolCall] = state['messages'][-1].tool_calls
        results = []
        for t in current_tools:
            tool_result = await self.tools[t['name']].ainvoke(t['args'])
            results.append(ToolMessage(
                tool_call_id=t['id'],
                content=str(tool_result),
                name=t['name'],
            ))
        print(f'Back to model')
        return {'messages': results}

最后的完整代码如下:

python 复制代码
import asyncio
from typing import Annotated, List, TypedDict
import os

import aiosqlite
from langchain_community.chat_models import ChatTongyi
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from dotenv import load_dotenv
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import BaseTool
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.constants import END, START
from langgraph.graph import add_messages, StateGraph

conn = aiosqlite.connect(":memory:", check_same_thread=False)
load_dotenv(dotenv_path='../keys.env')
ts_tool = TavilySearchResults(max_results=2)


class AgentState(TypedDict):
    messages: Annotated[List[AnyMessage], add_messages]


class Agent:

    def __init__(
            self,
            model: BaseChatModel,
            systemMessage: List[SystemMessage],
            tools: List[BaseTool],
            memory,
    ):

        assert all(isinstance(t, BaseTool) for t in tools), 'tools must implement BASEcALL'
        graph = StateGraph(AgentState)

        graph.add_node('llm', self.llm)
        graph.add_node('take_action_tool', self.take_action_tool)
        graph.add_conditional_edges(
            'llm',
            self.exist_action,
            {
                True: 'take_action_tool',
                False: END
            }
        )
        graph.set_entry_point('llm')
        graph.add_edge('take_action_tool', 'llm')

        self.app = graph.compile(checkpointer=memory)

        self.tools = {t.name: t for t in tools}
        self.model = model.bind_tools(tools)
        self.systemMessage = systemMessage

    def exist_action(self, state: AgentState):
        tool_calls = state['messages'][-1].tool_calls
        print(f'tool_calls size {len(tool_calls)}')
        return len(tool_calls) > 0

    async def llm(self, state: AgentState):
        llm_msgs = state['messages']
        if self.systemMessage:
            llm_msgs = self.systemMessage + state['messages']
            print(f'ask llm to handle request msg, msg: {llm_msgs}')
        try:
            # 关键修复:await 异步调用并直接获取结果
            msg = await self.model.ainvoke(llm_msgs)
            print(f'msg={msg}')
            return {'messages': [msg]}  # 确保返回的是消息对象而非协程
        except Exception as e:
            print(f"Model invocation error: {e}")
            # 返回错误提示消息(需符合Message类型)
            from langchain_core.messages import AIMessage
            return {'messages': [AIMessage(content=f"Error: {str(e)}")]}

    async def take_action_tool(self, state: AgentState):
        current_tools: List[ToolCall] = state['messages'][-1].tool_calls
        results = []
        for t in current_tools:
            tool_result = await self.tools[t['name']].ainvoke(t['args'])
            results.append(ToolMessage(
                tool_call_id=t['id'],
                content=str(tool_result),
                name=t['name'],
            ))
        print(f'Back to model')
        return {'messages': results}


async def work():
    prompt = """You are a smart research assistant. Use the search engine to look up information. \
    You are allowed to make multiple calls (either together or in sequence). \
    Only look up information when you are sure of what you want. \
    If you need to look up some information before asking a follow up question, you are allowed to do that!
    """

    qwen_model = ChatTongyi(
        model=os.getenv('model'),
        api_key=os.getenv('api_key'),
        base_url=os.getenv('base_url'),
    )  # reduce inference cost
    memory = AsyncSqliteSaver(conn)
    agent = Agent(model=qwen_model, tools=[ts_tool], systemMessage=[SystemMessage(content=prompt)], memory=memory)

    messages = [HumanMessage("who is the popular football star in the world?")]
    configurable = {"configurable": {"thread_id": "1"}}
    async for event in agent.app.astream_events({"messages": messages}, configurable, version="v1"):
        kind = event["event"]
        # print(f"kind = {kind}")
        if kind == "on_chat_model_stream":
            content = event["data"]["chunk"].content
            if content:
                # Empty content in the context of OpenAI means
                # that the model is asking for a tool to be invoked.
                # So we only print non-empty content
                print(content, end="|")


if __name__ == '__main__':
    asyncio.run(work())
相关推荐
FreakStudio1 小时前
一文速通 Python 并行计算:07 Python 多线程编程-线程池的使用和多线程的性能评估
python·单片机·嵌入式·多线程·面向对象·并行计算·电子diy
小臭希3 小时前
python蓝桥杯备赛常用算法模板
开发语言·python·蓝桥杯
mosaicwang3 小时前
dnf install openssl失败的原因和解决办法
linux·运维·开发语言·python
蹦蹦跳跳真可爱5894 小时前
Python----机器学习(基于PyTorch的乳腺癌逻辑回归)
人工智能·pytorch·python·分类·逻辑回归·学习方法
Bruce_Liuxiaowei4 小时前
基于Flask的Windows事件ID查询系统开发实践
windows·python·flask
carpell4 小时前
二叉树实战篇1
python·二叉树·数据结构与算法
HORSE RUNNING WILD4 小时前
为什么我们需要if __name__ == __main__:
linux·python·bash·学习方法
凡人的AI工具箱4 小时前
PyTorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(三)
人工智能·pytorch·python·深度学习·学习·生成对抗网络
码上通天地5 小时前
Python六大数据类型与可变类型
开发语言·python
Tiger_shl5 小时前
【Python语言基础】19、垃圾回收
java·python