一、使用本地/硅基流动等非openai的模型
openai-agent原生兼容openai模型,当我们使用其他模型,例如vllm私有部署的,或者给硅基流动的模型时,需要进行额外处理。
1、首先引入这两个库:
from agents import (
OpenAIChatCompletionsModel,
...)
和
from openai import AsyncOpenAI
如下图:
2、进行自建连接
external_client = AsyncOpenAI(
api_key="xxxxxxxxxxxxx",
base_url="https://api.siliconflow.cn/v1/",
)
3、修改各个agent所使用的模型名称
这样就可以连接到非openai模型,并且使用该框架了。
二、实现流式输出
如果你是用的openai模型,直接使用框架带的流式输出Runner.run_streamed流式方法即可。如果像我们用的第三方或私有部署的模型,则需要修改框架底层实现源码。
2.1 开始改造api.py文件
首先cd到python-backend目录下,修改api.py。
api.py流式接口代码如下:
# =========================
# 新增流式输出接口
# =========================
@app.post("/chat/stream")
async def chat_stream_endpoint(req: ChatRequest):
"""
Streaming chat endpoint for real-time agent responses.
"""
async def stream_chat_events() -> AsyncGenerator[str, None]:
try:
# 初始化和检索对话状态
is_new = not req.conversation_id or conversation_store.get(req.conversation_id) is None
if is_new:
conversation_id: str = uuid4().hex
ctx = create_initial_context()
current_agent_name = triage_agent.name
state: Dict[str, Any] = {
"input_items": [],
"context": ctx,
"current_agent": current_agent_name,
}
if req.message.strip() == "":
conversation_store.save(conversation_id, state)
# 发送初始的响应
yield format_sse_event(StreamEvent(
type="conversation_start",
data={
"conversation_id": conversation_id,
"current_agent": current_agent_name,
"context": ctx.model_dump(),
"agents": _build_agents_list(),
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent_name
))
yield "data: [DONE]\n\n"
return
else:
conversation_id = req.conversation_id # type: ignore
state = conversation_store.get(conversation_id)
current_agent = _get_agent_by_name(state["current_agent"])
state["input_items"].append({"content": req.message, "role": "user"})
old_context = state["context"].model_dump().copy()
# 发送对话开始事件
yield format_sse_event(StreamEvent(
type="conversation_start",
data={
"conversation_id": conversation_id,
"current_agent": current_agent.name,
"message": req.message,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent.name
))
try:
# 使用逼养的框架原生的Runner.run_streamed流式方法
result = Runner.run_streamed(
current_agent,
state["input_items"],
context=state["context"],
run_config=run_config
)
# 这里准备两个变量来记录状态
current_message_content = "" # 这个用来一点点拼接AI回复的内容
collected_items = [] # 用来收集所有发生的事情(工具调用、消息、代理切换等 tool_call_item - 工具调用项、tool_call_output_item - 工具输出项、message_output_item - 消息输出项、handoff_output_item - 代理切换项),后面要用来保存对话历史
# 开始流式处理 - 遍历AI代理返回的所有流式事件
async for event in result.stream_events():
# 跳过原始响应事件,处理文本增量
if event.type == "raw_response_event":
# 检查是不是流的增量事件
if hasattr(event, 'response') and hasattr(event.response, 'delta'):
delta = event.response.delta
if hasattr(delta, 'content') and delta.content:
current_message_content += delta.content
# 立即发送增量文本
yield format_sse_event(StreamEvent(
type="message_delta",
data={
"delta": delta.content,
"content": current_message_content,
"agent": current_agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent.name
))
# 强制刷新响应
await asyncio.sleep(0)
continue
elif event.type == "agent_updated_stream_event":
yield format_sse_event(StreamEvent(
type="agent_updated",
data={
"agent_name": event.new_agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.new_agent.name
))
elif event.type == "run_item_stream_event":
# 这里处理AI代理运行过程中产生的各种事件
if event.item.type == "tool_call_item":
# AI在这里要调用工具了
collected_items.append(event.item) # 先收集起来,后面用来保存对话历史
tool_name = getattr(event.item.raw_item, "name", "")
raw_args = getattr(event.item.raw_item, "arguments", None)
tool_args: Any = raw_args
# 工具的参数可能是JSON字符串,试着解析一下 逼养的老外真坑啊
if isinstance(raw_args, str):
try:
tool_args = json.loads(raw_args)
except Exception:
pass # 解析不了就拉倒,用原来的格式
yield format_sse_event(StreamEvent(
type="tool_call",
data={
"tool_name": tool_name,
"tool_args": tool_args,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.item.agent.name
))
# 特殊处理display_seat_map
if tool_name == "display_seat_map":
yield format_sse_event(StreamEvent(
type="message",
data={
"content": "DISPLAY_SEAT_MAP",
"agent": event.item.agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.item.agent.name
))
elif event.item.type == "tool_call_output_item":
# 工具执行完,返回结果
collected_items.append(event.item) # 也要收集起来
yield format_sse_event(StreamEvent(
type="tool_output",
data={
"output": str(event.item.output),
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.item.agent.name
))
elif event.item.type == "message_output_item":
# AI要回复消息
text = ItemHelpers.text_message_output(event.item)
collected_items.append(event.item) # 收集消息,要保存到对话历史里
if not current_message_content: # 如果还没有开始累积内容
for char in text:
current_message_content += char # 累积内容
yield format_sse_event(StreamEvent(
type="message_delta",
data={
"delta": char,
"content": current_message_content,
"agent": event.item.agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.item.agent.name
))
# 强制立即发送
await asyncio.sleep(0)
# 发送最终的消息事件
yield format_sse_event(StreamEvent(
type="message_complete",
data={
"content": text,
"agent": event.item.agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.item.agent.name
))
current_message_content = "" # 清空内容,准备接收下一条消息
elif event.item.type == "handoff_output_item":
# 代理切换事件,比如从客服转到座位预订专员
collected_items.append(event.item) # 这个切换事件也要记录下来
yield format_sse_event(StreamEvent(
type="handoff",
data={
"source_agent": event.item.source_agent.name,
"target_agent": event.item.target_agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=event.item.source_agent.name
))
current_agent = event.item.target_agent
# 现在要保存对话状态了,不然下次对话就忘记之前说了什么
# 把刚才收集到的所有事件转换成对话历史格式
for item in collected_items:
if isinstance(item, MessageOutputItem):
# 模型的回复消息要保存到对话历史里
text = ItemHelpers.text_message_output(item)
state["input_items"].append({"role": "assistant", "content": text})
elif isinstance(item, ToolCallItem):
# 工具调用暂时不需要保存到对话历史(这个框架应该是自动处理的)
pass
elif isinstance(item, ToolCallOutputItem):
# 工具输出也暂时不需要保存(这个框架应该是自动处理的)
pass
state["current_agent"] = current_agent.name # 记录当前是哪个代理在处理
conversation_store.save(conversation_id, state) # 保存到存储里,下次对话能接着用
# 现在要把收集到的事件整理成标准格式,返回给前端
events = [] # 用来存放事件列表
messages = [] # 用来存放消息列表
# 遍历刚才收集的所有事件,转换成前端需要的格式
for item in collected_items:
if isinstance(item, MessageOutputItem):
# 这是模型的回复消息,要转换成前端能理解的格式
text = ItemHelpers.text_message_output(item)
messages.append(MessageResponse(content=text, agent=item.agent.name))
events.append(AgentEvent(
id=uuid4().hex,
type="message",
agent=item.agent.name,
content=text,
timestamp=time.time()
))
elif isinstance(item, ToolCallItem):
# 这是工具调用事件,也要记录下来
tool_name = getattr(item.raw_item, "name", "")
raw_args = getattr(item.raw_item, "arguments", None)
tool_args: Any = raw_args
if isinstance(raw_args, str):
try:
tool_args = json.loads(raw_args)
except Exception:
pass
events.append(AgentEvent(
id=uuid4().hex,
type="tool_call",
agent=item.agent.name,
content=tool_name or "",
metadata={"tool_args": tool_args},
timestamp=time.time()
))
elif isinstance(item, ToolCallOutputItem):
events.append(AgentEvent(
id=uuid4().hex,
type="tool_output",
agent=item.agent.name,
content=str(item.output),
metadata={"tool_result": item.output},
timestamp=time.time()
))
elif isinstance(item, HandoffOutputItem):
events.append(AgentEvent(
id=uuid4().hex,
type="handoff",
agent=item.source_agent.name,
content=f"{item.source_agent.name} -> {item.target_agent.name}",
metadata={
"source_agent": item.source_agent.name,
"target_agent": item.target_agent.name
},
timestamp=time.time()
))
# 构建一下护栏结果
guardrail_checks = []
for g in getattr(current_agent, "input_guardrails", []):
guardrail_checks.append(GuardrailCheck(
id=uuid4().hex,
name=_get_guardrail_name(g),
input=req.message,
reasoning="",
passed=True,
timestamp=time.time() * 1000,
))
# 发送最终状态数据
yield format_sse_event(StreamEvent(
type="final_state",
data={
"conversation_id": conversation_id,
"current_agent": current_agent.name,
"messages": [msg.model_dump() for msg in messages],
"events": [event.model_dump() for event in events],
"context": state["context"].model_dump(),
"agents": _build_agents_list(),
"guardrails": [gc.model_dump() for gc in guardrail_checks],
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent.name
))
# 发送完成事件
yield format_sse_event(StreamEvent(
type="conversation_complete",
data={
"conversation_id": conversation_id,
"current_agent": current_agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent.name
))
except InputGuardrailTripwireTriggered as e:
# 处理违反护栏的行为
failed = e.guardrail_result.guardrail
gr_output = e.guardrail_result.output.output_info
gr_reasoning = getattr(gr_output, "reasoning", "")
refusal = "Sorry, I can only answer questions related to airline travel."
state["input_items"].append({"role": "assistant", "content": refusal})
conversation_store.save(conversation_id, state)
yield format_sse_event(StreamEvent(
type="guardrail_violation",
data={
"message": refusal,
"agent": current_agent.name,
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent.name
))
# 违规事件的护栏结果
guardrail_checks = []
for g in getattr(current_agent, "input_guardrails", []):
guardrail_checks.append(GuardrailCheck(
id=uuid4().hex,
name=_get_guardrail_name(g),
input=req.message,
reasoning=(gr_reasoning if g == failed else ""),
passed=(g != failed),
timestamp=time.time() * 1000,
))
# 发送违反护栏的最终状态
yield format_sse_event(StreamEvent(
type="final_state",
data={
"conversation_id": conversation_id,
"current_agent": current_agent.name,
"messages": [{"content": refusal, "agent": current_agent.name}],
"events": [{
"id": uuid4().hex,
"type": "guardrail_violation",
"agent": current_agent.name,
"content": refusal,
"timestamp": time.time()
}],
"context": state["context"].model_dump(),
"agents": _build_agents_list(),
"guardrails": [gc.model_dump() for gc in guardrail_checks],
},
timestamp=time.time(),
conversation_id=conversation_id,
agent_name=current_agent.name
))
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
yield format_sse_event(StreamEvent(
type="error",
data={
"error": str(e),
},
timestamp=time.time(),
conversation_id=req.conversation_id or "unknown",
agent_name="system"
))
yield "data: [DONE]\n\n"
return StreamingResponse(
stream_chat_events(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
}
)
2.2 进行框架源码文件备份
copy ".venv\lib\site-packages\agents\models\chatcmpl_stream_handler.py" ".venv\lib\site-packages\agents\models\chatcmpl_stream_handler.py.backup"
我们一会儿就会修改该文件,所以先做备份吗。
2.3 执行脚本修改框架
执行fix_logprobs.py脚本,来替换底层.venv/lib/site-packages/agents/models/chatcmpl_stream_handler.py文件中的逻辑,让流式兼容正常的模型api。
fix_logprobs.py脚本源码如下:
#!/usr/bin/env python3
"""
Fix script to add missing logprobs parameter to ResponseTextDeltaEvent
for CGH model compatibility.
"""
import os
import re
def fix_chatcmpl_stream_handler():
file_path = '.venv/lib/site-packages/agents/models/chatcmpl_stream_handler.py'
# 读取文件内容
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 使用正则表达式查找并替换
pattern = r'(yield ResponseTextDeltaEvent\(\s*content_index=state\.text_content_index_and_output\[0\],\s*delta=delta\.content,\s*item_id=FAKE_RESPONSES_ID,\s*)(output_index=)'
replacement = r'\1logprobs=[], # Add empty logprobs for CGH compatibility\n \2'
new_content = re.sub(pattern, replacement, content, flags=re.MULTILINE)
if new_content != content:
# 写回文件
with open(file_path, 'w', encoding='utf-8') as f:
f.write(new_content)
print('File modified successfully!')
return True
else:
print('Pattern not found or already fixed')
return False
if __name__ == "__main__":
fix_chatcmpl_stream_handler()
现在在测试就可以流式输出了: