LangGraph持久化管理
引言
LangGraph作为一个用于构建有状态多参与者应用程序的库,其核心优势之一就是持久性。持久性允许应用程序在多次交互中保持状态,这对于构建具有记忆功能的代理系统至关重要。本教程将详细介绍LangGraph中的持久化机制及其管理方法,帮助开发者构建具有记忆能力的智能应用。
1. 持久化的重要性
在代理系统中,持久化主要解决以下几个关键问题:
- 记忆保持:允许代理在多次交互中记住之前的对话和决策
- 人机协作:支持人类检查、中断和批准代理的步骤
- 状态恢复:在系统崩溃或重启后恢复之前的状态
- 长时间运行:支持跨越长时间的任务执行和多轮对话
2. 检查点(Checkpoints)机制
LangGraph具有一个内置的持久化层,通过检查点(Checkpoints)实现。检查点在每个超级步骤中保存图形状态的快照,从而实现一些强大的功能。
2.1 检查点的作用
-
促进人机交互工作流:允许人类检查、中断和批准步骤。检查点对于这些工作流是必需的,因为人类必须能够在任何时候查看图形的状态,并且图形必须能够在人类对状态进行任何更新后恢复执行。
-
实现"记忆"功能:允许在交互之间保持记忆。您可以使用检查点创建线程并在图形执行后保存线程的状态。在重复的人类交互(例如对话)的情况下,任何后续消息都可以发送到该检查点,该检查点将保留对其以前消息的记忆。
2.2 检查点类型
LangGraph提供了几种不同类型的检查点实现:
- MemoryCheckpointer:将状态保存在内存中,适用于开发和测试
- FileCheckpointer:将状态保存到文件系统,适用于持久化存储
- SQLiteCheckpointer:使用SQLite数据库存储状态
- RedisCheckpointer:使用Redis存储状态,适用于分布式系统
3. 添加持久性到LangGraph
在创建LangGraph工作流时,可以通过以下两个步骤设置持久性:
- 创建一个检查点(Checkpoint),例如MemoryCheckpointer
- 在编译图时调用compile(checkpointer=my_checkpointer)
3.1 基本示例
python
from langgraph.checkpoint import MemoryCheckpointer
from langgraph.graph import StateGraph
from typing import TypedDict, List
# 定义状态类型
class ChatState(TypedDict):
messages: List[dict]
# 创建图
graph = StateGraph(ChatState)
# 添加节点和边
# ...
# 创建检查点
checkpointer = MemoryCheckpointer()
# 编译图并添加持久性
app = graph.compile(checkpointer=checkpointer)
这适用于StateGraph及其所有子类,例如MessageGraph。
4. 完整的持久化代理示例
下面我们将创建一个完整的持久化代理示例,展示如何使用检查点来保持对话历史。
4.1 设置环境
首先,我们需要安装所需的软件包:
bash
pip install langgraph langchain-openai langchain-core tavily-python
然后设置必要的API密钥:
python
import os
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 设置API密钥
os.environ["OPENAI_API_KEY"] = "your-openai-api-key"
os.environ["TAVILY_API_KEY"] = "your-tavily-api-key" # 可选,用于搜索工具
4.2 设置状态
状态是所有节点的接口:
python
from typing import TypedDict, List, Union, Annotated
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
# 定义状态类型
class AgentState(TypedDict):
messages: List[BaseMessage] # 对话历史
next_steps: List[str] # 可选的下一步操作
4.3 设置工具
我们将定义一个简单的搜索工具:
python
from langchain_core.tools import tool
from langchain_core.pydantic_v1 import BaseModel, Field
from tavily import TavilyClient
# 定义搜索工具
@tool
def search(query: str) -> str:
"""搜索互联网以获取有关查询的最新信息。"""
client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY", ""))
result = client.search(query=query)
return str(result)
# 创建工具列表
tools = [search]
# 创建工具节点
from langgraph.prebuilt import ToolNode
tool_node = ToolNode(tools)
4.4 设置模型
现在我们需要加载聊天模型:
python
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
# 创建聊天模型
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# 将工具绑定到模型
model_with_tools = model.bind_tools(tools)
4.5 定义图
我们需要定义代理节点和条件边:
python
from langgraph.graph import StateGraph, START, END
# 定义代理节点
def agent(state: AgentState):
"""代理节点,决定要采取哪些(如果有)操作。"""
messages = state["messages"]
# 调用模型
response = model_with_tools.invoke(messages)
# 返回更新的状态
return {"messages": messages + [response]}
# 定义路由函数
def should_continue(state: AgentState) -> Union[str, list[str]]:
"""决定是否继续执行工具或结束。"""
messages = state["messages"]
last_message = messages[-1]
# 检查是否有工具调用
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tool_node"
else:
return END
# 创建图
graph = StateGraph(AgentState)
# 添加节点
graph.add_node("agent", agent)
graph.add_node("tool_node", tool_node)
# 添加边
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", should_continue)
graph.add_edge("tool_node", "agent")
4.6 添加持久性
现在我们可以添加持久性:
python
from langgraph.checkpoint import MemoryCheckpointer
# 创建检查点
checkpointer = MemoryCheckpointer()
# 编译图并添加持久性
app = graph.compile(checkpointer=checkpointer)
4.7 与代理交互
现在我们可以与代理进行交互,并看到它会记住以前的消息:
python
# 初始化状态
messages = [HumanMessage(content="什么是量子计算?")]
result = app.invoke({"messages": messages}, {"thread_id": "thread_1"})
# 继续对话
messages = result["messages"] + [HumanMessage(content="它与传统计算有什么区别?")]
result = app.invoke({"messages": messages}, {"thread_id": "thread_1"})
# 开始新的对话
new_messages = [HumanMessage(content="告诉我关于人工智能的信息")]
new_result = app.invoke({"messages": new_messages}, {"thread_id": "thread_2"})
所有检查点都将持久保存,因此您可以随时恢复以前的线程。
5. 管理对话历史
持久性的最常见用例之一是使用它来跟踪对话历史。然而,随着对话越来越长,对话历史会不断累积,并占用越来越多的上下文窗口。这通常是不可取的,因为它会导致对LLM的调用更加昂贵和耗时,并可能导致错误。为了防止这种情况发生,您可能需要管理对话历史记录。
5.1 过滤消息
为了防止对话历史爆炸,最直接的方法是在将消息传递给LLM之前过滤消息列表:
python
def filter_messages(state: AgentState):
"""过滤消息以保持对话历史在合理大小。"""
messages = state["messages"]
# 保留最近的10条消息
if len(messages) > 10:
# 始终保留系统消息(如果有)
system_messages = [msg for msg in messages if msg.type == "system"]
# 保留最近的人类和AI消息
recent_messages = messages[-10:]
# 合并系统消息和最近消息
filtered_messages = system_messages + [
msg for msg in recent_messages if msg.type != "system"
]
return {"messages": filtered_messages}
return {} # 如果消息数量不超过阈值,不做任何更改
# 将过滤节点添加到图中
graph.add_node("filter_messages", filter_messages)
graph.add_edge(START, "filter_messages")
graph.add_edge("filter_messages", "agent")
5.2 高级过滤策略
除了简单的保留最近N条消息外,还可以实现更复杂的过滤策略:
5.2.1 基于令牌计数的过滤
python
import tiktoken
def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int:
"""计算文本中的令牌数。"""
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
def filter_messages_by_tokens(state: AgentState, max_tokens: int = 3000):
"""基于令牌数过滤消息。"""
messages = state["messages"]
total_tokens = sum(count_tokens(msg.content) for msg in messages)
if total_tokens > max_tokens:
# 始终保留系统消息
system_messages = [msg for msg in messages if msg.type == "system"]
# 保留最近的人类和AI消息对
filtered_messages = system_messages
remaining_tokens = max_tokens - sum(count_tokens(msg.content) for msg in system_messages)
# 从最近的消息开始添加,直到达到令牌限制
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if msg.type != "system":
msg_tokens = count_tokens(msg.content)
if remaining_tokens - msg_tokens >= 0:
filtered_messages.append(msg)
remaining_tokens -= msg_tokens
else:
break
# 反转列表以保持原始顺序
filtered_messages = system_messages + [msg for msg in filtered_messages if msg.type != "system"]
filtered_messages.sort(key=lambda x: messages.index(x))
return {"messages": filtered_messages}
return {}
5.2.2 基于重要性的过滤
python
def filter_messages_by_importance(state: AgentState):
"""基于消息重要性过滤。"""
messages = state["messages"]
if len(messages) > 10:
# 始终保留系统消息和最近的对话
system_messages = [msg for msg in messages if msg.type == "system"]
recent_messages = messages[-4:] # 最近的2轮对话
# 使用LLM评估中间消息的重要性
middle_messages = messages[len(system_messages):-4]
if middle_messages:
# 这里可以使用LLM来评估每条消息的重要性
# 简化版:假设我们已经评估了重要性并保留前3条重要消息
important_messages = middle_messages[:3]
# 合并所有要保留的消息
filtered_messages = system_messages + important_messages + recent_messages
return {"messages": filtered_messages}
return {}
5.3 消息摘要
另一种管理对话历史的方法是创建历史消息的摘要:
python
def summarize_history(state: AgentState):
"""将长对话历史摘要为一条系统消息。"""
messages = state["messages"]
if len(messages) > 10:
# 保留最近的消息
recent_messages = messages[-4:]
# 获取需要摘要的消息
history_to_summarize = messages[:-4]
# 使用LLM创建摘要
history_text = "\n".join([f"{msg.type}: {msg.content}" for msg in history_to_summarize])
summary_prompt = f"请总结以下对话历史,保留关键信息:\n\n{history_text}"
summary = model.invoke([HumanMessage(content=summary_prompt)])
# 创建包含摘要的系统消息
summary_message = SystemMessage(content=f"对话历史摘要:{summary.content}")
# 返回摘要消息和最近消息
filtered_messages = [summary_message] + recent_messages
return {"messages": filtered_messages}
return {}
6. 高级持久化技术
6.1 使用数据库存储检查点
对于生产环境,使用数据库存储检查点通常是更好的选择:
python
from langgraph.checkpoint import SQLiteCheckpointer
# 创建SQLite检查点
sqlite_checkpointer = SQLiteCheckpointer("agent_checkpoints.db")
# 编译图并添加持久性
app = graph.compile(checkpointer=sqlite_checkpointer)
6.2 分布式检查点
对于分布式系统,可以使用Redis存储检查点:
python
from langgraph.checkpoint import RedisCheckpointer
# 创建Redis检查点
redis_checkpointer = RedisCheckpointer(
redis_url="redis://localhost:6379/0",
ttl=3600 # 检查点的生存时间(秒)
)
# 编译图并添加持久性
app = graph.compile(checkpointer=redis_checkpointer)
6.3 检查点事件处理
可以为检查点添加事件处理器,以在保存或加载检查点时执行自定义逻辑:
python
from langgraph.checkpoint import MemoryCheckpointer
class CustomCheckpointer(MemoryCheckpointer):
def __init__(self):
super().__init__()
def save(self, key, state):
# 在保存前执行自定义逻辑
print(f"保存检查点: {key}")
# 可以在这里添加日志记录、监控或其他处理
return super().save(key, state)
def load(self, key):
# 在加载前执行自定义逻辑
print(f"加载检查点: {key}")
return super().load(key)
# 创建自定义检查点
custom_checkpointer = CustomCheckpointer()
# 编译图并添加持久性
app = graph.compile(checkpointer=custom_checkpointer)
7. 实际应用示例:具有记忆管理的客服代理
下面是一个完整的实际应用示例,展示了如何创建一个具有记忆管理功能的客服代理:
python
import os
from typing import TypedDict, List, Union, Annotated
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint import MemoryCheckpointer
import tiktoken
# 定义状态类型
class CustomerServiceState(TypedDict):
messages: List[BaseMessage]
customer_info: dict
# 创建令牌计数函数
def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int:
"""计算文本中的令牌数。"""
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
# 创建模型
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.2)
# 定义节点函数
def filter_and_summarize(state: CustomerServiceState):
"""过滤和摘要消息历史。"""
messages = state["messages"]
# 计算当前消息的总令牌数
total_tokens = sum(count_tokens(msg.content) for msg in messages)
# 如果总令牌数超过阈值,进行摘要
if total_tokens > 2000:
# 保留系统消息和最近的2轮对话(4条消息)
system_messages = [msg for msg in messages if msg.type == "system"]
recent_messages = messages[-4:]
# 获取需要摘要的消息
history_to_summarize = [msg for msg in messages if msg not in system_messages and msg not in recent_messages]
if history_to_summarize:
# 创建摘要提示
history_text = "\n".join([f"{msg.type}: {msg.content}" for msg in history_to_summarize])
summary_prompt = f"请总结以下客服对话历史,保留客户问题、需求和关键信息:\n\n{history_text}"
# 生成摘要
summary_response = model.invoke([HumanMessage(content=summary_prompt)])
# 创建包含摘要的系统消息
summary_message = SystemMessage(content=f"对话历史摘要:{summary_response.content}")
# 返回摘要消息和最近消息
filtered_messages = system_messages + [summary_message] + recent_messages
return {"messages": filtered_messages}
return {}
def update_customer_info(state: CustomerServiceState):
"""更新客户信息。"""
messages = state["messages"]
customer_info = state.get("customer_info", {})
if len(messages) >= 2:
# 最新的对话
recent_messages = messages[-2:]
# 提取客户信息
extraction_prompt = f"""
从以下对话中提取客户信息。如果找到新信息,以JSON格式返回。
仅提取确定的信息,不要猜测。
对话:
{recent_messages[0].content}
{recent_messages[1].content}
已知信息:
{customer_info}
仅返回JSON格式的新信息,如果没有新信息,返回空JSON {{}}。
"""
extraction_response = model.invoke([HumanMessage(content=extraction_prompt)])
# 尝试解析JSON响应
import json
try:
new_info = json.loads(extraction_response.content.strip())
if new_info:
# 更新客户信息
updated_info = {**customer_info, **new_info}
return {"customer_info": updated_info}
except:
pass
return {}
def customer_service_agent(state: CustomerServiceState):
"""客服代理节点。"""
messages = state["messages"]
customer_info = state.get("customer_info", {})
# 创建系统提示
system_prompt = """
你是一位专业的客服代表。请根据客户的问题提供礼貌、准确的回答。
如果客户有具体问题,请尽量解决。如果需要更多信息,请礼貌地询问。
"""
# 如果有客户信息,添加到系统提示
if customer_info:
system_prompt += f"\n\n客户信息:{json.dumps(customer_info, ensure_ascii=False)}"
# 准备消息列表
prepared_messages = [SystemMessage(content=system_prompt)]
# 添加对话历史
for msg in messages:
if msg.type != "system": # 避免重复系统消息
prepared_messages.append(msg)
# 调用模型
response = model.invoke(prepared_messages)
# 返回更新的状态
return {"messages": messages + [response]}
def should_end(state: CustomerServiceState):
"""检查是否应该结束对话。"""
messages = state["messages"]
if len(messages) >= 2:
last_message = messages[-1]
# 检查是否包含结束对话的关键词
end_keywords = ["再见", "谢谢", "感谢", "结束", "拜拜", "goodbye", "thank you", "thanks", "bye"]
if any(keyword in last_message.content.lower() for keyword in end_keywords):
return END
return "customer_service_agent"
# 创建图
graph = StateGraph(CustomerServiceState)
# 添加节点
graph.add_node("filter_and_summarize", filter_and_summarize)
graph.add_node("update_customer_info", update_customer_info)
graph.add_node("customer_service_agent", customer_service_agent)
# 添加边
graph.add_edge(START, "filter_and_summarize")
graph.add_edge("filter_and_summarize", "update_customer_info")
graph.add_edge("update_customer_info", "customer_service_agent")
graph.add_conditional_edges("customer_service_agent", should_end)
# 创建检查点
checkpointer = MemoryCheckpointer()
# 编译图并添加持久性
app = graph.compile(checkpointer=checkpointer)
# 使用示例
def chat_with_agent(message: str, thread_id: str = "default"):
"""与客服代理聊天。"""
# 尝试加载现有状态
try:
current_state = checkpointer.load(thread_id)
messages = current_state["messages"]
messages.append(HumanMessage(content=message))
new_state = {"messages": messages}
except:
# 如果没有现有状态,创建新的
new_state = {
"messages": [HumanMessage(content=message)],
"customer_info": {}
}
# 调用代理
result = app.invoke(new_state, {"thread_id": thread_id})
# 返回最后一条消息
return result["messages"][-1].content
# 示例对话
response1 = chat_with_agent("你好,我想查询我的订单状态", "customer_1")
print(f"代理: {response1}")
response2 = chat_with_agent("我的订单号是ABC123", "customer_1")
print(f"代理: {response2}")
response3 = chat_with_agent("我的邮箱是example@email.com", "customer_1")
print(f"代理: {response3}")
# 几天后继续对话
response4 = chat_with_agent("你好,我之前查询的订单有更新吗?", "customer_1")
print(f"代理: {response4}")
8. 总结
LangGraph的持久化机制通过检查点提供了强大的状态管理能力,使开发者能够构建具有记忆功能的智能应用。主要优势包括:
- 灵活的存储选项:从内存存储到数据库存储,适应不同的应用场景
- 线程管理:通过thread_id管理多个独立的对话或任务
- 状态恢复:能够在任何时候恢复之前的状态
- 记忆管理:提供多种策略来管理长对话历史,防止上下文窗口爆炸
通过合理使用LangGraph的持久化功能,开发者可以创建更加智能、自然的对话代理,提供连贯一致的用户体验。