LangChain自定义Callback组件
引言
在构建LLM应用时,我们常常需要监控应用的运行状态、收集性能指标或实现流式输出等功能。LangChain提供了强大的回调系统,允许开发者连接到应用程序的各个阶段,实现日志记录、监控、流式处理等功能。本教程将深入探讨LangChain的回调系统,并通过实例演示如何创建和使用自定义回调组件。
1. 回调系统概念
什么是回调(Callback)
回调是一种编程模式,允许在特定事件发生时执行自定义代码。在LangChain中,回调系统使开发者能够在LLM应用执行过程中的关键点插入自定义逻辑,例如:
- 在LLM开始生成内容时记录日志
- 在每个新token生成时实现流式输出
- 在链式执行完成时收集性能指标
- 在错误发生时发送警报
回调事件(Callback Events)
LangChain定义了一系列可以订阅的事件,每个事件对应应用执行过程中的特定阶段:
Event | Event Trigger | Associated Method |
---|---|---|
Chat model start | 当聊天模型开始执行时 | on_chat_model_start |
LLM start | 当LLM模型开始执行时 | on_llm_start |
LLM new token | 当LLM或聊天模型生成新token时 | on_llm_new_token |
LLM ends | 当LLM或聊天模型执行结束时 | on_llm_end |
LLM errors | 当LLM或聊天模型发生错误时 | on_llm_error |
Chain start | 当链开始执行时 | on_chain_start |
Chain end | 当链执行结束时 | on_chain_end |
Chain error | 当链执行发生错误时 | on_chain_error |
Tool start | 当工具开始执行时 | on_tool_start |
Tool end | 当工具执行结束时 | on_tool_end |
Tool error | 当工具执行发生错误时 | on_tool_error |
Agent action | 当代理执行动作时 | on_agent_action |
Agent finish | 当代理执行结束时 | on_agent_finish |
Retriever start | 当检索器开始执行时 | on_retriever_start |
Retriever end | 当检索器执行结束时 | on_retriever_end |
Retriever error | 当检索器执行发生错误时 | on_retriever_error |
Text | 当执行任意文本时 | on_text |
Retry | 当执行重试事件时 | on_retry |
2. 回调处理程序(Callback Handlers)
回调处理程序是实现了BaseCallbackHandler
接口的对象,该接口为每个可订阅的事件定义了对应的方法。当事件触发时,CallbackManager会在每个处理程序上调用适当的方法。
传递回调处理程序的方式
在LangChain中,有两种主要方式传递回调处理程序:
1. 构造函数回调
在创建对象时通过构造函数传递回调处理程序:
python
from langchain_openai import ChatOpenAI
from langchain.callbacks import StdOutCallbackHandler
handler = StdOutCallbackHandler()
chat = ChatOpenAI(callbacks=[handler], tags=['chat-model'])
这种方式下,回调处理程序仅用于该特定对象的所有调用。例如,如果在链中使用上述聊天模型,回调处理程序只会在对该模型的调用中被触发。
2. 请求回调
在调用invoke
方法时通过config
参数传递回调处理程序:
python
from langchain.callbacks import StdOutCallbackHandler
handler = StdOutCallbackHandler()
result = chat.invoke("你好", config={"callbacks": [handler]})
这种方式下,回调处理程序仅用于该特定请求及其包含的所有子请求。
3. 运行时回调
在执行运行时通过callbacks
关键字参数传递回调处理程序:
python
from langchain.agents import AgentExecutor
from langchain.callbacks import StdOutCallbackHandler
handler = StdOutCallbackHandler()
agent_executor = AgentExecutor(agent=agent, tools=tools)
agent_executor.invoke({"input": "查询天气"}, callbacks=[handler])
这种方式下,回调处理程序将用于执行过程中涉及的所有嵌套对象,包括代理、工具和LLM等。这避免了手动将处理程序附加到每个单独的嵌套对象上。
3. 创建自定义回调处理程序
LangChain提供了一些内置的回调处理程序,但在实际应用中,我们通常需要创建具有自定义逻辑的处理程序。
基本步骤
创建自定义回调处理程序的基本步骤如下:
- 继承
BaseCallbackHandler
类 - 实现需要处理的事件对应的方法
- 将自定义处理程序附加到LangChain对象上
示例:实现流式输出
下面是一个实现流式输出的自定义回调处理程序示例:
python
from langchain.callbacks.base import BaseCallbackHandler
from langchain_openai import ChatOpenAI
class StreamingCallbackHandler(BaseCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs) -> None:
"""当LLM生成新token时打印出来,实现流式输出"""
print(token, end="", flush=True)
# 创建带有流式处理的聊天模型
streaming_handler = StreamingCallbackHandler()
streaming_llm = ChatOpenAI(
streaming=True,
callbacks=[streaming_handler],
temperature=0
)
# 使用流式模型
response = streaming_llm.invoke("讲一个简短的笑话")
print("\n完成!")
在这个示例中,我们创建了一个名为StreamingCallbackHandler
的自定义处理程序,实现了on_llm_new_token
方法来打印每个新生成的token,从而实现流式输出效果。
示例:记录执行时间
下面是一个记录各组件执行时间的自定义回调处理程序示例:
python
import time
from langchain.callbacks.base import BaseCallbackHandler
from langchain_openai import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
class TimingCallbackHandler(BaseCallbackHandler):
def __init__(self) -> None:
super().__init__()
self.timing_records = {}
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
self.timing_records["llm_start"] = time.time()
print(f"LLM开始执行,提示词长度: {len(prompts[0])}")
def on_llm_end(self, response, **kwargs) -> None:
llm_end = time.time()
llm_start = self.timing_records.get("llm_start")
if llm_start:
execution_time = llm_end - llm_start
print(f"LLM执行完成,耗时: {execution_time:.2f}秒")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
self.timing_records["chain_start"] = time.time()
chain_type = serialized.get("name", "未知")
print(f"链 '{chain_type}' 开始执行")
def on_chain_end(self, outputs, **kwargs) -> None:
chain_end = time.time()
chain_start = self.timing_records.get("chain_start")
if chain_start:
execution_time = chain_end - chain_start
print(f"链执行完成,耗时: {execution_time:.2f}秒")
# 创建回调处理程序
timing_handler = TimingCallbackHandler()
# 创建LLM和链
llm = ChatOpenAI(temperature=0)
prompt = PromptTemplate.from_template("请解释{concept}是什么?")
chain = LLMChain(llm=llm, prompt=prompt)
# 执行链并传递回调处理程序
response = chain.invoke({"concept": "量子计算"}, callbacks=[timing_handler])
print(f"结果: {response['text'][:50]}...")
这个示例创建了一个TimingCallbackHandler
,记录LLM和链的开始和结束时间,并计算执行耗时。
4. 高级回调功能
多个回调处理程序
可以同时使用多个回调处理程序,每个处理程序负责不同的功能:
python
from langchain.callbacks import StdOutCallbackHandler
from langchain.callbacks.tracers import LangChainTracer
# 标准输出处理程序
stdout_handler = StdOutCallbackHandler()
# 追踪处理程序
tracer = LangChainTracer()
# 同时使用两个处理程序
llm = ChatOpenAI(callbacks=[stdout_handler, tracer])
条件回调
可以创建只在特定条件下触发的回调处理程序:
python
class ConditionalCallbackHandler(BaseCallbackHandler):
def __init__(self, condition_func):
super().__init__()
self.condition_func = condition_func
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
if self.condition_func(prompts[0]):
print(f"触发条件回调: {prompts[0][:50]}...")
# 只有当提示词包含"紧急"时才触发回调
condition = lambda prompt: "紧急" in prompt
conditional_handler = ConditionalCallbackHandler(condition)
异步回调
LangChain也支持异步回调,适用于异步应用场景:
python
from langchain.callbacks.base import AsyncCallbackHandler
class AsyncStreamingCallbackHandler(AsyncCallbackHandler):
async def on_llm_new_token(self, token: str, **kwargs) -> None:
# 异步处理新token
print(token, end="", flush=True)
# 可以在这里执行其他异步操作
5. 实用回调示例
日志记录回调
记录应用程序执行过程中的详细日志:
python
import logging
from langchain.callbacks.base import BaseCallbackHandler
class LoggingCallbackHandler(BaseCallbackHandler):
def __init__(self):
super().__init__()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger("LangChain")
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
self.logger.info(f"开始LLM调用,模型: {serialized.get('name', '未知')}")
def on_llm_end(self, response, **kwargs) -> None:
self.logger.info("LLM调用完成")
def on_llm_error(self, error, **kwargs) -> None:
self.logger.error(f"LLM调用出错: {error}")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
chain_type = serialized.get("name", "未知")
self.logger.info(f"开始执行链: {chain_type}")
def on_chain_end(self, outputs, **kwargs) -> None:
self.logger.info("链执行完成")
def on_chain_error(self, error, **kwargs) -> None:
self.logger.error(f"链执行出错: {error}")
# 使用日志记录回调
logging_handler = LoggingCallbackHandler()
llm = ChatOpenAI(callbacks=[logging_handler])
性能监控回调
监控和收集性能指标:
python
from langchain.callbacks.base import BaseCallbackHandler
import time
import psutil
class PerformanceMonitorCallback(BaseCallbackHandler):
def __init__(self):
super().__init__()
self.start_time = None
self.records = {
"llm_calls": 0,
"chain_calls": 0,
"total_tokens": 0,
"total_time": 0,
"max_memory": 0,
}
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
self.start_time = time.time()
self.records["llm_calls"] += 1
def on_llm_end(self, response, **kwargs) -> None:
if self.start_time:
execution_time = time.time() - self.start_time
self.records["total_time"] += execution_time
# 估算token数量
if hasattr(response, "llm_output") and response.llm_output and "token_usage" in response.llm_output:
self.records["total_tokens"] += response.llm_output["token_usage"].get("total_tokens", 0)
# 记录内存使用
current_memory = psutil.Process().memory_info().rss / (1024 * 1024) # MB
self.records["max_memory"] = max(self.records["max_memory"], current_memory)
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
self.records["chain_calls"] += 1
def get_performance_summary(self):
return {
"llm_calls": self.records["llm_calls"],
"chain_calls": self.records["chain_calls"],
"total_tokens": self.records["total_tokens"],
"total_time_seconds": round(self.records["total_time"], 2),
"max_memory_mb": round(self.records["max_memory"], 2),
}
错误处理与重试回调
实现错误处理和自动重试逻辑:
python
from langchain.callbacks.base import BaseCallbackHandler
import time
class ErrorHandlingCallback(BaseCallbackHandler):
def __init__(self, max_retries=3, retry_delay=2):
super().__init__()
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retries = {}
def on_llm_error(self, error, **kwargs) -> None:
run_id = kwargs.get("run_id", "default")
retries = self.retries.get(run_id, 0)
if retries < self.max_retries:
self.retries[run_id] = retries + 1
print(f"LLM调用出错: {error}. 正在进行第{retries + 1}次重试...")
time.sleep(self.retry_delay)
# 这里可以触发重试逻辑
else:
print(f"达到最大重试次数({self.max_retries}),放弃重试: {error}")
6. 最佳实践
何时使用回调
回调系统最适合以下场景:
- 监控与日志记录:跟踪应用程序执行流程,记录关键事件
- 性能分析:收集执行时间、token使用量等性能指标
- 流式处理:实现实时输出,提升用户体验
- 错误处理:捕获和处理执行过程中的错误
- 自定义行为:在特定事件发生时执行自定义逻辑
回调设计建议
设计高效的回调处理程序时,请考虑以下建议:
- 保持轻量:回调处理程序应该执行轻量级操作,避免阻塞主执行流程
- 关注点分离:每个回调处理程序应专注于单一职责
- 错误处理:在回调方法中实现适当的错误处理,避免因回调错误导致主应用崩溃
- 避免状态依赖:尽量减少回调处理程序之间的状态依赖
- 考虑异步:对于IO密集型操作,考虑使用异步回调
总结
LangChain的回调系统提供了一种强大的机制,使开发者能够在应用程序执行的关键点插入自定义逻辑。通过创建自定义回调处理程序,可以实现日志记录、性能监控、流式输出等各种功能,大大增强LLM应用的可观测性和用户体验。
掌握回调系统的使用,对于构建高质量、可维护的LLM应用至关重要。无论是简单的流式输出,还是复杂的性能监控系统,回调机制都能帮助开发者更好地控制和理解应用程序的执行过程。