openai-agent使用本地模型并进行流式输出

一、使用本地/硅基流动等非openai的模型

openai-agent原生兼容openai模型,当我们使用其他模型,例如vllm私有部署的,或者给硅基流动的模型时,需要进行额外处理。

1、首先引入这两个库:

from agents import (

OpenAIChatCompletionsModel,

...)

from openai import AsyncOpenAI

如下图:

2、进行自建连接

external_client = AsyncOpenAI(

api_key="xxxxxxxxxxxxx",

base_url="https://api.siliconflow.cn/v1/",

)

3、修改各个agent所使用的模型名称

这样就可以连接到非openai模型,并且使用该框架了。

二、实现流式输出

如果你是用的openai模型,直接使用框架带的流式输出Runner.run_streamed流式方法即可。如果像我们用的第三方或私有部署的模型,则需要修改框架底层实现源码。

2.1 开始改造api.py文件

首先cd到python-backend目录下,修改api.py

api.py流式接口代码如下:

复制代码
# =========================
# 新增流式输出接口
# =========================

@app.post("/chat/stream")
async def chat_stream_endpoint(req: ChatRequest):
    """
    Streaming chat endpoint for real-time agent responses.
    """
    async def stream_chat_events() -> AsyncGenerator[str, None]:
        try:
            # 初始化和检索对话状态
            is_new = not req.conversation_id or conversation_store.get(req.conversation_id) is None
            if is_new:
                conversation_id: str = uuid4().hex
                ctx = create_initial_context()
                current_agent_name = triage_agent.name
                state: Dict[str, Any] = {
                    "input_items": [],
                    "context": ctx,
                    "current_agent": current_agent_name,
                }
                if req.message.strip() == "":
                    conversation_store.save(conversation_id, state)
                    # 发送初始的响应
                    yield format_sse_event(StreamEvent(
                        type="conversation_start",
                        data={
                            "conversation_id": conversation_id,
                            "current_agent": current_agent_name,
                            "context": ctx.model_dump(),
                            "agents": _build_agents_list(),
                        },
                        timestamp=time.time(),
                        conversation_id=conversation_id,
                        agent_name=current_agent_name
                    ))
                    yield "data: [DONE]\n\n"
                    return
            else:
                conversation_id = req.conversation_id  # type: ignore
                state = conversation_store.get(conversation_id)

            current_agent = _get_agent_by_name(state["current_agent"])
            state["input_items"].append({"content": req.message, "role": "user"})
            old_context = state["context"].model_dump().copy()

            # 发送对话开始事件
            yield format_sse_event(StreamEvent(
                type="conversation_start",
                data={
                    "conversation_id": conversation_id,
                    "current_agent": current_agent.name,
                    "message": req.message,
                },
                timestamp=time.time(),
                conversation_id=conversation_id,
                agent_name=current_agent.name
            ))

            try:
                # 使用逼养的框架原生的Runner.run_streamed流式方法
                result = Runner.run_streamed(
                    current_agent,
                    state["input_items"],
                    context=state["context"],
                    run_config=run_config
                )

                # 这里准备两个变量来记录状态
                current_message_content = ""  # 这个用来一点点拼接AI回复的内容
                collected_items = []  # 用来收集所有发生的事情(工具调用、消息、代理切换等 tool_call_item - 工具调用项、tool_call_output_item - 工具输出项、message_output_item - 消息输出项、handoff_output_item - 代理切换项),后面要用来保存对话历史 

                # 开始流式处理 - 遍历AI代理返回的所有流式事件
                async for event in result.stream_events():
                    # 跳过原始响应事件,处理文本增量
                    if event.type == "raw_response_event":
                        # 检查是不是流的增量事件
                        if hasattr(event, 'response') and hasattr(event.response, 'delta'):
                            delta = event.response.delta
                            if hasattr(delta, 'content') and delta.content:
                                current_message_content += delta.content
                                # 立即发送增量文本
                                yield format_sse_event(StreamEvent(
                                    type="message_delta",
                                    data={
                                        "delta": delta.content,
                                        "content": current_message_content,
                                        "agent": current_agent.name,
                                    },
                                    timestamp=time.time(),
                                    conversation_id=conversation_id,
                                    agent_name=current_agent.name
                                ))
                                # 强制刷新响应
                                await asyncio.sleep(0) 
                        continue

                    elif event.type == "agent_updated_stream_event":
                        yield format_sse_event(StreamEvent(
                            type="agent_updated",
                            data={
                                "agent_name": event.new_agent.name,
                            },
                            timestamp=time.time(),
                            conversation_id=conversation_id,
                            agent_name=event.new_agent.name
                        ))

                    elif event.type == "run_item_stream_event":
                        # 这里处理AI代理运行过程中产生的各种事件
                        if event.item.type == "tool_call_item":
                            # AI在这里要调用工具了
                            collected_items.append(event.item)  # 先收集起来,后面用来保存对话历史
                            tool_name = getattr(event.item.raw_item, "name", "")
                            raw_args = getattr(event.item.raw_item, "arguments", None)
                            tool_args: Any = raw_args
                            # 工具的参数可能是JSON字符串,试着解析一下 逼养的老外真坑啊
                            if isinstance(raw_args, str):
                                try:
                                    tool_args = json.loads(raw_args)
                                except Exception:
                                    pass  # 解析不了就拉倒,用原来的格式

                            yield format_sse_event(StreamEvent(
                                type="tool_call",
                                data={
                                    "tool_name": tool_name,
                                    "tool_args": tool_args,
                                },
                                timestamp=time.time(),
                                conversation_id=conversation_id,
                                agent_name=event.item.agent.name
                            ))

                            # 特殊处理display_seat_map
                            if tool_name == "display_seat_map":
                                yield format_sse_event(StreamEvent(
                                    type="message",
                                    data={
                                        "content": "DISPLAY_SEAT_MAP",
                                        "agent": event.item.agent.name,
                                    },
                                    timestamp=time.time(),
                                    conversation_id=conversation_id,
                                    agent_name=event.item.agent.name
                                ))

                        elif event.item.type == "tool_call_output_item":
                            # 工具执行完,返回结果
                            collected_items.append(event.item)  # 也要收集起来
                            yield format_sse_event(StreamEvent(
                                type="tool_output",
                                data={
                                    "output": str(event.item.output),
                                },
                                timestamp=time.time(),
                                conversation_id=conversation_id,
                                agent_name=event.item.agent.name
                            ))

                        elif event.item.type == "message_output_item":
                            # AI要回复消息
                            text = ItemHelpers.text_message_output(event.item)
                            collected_items.append(event.item)  # 收集消息,要保存到对话历史里

                            if not current_message_content:  # 如果还没有开始累积内容
                                for char in text:
                                    current_message_content += char  # 累积内容
                                    yield format_sse_event(StreamEvent(
                                        type="message_delta",
                                        data={
                                            "delta": char,
                                            "content": current_message_content,
                                            "agent": event.item.agent.name,
                                        },
                                        timestamp=time.time(),
                                        conversation_id=conversation_id,
                                        agent_name=event.item.agent.name
                                    ))
                                    # 强制立即发送
                                    await asyncio.sleep(0)

                            # 发送最终的消息事件
                            yield format_sse_event(StreamEvent(
                                type="message_complete",
                                data={
                                    "content": text,
                                    "agent": event.item.agent.name,
                                },
                                timestamp=time.time(),
                                conversation_id=conversation_id,
                                agent_name=event.item.agent.name
                            ))
                            current_message_content = ""  # 清空内容,准备接收下一条消息

                        elif event.item.type == "handoff_output_item":
                            # 代理切换事件,比如从客服转到座位预订专员
                            collected_items.append(event.item)  # 这个切换事件也要记录下来
                            yield format_sse_event(StreamEvent(
                                type="handoff",
                                data={
                                    "source_agent": event.item.source_agent.name,
                                    "target_agent": event.item.target_agent.name,
                                },
                                timestamp=time.time(),
                                conversation_id=conversation_id,
                                agent_name=event.item.source_agent.name
                            ))
                            current_agent = event.item.target_agent

                # 现在要保存对话状态了,不然下次对话就忘记之前说了什么
                # 把刚才收集到的所有事件转换成对话历史格式
                for item in collected_items:
                    if isinstance(item, MessageOutputItem):
                        # 模型的回复消息要保存到对话历史里
                        text = ItemHelpers.text_message_output(item)
                        state["input_items"].append({"role": "assistant", "content": text})
                    elif isinstance(item, ToolCallItem):
                        # 工具调用暂时不需要保存到对话历史(这个框架应该是自动处理的)
                        pass
                    elif isinstance(item, ToolCallOutputItem):
                        # 工具输出也暂时不需要保存(这个框架应该是自动处理的)
                        pass

                state["current_agent"] = current_agent.name  # 记录当前是哪个代理在处理
                conversation_store.save(conversation_id, state)  # 保存到存储里,下次对话能接着用

                # 现在要把收集到的事件整理成标准格式,返回给前端
                events = []  # 用来存放事件列表
                messages = []  # 用来存放消息列表

                # 遍历刚才收集的所有事件,转换成前端需要的格式
                for item in collected_items:
                    if isinstance(item, MessageOutputItem):
                        # 这是模型的回复消息,要转换成前端能理解的格式
                        text = ItemHelpers.text_message_output(item)
                        messages.append(MessageResponse(content=text, agent=item.agent.name))
                        events.append(AgentEvent(
                            id=uuid4().hex,
                            type="message",
                            agent=item.agent.name,
                            content=text,
                            timestamp=time.time()
                        ))
                    elif isinstance(item, ToolCallItem):
                        # 这是工具调用事件,也要记录下来
                        tool_name = getattr(item.raw_item, "name", "")
                        raw_args = getattr(item.raw_item, "arguments", None)
                        tool_args: Any = raw_args
                        if isinstance(raw_args, str):
                            try:
                                tool_args = json.loads(raw_args)
                            except Exception:
                                pass
                        events.append(AgentEvent(
                            id=uuid4().hex,
                            type="tool_call",
                            agent=item.agent.name,
                            content=tool_name or "",
                            metadata={"tool_args": tool_args},
                            timestamp=time.time()
                        ))
                    elif isinstance(item, ToolCallOutputItem):
                        events.append(AgentEvent(
                            id=uuid4().hex,
                            type="tool_output",
                            agent=item.agent.name,
                            content=str(item.output),
                            metadata={"tool_result": item.output},
                            timestamp=time.time()
                        ))
                    elif isinstance(item, HandoffOutputItem):
                        events.append(AgentEvent(
                            id=uuid4().hex,
                            type="handoff",
                            agent=item.source_agent.name,
                            content=f"{item.source_agent.name} -> {item.target_agent.name}",
                            metadata={
                                "source_agent": item.source_agent.name,
                                "target_agent": item.target_agent.name
                            },
                            timestamp=time.time()
                        ))

                # 构建一下护栏结果
                guardrail_checks = []
                for g in getattr(current_agent, "input_guardrails", []):
                    guardrail_checks.append(GuardrailCheck(
                        id=uuid4().hex,
                        name=_get_guardrail_name(g),
                        input=req.message,
                        reasoning="",
                        passed=True,
                        timestamp=time.time() * 1000,
                    ))

                # 发送最终状态数据
                yield format_sse_event(StreamEvent(
                    type="final_state",
                    data={
                        "conversation_id": conversation_id,
                        "current_agent": current_agent.name,
                        "messages": [msg.model_dump() for msg in messages],
                        "events": [event.model_dump() for event in events],
                        "context": state["context"].model_dump(),
                        "agents": _build_agents_list(),
                        "guardrails": [gc.model_dump() for gc in guardrail_checks],
                    },
                    timestamp=time.time(),
                    conversation_id=conversation_id,
                    agent_name=current_agent.name
                ))

                # 发送完成事件
                yield format_sse_event(StreamEvent(
                    type="conversation_complete",
                    data={
                        "conversation_id": conversation_id,
                        "current_agent": current_agent.name,
                    },
                    timestamp=time.time(),
                    conversation_id=conversation_id,
                    agent_name=current_agent.name
                ))

            except InputGuardrailTripwireTriggered as e:
                # 处理违反护栏的行为
                failed = e.guardrail_result.guardrail
                gr_output = e.guardrail_result.output.output_info
                gr_reasoning = getattr(gr_output, "reasoning", "")
                refusal = "Sorry, I can only answer questions related to airline travel."
                state["input_items"].append({"role": "assistant", "content": refusal})
                conversation_store.save(conversation_id, state)

                yield format_sse_event(StreamEvent(
                    type="guardrail_violation",
                    data={
                        "message": refusal,
                        "agent": current_agent.name,
                    },
                    timestamp=time.time(),
                    conversation_id=conversation_id,
                    agent_name=current_agent.name
                ))

                # 违规事件的护栏结果
                guardrail_checks = []
                for g in getattr(current_agent, "input_guardrails", []):
                    guardrail_checks.append(GuardrailCheck(
                        id=uuid4().hex,
                        name=_get_guardrail_name(g),
                        input=req.message,
                        reasoning=(gr_reasoning if g == failed else ""),
                        passed=(g != failed),
                        timestamp=time.time() * 1000,
                    ))

                # 发送违反护栏的最终状态
                yield format_sse_event(StreamEvent(
                    type="final_state",
                    data={
                        "conversation_id": conversation_id,
                        "current_agent": current_agent.name,
                        "messages": [{"content": refusal, "agent": current_agent.name}],
                        "events": [{
                            "id": uuid4().hex,
                            "type": "guardrail_violation",
                            "agent": current_agent.name,
                            "content": refusal,
                            "timestamp": time.time()
                        }],
                        "context": state["context"].model_dump(),
                        "agents": _build_agents_list(),
                        "guardrails": [gc.model_dump() for gc in guardrail_checks],
                    },
                    timestamp=time.time(),
                    conversation_id=conversation_id,
                    agent_name=current_agent.name
                ))

            yield "data: [DONE]\n\n"

        except Exception as e:
            logger.error(f"Streaming error: {e}")
            yield format_sse_event(StreamEvent(
                type="error",
                data={
                    "error": str(e),
                },
                timestamp=time.time(),
                conversation_id=req.conversation_id or "unknown",
                agent_name="system"
            ))
            yield "data: [DONE]\n\n"

    return StreamingResponse(
        stream_chat_events(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "Content-Type": "text/event-stream",
            "X-Accel-Buffering": "no", 
            "Access-Control-Allow-Origin": "*",
            "Access-Control-Allow-Headers": "*",
        }
    )

2.2 进行框架源码文件备份

copy ".venv\lib\site-packages\agents\models\chatcmpl_stream_handler.py" ".venv\lib\site-packages\agents\models\chatcmpl_stream_handler.py.backup"

我们一会儿就会修改该文件,所以先做备份吗。

2.3 执行脚本修改框架

执行fix_logprobs.py脚本,来替换底层.venv/lib/site-packages/agents/models/chatcmpl_stream_handler.py文件中的逻辑,让流式兼容正常的模型api。

fix_logprobs.py脚本源码如下:

复制代码
#!/usr/bin/env python3
"""
Fix script to add missing logprobs parameter to ResponseTextDeltaEvent
for CGH model compatibility.
"""

import os
import re

def fix_chatcmpl_stream_handler():
    file_path = '.venv/lib/site-packages/agents/models/chatcmpl_stream_handler.py'
    
    # 读取文件内容
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # 使用正则表达式查找并替换
    pattern = r'(yield ResponseTextDeltaEvent\(\s*content_index=state\.text_content_index_and_output\[0\],\s*delta=delta\.content,\s*item_id=FAKE_RESPONSES_ID,\s*)(output_index=)'
    
    replacement = r'\1logprobs=[],  # Add empty logprobs for CGH compatibility\n                    \2'
    
    new_content = re.sub(pattern, replacement, content, flags=re.MULTILINE)
    
    if new_content != content:
        # 写回文件
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(new_content)
        print('File modified successfully!')
        return True
    else:
        print('Pattern not found or already fixed')
        return False

if __name__ == "__main__":
    fix_chatcmpl_stream_handler()

现在在测试就可以流式输出了:

相关推荐
Cl_rown去掉l变成C10 分钟前
第R5周:天气预测
人工智能·python·深度学习·算法·tensorflow2
天下弈星~13 分钟前
变分自编码器VAE的Pytorch实现
图像处理·pytorch·python·深度学习·vae·图像生成·变分自编码器
这里有鱼汤29 分钟前
新型震荡器CyberOsc指标详解及完整策略源码(含图)
python
一百天成为python专家32 分钟前
OpenCV图像平滑处理方法详解
开发语言·人工智能·python·opencv·机器学习·支持向量机·计算机视觉
软测进阶41 分钟前
【Python】Python 函数基本介绍(详细版)
开发语言·python
Q_Q5110082851 小时前
python的滑雪场雪具租赁服务数据可视化分析系统
spring boot·python·信息可视化·django·flask·node.js·php
java1234_小锋1 小时前
一周学会Matplotlib3 Python 数据可视化-绘制散点图(Scatter)
开发语言·python·信息可视化·matplotlib·matplotlib3
MicrosoftReactor2 小时前
技术速递|通过 GitHub Models 在 Actions 中实现项目自动化
ai·自动化·github·copilot
阿群今天学习了吗2 小时前
label studio 服务器端打开+xshell端口转发设置
linux·运维·服务器·笔记·python