Agent平台消息节点输出设计思路

引言

在Agent平台的工作流系统中,我们经常需要将多个节点的输出组合成一条消息展示给用户。比如,一个 AI 分析节点生成流式文本,一个代码执行节点返回计算结果,我们希望将它们按顺序组合成:"AI分析:{流式文本}\n\n计算结果:{标量值}"。

这看似简单的需求,实际上涉及几个核心挑战:

  1. 混合数据源:如何同时处理流式数据(逐块到达)和标量数据(一次性完成)?
  2. 顺序保证:如何确保输出严格按照模板中变量出现的顺序推送?
  3. 依赖管理:如何知道何时可以开始推送?如何等待未就绪的数据?
  4. 解耦设计:如何让消息节点不需要知道具体引用了哪些节点?

本文将深入探讨一个基于编译时-运行时分离拦截机制的流式输出设计方案。

说明 :本文的设计思路参考了 Dify 的 StreamCoordinator 模块设计。Dify 是一个 AI 原生应用开发平台,其工作流引擎中的回答节点(Answer Node)采用了类似的流式输出协调机制,通过拦截上游节点的输出并按模板顺序推送,实现了高效的流式输出体验。本文在此基础上进行了架构优化和实现细节的阐述。


问题场景

假设我们有一个工作流,包含以下节点:

  • LLM 节点:生成 AI 分析,输出是流式的(逐 token 返回)
  • 代码执行节点:执行计算,输出是标量的(一次性返回)
  • 消息节点 :将两者组合,模板为 "AI分析:{llm_output}\n\n计算结果:{code_result}"

我们希望实现的效果是:

swift 复制代码
时间轴 │ 前端接收到的输出
──────┼─────────────────────────────
  T1  │ "AI分析:"
  T2  │ "AI分析:AI"
  T3  │ "AI分析:AI thinks"
  T4  │ "AI分析:AI thinks..."
  T5  │ "AI分析:AI thinks...\n\n计算结果:"
  T6  │ "AI分析:AI thinks...\n\n计算结果:42"

注意几个关键点:

  1. 文本段立即推送"AI分析:" 在第一个变量数据到达前就可以推送
  2. 流式数据实时转发:LLM 的每个 token 到达后立即转发
  3. 等待标量数据"计算结果:" 推送后,必须等待代码节点完成才能推送 42
  4. 严格顺序:即使代码节点先完成,也必须等 LLM 流式输出完成后再推送

设计思路

参考说明:本设计参考了 Dify 的 StreamCoordinator 设计理念,特别是其"按变量在模板中的位置决定流式输出顺序"的核心思想。

核心思想:拦截与协调

在状态驱动的工作流系统中,传统方案可能是:

  • 等待所有上游节点完成:消息节点执行时,等待所有上游节点完成后,一次性组装所有输出
  • 一次性推送:无法实现流式输出,因为必须等待所有节点完成才能推送
  • 时序限制:流式节点(如 LLM)的输出必须等到所有节点完成后才能展示,用户体验差

我们的方案在 workflow 之外额外设计了一个协调层,它独立于 workflow 的运行机制:

  • 独立协调层:协调器(StreamCoordinator)不参与 workflow 的状态流转,而是作为独立的协调组件运行
  • 拦截机制:上游节点在执行过程中主动通知协调器,协调器负责按模板顺序实时组装和推送
  • 解耦设计:消息节点不需要知道上游节点的执行细节,协调层负责处理所有组装逻辑
  • 实时推送:流式数据到达时立即推送,无需等待所有节点完成
  • 灵活组合:可以混合处理流式数据和标量数据,按模板顺序逐段推送

这种设计的优势在于:协调层与 workflow 解耦,workflow 专注于业务逻辑的执行,协调层专注于流式输出的组装和推送,两者通过轻量级的拦截接口交互。

架构分层:编译时 vs 运行时

系统采用编译时-运行时分离的架构:

markdown 复制代码
┌─────────────────────────────────────────────────────────┐
│                   编译时(Graph Build)                  │
├─────────────────────────────────────────────────────────┤
│ 1. 解析模板,提取变量和文本段                            │
│ 2. 建立变量到节点的映射关系                              │
│ 3. 注册拦截规则(哪些节点需要被拦截)                    │
│ 4. 构建反向索引(一个节点可能被多个消息节点引用)        │
└─────────────────────────────────────────────────────────┘
                        │
                        ▼
┌─────────────────────────────────────────────────────────┐
│                  运行时(Graph Execution)                │
├─────────────────────────────────────────────────────────┤
│ 1. 创建协调器实例(维护运行时状态)                       │
│ 2. 上游节点执行时,调用协调器拦截输出                     │
│ 3. 协调器按模板顺序组装并推送                            │
│ 4. 执行完毕后释放协调器                                  │
└─────────────────────────────────────────────────────────┘

为什么这样设计?

  • 编译时:无状态,只存储"订阅关系",可以复用,无需持久化
  • 运行时:有状态,维护缓冲区、会话等,每次执行独立创建,支持序列化(断点续传)
  • 独立于 workflow:协调层不参与 workflow 的状态流转,而是作为独立的协调组件,通过拦截接口与 workflow 交互。这样 workflow 可以专注于业务逻辑,协调层专注于流式输出的组装和推送

核心组件设计

1. 模板解析:从字符串到段序列

模板 "AI分析:{llm_output}\n\n计算结果:{code_result}" 需要解析为段序列:

python 复制代码
[
    TextSegment(text="AI分析:"),
    VariableSegment(variable_name="llm_output", node_id="llm_1", node_variable="text"),
    TextSegment(text="\n\n计算结果:"),
    VariableSegment(variable_name="code_result", node_id="code_1", node_variable="result")
]

设计要点:

  • 文本段:固定文本,可以立即推送
  • 变量段:需要等待上游节点数据,包含变量名、节点ID、输出变量路径

为什么需要段序列?

段序列定义了推送的严格顺序。协调器按索引逐段处理,遇到未就绪的变量段时停止,保证顺序。

2. 注册表(Registry):编译时的订阅关系

注册表在编译时构建,存储以下信息:

python 复制代码
class StreamCoordinatorRegistry:
    # 拦截规则:{message_node_id: InterceptionRule}
    _interception_rules: dict[str, InterceptionRule]
    
    # 反向索引:{node_id: [message_node_ids]}
    # 例如:{"llm_1": ["msg_1", "msg_2"]} 表示 llm_1 被两个消息节点引用
    _node_to_messages: dict[str, list[str]]

关键方法:

  • register_interception():注册拦截规则(消息节点初始化时调用)
  • is_intercepted(node_id):检查节点是否被任何消息节点引用
  • get_message_nodes_for(node_id):获取引用了指定节点的所有消息节点

设计要点:

  • 反向索引:快速判断一个节点是否需要被拦截(避免不必要的拦截开销)
  • 无状态:只存储订阅关系,不存储运行时数据

3. 协调器(Coordinator):运行时的状态管理

协调器在运行时创建,维护以下状态:

python 复制代码
class MessageStreamCoordinator:
    # 活跃会话:{message_node_id: ResponseSession}
    # 每个消息节点对应一个会话,维护当前处理的段索引
    _active_sessions: dict[str, ResponseSession]
    
    # 流式缓冲区:{(node_id, variable): StreamBuffer}
    # 缓存上游节点的流式输出块
    _stream_buffers: dict[tuple[str, str], StreamBuffer]
    
    # 已完成节点:{node_id}
    _completed_nodes: set[str]

核心数据结构:

ResponseSession(响应会话)

python 复制代码
@dataclass
class ResponseSession:
    node_id: str              # 消息节点ID
    rule: InterceptionRule    # 拦截规则(包含段序列)
    index: int = 0            # 当前处理的段索引
    
    def current_segment(self) -> Segment | None:
        """获取当前段"""
        return self.rule.segments[self.index]
    
    def advance(self):
        """推进到下一段"""
        self.index += 1

会话维护了当前处理进度,确保按顺序推进。

StreamBuffer(流式缓冲区)

python 复制代码
@dataclass
class StreamBuffer:
    chunks: deque[StreamChunk]  # 待处理的块队列
    is_closed: bool = False      # 流是否已关闭
    
    def append(self, chunk: StreamChunk):
        """添加数据块"""
        self.chunks.append(chunk)
    
    def pop(self) -> StreamChunk | None:
        """弹出数据块(FIFO)"""
        return self.chunks.popleft() if self.chunks else None

缓冲区采用FIFO队列,保证流式数据的顺序。


执行流程

阶段0:节点开始执行(支持循环)

在节点开始执行时(BaseNode.__call__ 开头),会调用 on_node_start() 进行状态重置:

python 复制代码
# BaseNode.__call__ 开头
await self.on_node_start()  # 通知协调器节点开始执行

# 协调器处理
async def on_node_start(self, node_id: str):
    # 重置该节点的所有buffer,清空上一轮执行的数据
    keys_to_reset = [key for key in self._stream_buffers.keys() if key[0] == node_id]
    for key in keys_to_reset:
        self._stream_buffers[key] = StreamBuffer()
    
    # 从完成节点集合中移除(允许重新流式输出)
    self._completed_nodes.discard(node_id)

设计要点:

  • 循环支持:在循环工作流中,同一节点可能被多次执行。每次执行开始时重置 buffer,确保不会使用上一轮的数据
  • 状态清理:从完成节点集合中移除,允许节点重新进行流式输出
  • 仅重置被引用的节点:只有被消息节点引用的节点才会被重置,避免不必要的开销

阶段1:拦截流式输出

当 LLM 节点产生流式输出时:

python 复制代码
# LLM 节点内部
async for chunk in llm.astream(prompt):
    # 通知协调器
    await coordinator.on_stream_chunk(
        node_id="llm_1",
        variable="text",
        chunk=chunk,
        is_final=False
    )

协调器的处理逻辑:

python 复制代码
async def on_stream_chunk(self, node_id, variable, chunk, is_final):
    # 1. 检查是否被任何消息节点引用
    if not self._registry.is_intercepted(node_id):
        return  # 未被引用,忽略
    
    # 2. 拦截到缓冲区
    selector = (node_id, variable)
    stream_chunk = StreamChunk(selector=selector, chunk=chunk, is_final=is_final)
    self._stream_buffers[selector].append(stream_chunk)
    
    if is_final:
        self._stream_buffers[selector].close()
    
    # 3. 检查是否可以启动会话
    await self._check_and_start_sessions(node_id)
    
    # 4. 尝试推送
    await self._try_push_streams()

关键点:

  • 快速过滤:通过反向索引快速判断是否需要拦截
  • 缓冲机制:流式数据先进入缓冲区,等待推送时机
  • 触发检查:每次收到数据块后,检查是否可以启动或继续推送

阶段2:启动会话

会话启动的时机:当消息节点的第一个依赖节点准备好时

python 复制代码
async def _check_and_start_sessions(self, node_id):
    for message_node_id, rule in self._registry.get_all_rules().items():
        if message_node_id in self._active_sessions:
            continue  # 已启动
        
        # 找到第一个变量段
        first_var_segment = next(
            (seg for seg in rule.segments if isinstance(seg, VariableSegment)),
            None
        )
        
        if first_var_segment and first_var_segment.node_id == node_id:
            selector = (first_var_segment.node_id, first_var_segment.node_variable)
            buffer = self._stream_buffers.get(selector)
            
            # 检查数据是否准备好
            has_stream_data = buffer and (buffer.has_unread() or buffer.is_closed)
            has_scalar_data = self._get_variable_value(...) is not None
            
            if has_stream_data or has_scalar_data:
                # 启动会话
                session = ResponseSession(node_id=message_node_id, rule=rule)
                self._active_sessions[message_node_id] = session
                
                # 立即尝试推送
                await self._try_push_streams()

设计要点:

  • 延迟启动:只有第一个依赖节点准备好时才启动,避免过早推送
  • 支持流式和标量:流式数据从缓冲区读取,标量数据从节点输出中读取

阶段3:按序推送

推送逻辑的核心是按段顺序处理,遇到未就绪的数据时停止

python 复制代码
async def _try_push_streams(self):
    for session_id, session in list(self._active_sessions.items()):
        while not session.is_complete():
            segment = session.current_segment()
            
            if isinstance(segment, TextSegment):
                # 文本段:直接推送
                await stream_writer({
                    "event_type": "MESSAGE_CHUNK",
                    "chunk": segment.text,
                    "is_final": False
                })
                session.advance()  # 推进到下一段
            
            elif isinstance(segment, VariableSegment):
                selector = (segment.node_id, segment.node_variable)
                buffer = self._stream_buffers.get(selector)
                
                if buffer and buffer.has_unread():
                    # 流式数据:从缓冲区读取并推送
                    while buffer.has_unread():
                        chunk = buffer.pop()
                        await stream_writer({
                            "event_type": "MESSAGE_CHUNK",
                            "chunk": chunk.chunk,
                            "is_final": False
                        })
                    
                    # 如果流已关闭,推进到下一段
                    if buffer.is_closed:
                        session.advance()
                    else:
                        break  # 等待更多数据
                
                else:
                    # 尝试从节点输出获取标量值
                    value = self._get_variable_value(...)
                    if value is not None:
                        # 标量值:直接推送
                        await stream_writer({
                            "event_type": "MESSAGE_CHUNK",
                            "chunk": str(value),
                            "is_final": False
                        })
                        session.advance()
                    else:
                        # 数据未准备好,停止推送
                        break
        
        # 会话完成
        if session.is_complete():
            await stream_writer({
                "event_type": "MESSAGE_CHUNK",
                "chunk": "",
                "is_final": True
            })
            del self._active_sessions[session_id]

关键逻辑:

  1. 文本段立即推送:不依赖任何数据,直接推送并推进
  2. 流式数据实时转发:从缓冲区读取所有可用数据,流关闭后推进
  3. 标量数据等待:如果缓冲区为空,尝试从节点输出读取,未就绪时停止
  4. 阻塞机制 :遇到未就绪的数据时,break 跳出循环,等待下次触发

阶段4:节点完成通知

当标量节点(如代码执行节点)完成时:

python 复制代码
async def on_node_completed(self, node_id, outputs):
    # 记录节点完成
    self._completed_nodes.add(node_id)
    
    # 检查是否可以启动会话
    await self._check_and_start_sessions(node_id)
    
    # 尝试推送(可能有新的标量数据可用)
    await self._try_push_streams()

设计要点:

  • 统一入口:流式节点和标量节点都通过协调器通知
  • 触发推送:节点完成可能触发新的推送(如果某个会话正在等待该节点的数据)

完整执行时序

让我们通过一个完整例子理解整个流程:

场景配置

swift 复制代码
模板: "AI分析:{llm_output}\n\n计算结果:{code_result}"
引用关系:
  - llm_output → llm_1.text (流式)
  - code_result → code_1.result (标量)

执行时序图

swift 复制代码
时间 │ LLM节点          │ 代码节点        │ 协调器操作                    │ 前端接收
─────┼──────────────────┼────────────────┼──────────────────────────────┼─────────────
T0   │                  │                │ 初始化                         │
     │                  │                │                               │
T1   │ 输出 "AI"        │                │ 1. 拦截到缓冲区                │
     │                  │                │ 2. 检查依赖(第一个是llm_1)  │
     │                  │                │ 3. 启动会话                   │
     │                  │                │ 4. 推送 TextSegment: "AI分析:"│ "AI分析:"
     │                  │                │ 5. 推送 VariableSegment: "AI" │ "AI分析:AI"
     │                  │                │                               │
T2   │ 输出 " thinks"   │                │ 1. 拦截到缓冲区                │
     │                  │                │ 2. 推送: " thinks"            │ "AI分析:AI thinks"
     │                  │                │                               │
T3   │ 输出 "..."       │                │ 1. 拦截到缓冲区                │
     │                  │                │ 2. 推送: "..."                │ "AI分析:AI thinks..."
     │                  │                │                               │
T4   │ 流关闭           │                │ 1. 标记流关闭                  │
     │                  │                │ 2. 推进到下一段                │
     │                  │                │ 3. 推送 TextSegment: "\n\n计算结果:"│ "AI分析:AI thinks...\n\n计算结果:"
     │                  │                │ 4. 检查 code_result(未就绪)  │
     │                  │                │ 5. 停止推送,等待              │
     │                  │                │                               │
T5   │                  │ 完成,返回42   │ 1. 记录节点完成                │
     │                  │                │ 2. 尝试推送                    │
     │                  │                │ 3. 从输出获取 code_1.result=42 │
     │                  │                │ 4. 推送: "42"                  │ "AI分析:AI thinks...\n\n计算结果:42"
     │                  │                │ 5. 会话完成,发送 is_final=True│

设计优势

1. 解耦设计

  • 变量名与节点ID解耦 :消息节点使用语义化的变量名(如 llm_output),通过 reference_metadata 映射到实际节点ID
  • 前端无需知道节点细节:前端只接收消息节点的输出,不需要知道具体引用了哪些节点

2. 性能优化

  • 快速过滤:通过反向索引快速判断是否需要拦截,避免不必要的开销
  • 按需启动:只有第一个依赖节点准备好时才启动会话,避免过早推送

3. 灵活性

  • 支持混合数据源:同时处理流式数据和标量数据
  • 支持深层属性访问 :变量路径支持 response.data.user.name 等深层访问
  • 支持多消息节点引用同一节点:一个节点的输出可以被多个消息节点引用

4. 可维护性

  • 编译时-运行时分离:职责清晰,易于测试和维护
  • 状态序列化支持:支持断点续传(序列化协调器状态)

关键设计模式

1. 拦截模式(Interception Pattern)

上游节点不直接推送,而是通知协调器,由协调器决定何时、如何推送。这实现了关注点分离

  • 上游节点:专注于生成数据
  • 协调器:专注于组装和推送逻辑

2. 会话模式(Session Pattern)

每个消息节点对应一个会话,维护处理进度。这实现了状态隔离

  • 多个消息节点可以并行处理
  • 每个会话独立推进,互不干扰

3. 缓冲区模式(Buffer Pattern)

流式数据先进入缓冲区,等待推送时机。这实现了异步解耦

  • 数据生成和数据消费解耦
  • 支持背压(backpressure)控制

4. 状态机模式(State Machine Pattern)

会话通过段索引推进,本质上是一个状态机:

ini 复制代码
状态0: 处理 TextSegment[0]
  ↓
状态1: 处理 VariableSegment[1](等待数据)
  ↓
状态2: 处理 TextSegment[2]
  ↓
状态3: 处理 VariableSegment[3](等待数据)
  ↓
完成

扩展思考

1. 如何支持循环执行?

在工作流中,可能存在循环结构,同一个节点可能被多次执行。为了支持这种情况,协调器在节点开始执行时会进行状态重置:

python 复制代码
async def on_node_start(self, node_id: str):
    """节点开始执行时调用(在BaseNode.__call__开头调用)"""
    # 重置该节点的所有buffer,清空上一轮执行的数据
    keys_to_reset = [key for key in self._stream_buffers.keys() if key[0] == node_id]
    for key in keys_to_reset:
        self._stream_buffers[key] = StreamBuffer()
    
    # 从完成节点集合中移除(允许重新流式输出)
    self._completed_nodes.discard(node_id)

设计要点:

  • 节点级重置:每次节点开始执行时,重置该节点的所有 buffer,确保不会使用上一轮的数据
  • 状态清理:从完成节点集合中移除,允许节点重新进行流式输出
  • 会话保持:消息节点的会话状态保持不变,因为消息节点本身不会被循环执行

示例场景:

ini 复制代码
工作流: LLM节点 → 条件判断 → [循环] → LLM节点 → 消息节点

在循环中,LLM 节点可能被多次执行。每次执行开始时,协调器会重置 LLM 节点的 buffer,确保每次循环都能正确进行流式输出。

2. 如何处理循环依赖?

如果消息节点引用了多个节点,且这些节点之间存在依赖关系,协调器会自然处理:

  • 按模板顺序推进,遇到未就绪的数据时停止
  • 节点完成后触发检查,继续推进

3. 如何处理错误?

如果上游节点执行失败:

  • 协调器可以从节点输出中读取错误信息
  • 或者等待超时后推送错误提示

4. 如何支持更复杂的模板语法?

当前方案支持简单的变量替换。如果需要支持条件判断、循环等,可以:

  • 扩展模板解析器,生成更复杂的段序列
  • 在推送时根据条件决定是否跳过某些段

5. 如何优化性能?

  • 批量推送:将多个小块合并为一个大块推送(减少网络开销)
  • 预取机制:提前启动会话,预取第一个变量段的数据
  • 缓存机制:缓存已完成的节点输出,避免重复查询

总结

消息节点流式输出的核心设计思路可以概括为:

  1. 编译时构建订阅关系:解析模板,建立变量到节点的映射,注册拦截规则
  2. 运行时拦截和协调:上游节点通知协调器,协调器按模板顺序组装并推送
  3. 状态管理:通过会话维护处理进度,通过缓冲区缓存流式数据
  4. 阻塞机制:遇到未就绪的数据时停止推送,等待数据就绪后继续

这种设计实现了解耦、灵活、高效的流式输出机制,适用于复杂的工作流场景。


实现要点总结

基于本文的设计思路,实现一个流式输出系统需要关注以下几个核心接口:

编译时接口

python 复制代码
# 注册拦截规则(在消息节点初始化时调用)
registry.register_interception(
    message_node_id="msg_1",
    template="AI分析:{llm_output}\n\n计算结果:{code_result}",
    variable_to_node_map={
        "llm_output": ("llm_1", "text"),
        "code_result": ("code_1", "result")
    }
)

运行时接口

python 复制代码
# 节点开始执行时(支持循环)
await coordinator.on_node_start(node_id="llm_1")

# 拦截流式输出(在流式节点产生数据时调用)
await coordinator.on_stream_chunk(
    node_id="llm_1",
    variable="text",
    chunk="AI",
    is_final=False
)

# 通知节点完成(在节点执行完成时调用)
await coordinator.on_node_completed(
    node_id="code_1",
    outputs={"result": 42}
)

核心数据结构

以下是核心类的简化实现,展示了关键的数据结构和逻辑:

1. StreamBuffer(流式缓冲区)

python 复制代码
from collections import deque
from dataclasses import dataclass, field

@dataclass
class StreamChunk:
    """流式数据块"""
    selector: tuple[str, str]  # (node_id, variable_name)
    chunk: str                 # 增量文本
    is_final: bool             # 是否为最后一块

@dataclass
class StreamBuffer:
    """流式缓冲区,用于缓存上游节点的流式输出"""
    chunks: deque[StreamChunk] = field(default_factory=deque)  # FIFO队列
    is_closed: bool = False  # 流是否已关闭
    
    def append(self, chunk: StreamChunk):
        """添加数据块到队列尾部"""
        self.chunks.append(chunk)
    
    def pop(self) -> StreamChunk | None:
        """从队列头部弹出数据块(FIFO)"""
        return self.chunks.popleft() if self.chunks else None
    
    def has_unread(self) -> bool:
        """检查是否有未读数据"""
        return len(self.chunks) > 0
    
    def close(self):
        """标记流已关闭"""
        self.is_closed = True

设计要点:

  • 使用 deque 实现 FIFO 队列,保证流式数据的顺序
  • is_closed 标记流是否结束,用于判断是否可以推进到下一段

2. ResponseSession(响应会话)

python 复制代码
from dataclasses import dataclass

@dataclass
class Segment:
    """模板段基类"""
    pass

@dataclass
class TextSegment(Segment):
    """文本段:固定文本,可以直接推送"""
    text: str

@dataclass
class VariableSegment(Segment):
    """变量段:需要等待上游节点数据"""
    variable_name: str  # 模板中的变量名(如 llm_output)
    node_id: str        # 实际的节点ID(如 llm_1)
    node_variable: str  # 节点的输出变量名(如 text)

@dataclass
class InterceptionRule:
    """拦截规则:记录消息节点的配置"""
    message_node_id: str
    template: str
    segments: list[Segment]  # 解析后的模板段序列
    variable_to_node_map: dict[str, tuple[str, str]]  # 变量映射

@dataclass
class ResponseSession:
    """响应会话:维护单个消息节点的流式输出状态"""
    node_id: str              # 消息节点ID
    rule: InterceptionRule    # 拦截规则(包含段序列)
    index: int = 0            # 当前处理的段索引
    
    def is_complete(self) -> bool:
        """检查是否处理完所有段"""
        return self.index >= len(self.rule.segments)
    
    def current_segment(self) -> Segment | None:
        """获取当前需要处理的段"""
        if self.is_complete():
            return None
        return self.rule.segments[self.index]
    
    def advance(self):
        """推进到下一段"""
        self.index += 1

设计要点:

  • 通过 index 维护处理进度,确保按顺序推进
  • segments 定义了推送的严格顺序(文本段和变量段交替)

3. StreamCoordinatorRegistry(注册表)

python 复制代码
from collections import defaultdict

class StreamCoordinatorRegistry:
    """流式输出协调器注册表(编译时,无状态)"""
    
    def __init__(self):
        # 拦截规则:{message_node_id: InterceptionRule}
        self._interception_rules: dict[str, InterceptionRule] = {}
        
        # 反向索引:{node_id: [message_node_ids]}
        # 例如:{"llm_1": ["msg_1", "msg_2"]} 表示 llm_1 被两个消息节点引用
        self._node_to_messages: dict[str, list[str]] = defaultdict(list)
    
    def register_interception(
        self,
        message_node_id: str,
        template: str,
        variable_to_node_map: dict[str, tuple[str, str]]
    ):
        """
        注册拦截规则(在消息节点初始化时调用)
        
        Args:
            message_node_id: 消息节点ID
            template: 模板字符串,如 "AI: {llm_output}\\n结果: {code_result}"
            variable_to_node_map: 变量映射,如 {
                "llm_output": ("llm_1", "text"),
                "code_result": ("code_1", "result")
            }
        """
        # 解析模板为段序列
        segments = self._parse_template(template, variable_to_node_map)
        
        # 创建拦截规则
        rule = InterceptionRule(
            message_node_id=message_node_id,
            template=template,
            segments=segments,
            variable_to_node_map=variable_to_node_map
        )
        
        # 存储规则
        self._interception_rules[message_node_id] = rule
        
        # 建立反向索引(快速查询哪些节点需要被拦截)
        for node_id in rule.referenced_nodes:
            self._node_to_messages[node_id].append(message_node_id)
    
    def _parse_template(self, template: str, variable_to_node_map: dict) -> list[Segment]:
        """解析模板为段序列"""
        segments: list[Segment] = []
        # ... 解析逻辑:按变量位置分割模板,生成 TextSegment 和 VariableSegment
        return segments
    
    def is_intercepted(self, node_id: str) -> bool:
        """检查节点是否被任何消息节点引用(快速过滤)"""
        return node_id in self._node_to_messages
    
    def get_all_rules(self) -> dict[str, InterceptionRule]:
        """获取所有拦截规则"""
        return self._interception_rules.copy()

设计要点:

  • 反向索引:快速判断节点是否需要被拦截,避免不必要的处理
  • 无状态:只存储订阅关系,不存储运行时数据,可复用

4. MessageStreamCoordinator(协调器)

python 复制代码
import asyncio
from collections import defaultdict

class MessageStreamCoordinator:
    """消息流式输出协调器(运行时,有状态)"""
    
    def __init__(self, registry: StreamCoordinatorRegistry):
        self._registry = registry
        self._lock = asyncio.Lock()  # 异步锁,保证线程安全
        
        # 活跃会话:{message_node_id: ResponseSession}
        # 每个消息节点对应一个会话,维护处理进度
        self._active_sessions: dict[str, ResponseSession] = {}
        
        # 流式缓冲区:{(node_id, variable): StreamBuffer}
        # 缓存上游节点的流式输出
        self._stream_buffers: dict[tuple[str, str], StreamBuffer] = defaultdict(StreamBuffer)
        
        # 已完成节点:{node_id}
        self._completed_nodes: set[str] = set()
    
    async def on_node_start(self, node_id: str):
        """节点开始执行时调用(支持循环)"""
        async with self._lock:
            if not self._registry.is_intercepted(node_id):
                return  # 未被引用,无需处理
            
            # 重置该节点的所有buffer(清空上一轮执行的数据)
            keys_to_reset = [key for key in self._stream_buffers.keys() if key[0] == node_id]
            for key in keys_to_reset:
                self._stream_buffers[key] = StreamBuffer()
            
            # 从完成节点集合中移除(允许重新流式输出)
            self._completed_nodes.discard(node_id)
    
    async def on_stream_chunk(
        self,
        node_id: str,
        variable: str,
        chunk: str,
        is_final: bool = False
    ):
        """处理流式数据块(在流式节点产生数据时调用)"""
        async with self._lock:
            # 快速过滤:检查是否被任何消息节点引用
            if not self._registry.is_intercepted(node_id):
                return  # 未被引用,忽略
            
            # 拦截到缓冲区
            selector = (node_id, variable)
            stream_chunk = StreamChunk(selector=selector, chunk=chunk, is_final=is_final)
            self._stream_buffers[selector].append(stream_chunk)
            
            if is_final:
                self._stream_buffers[selector].close()
            
            # 检查是否可以启动会话,并尝试推送
            await self._check_and_start_sessions(node_id)
            await self._try_push_streams()
    
    async def on_node_completed(self, node_id: str, outputs: dict):
        """处理节点完成事件(在节点执行完成时调用)"""
        async with self._lock:
            # 记录节点完成
            self._completed_nodes.add(node_id)
            
            # 检查是否可以启动会话,并尝试推送(可能有新的标量数据可用)
            await self._check_and_start_sessions(node_id)
            await self._try_push_streams()
    
    async def _check_and_start_sessions(self, node_id: str):
        """检查并启动可以开始的会话"""
        # 遍历所有注册的消息节点
        for message_node_id, rule in self._registry.get_all_rules().items():
            if message_node_id in self._active_sessions:
                continue  # 已启动,跳过
            
            # 找到第一个变量段
            first_var_segment = next(
                (seg for seg in rule.segments if isinstance(seg, VariableSegment)),
                None
            )
            
            # 如果第一个变量段对应的节点已准备好,启动会话
            if first_var_segment and first_var_segment.node_id == node_id:
                selector = (first_var_segment.node_id, first_var_segment.node_variable)
                buffer = self._stream_buffers.get(selector)
                
                # 检查数据是否准备好(流式数据或标量数据)
                has_stream_data = buffer and (buffer.has_unread() or buffer.is_closed)
                has_scalar_data = self._get_variable_value(node_id, ...) is not None
                
                if has_stream_data or has_scalar_data:
                    # 启动会话
                    session = ResponseSession(node_id=message_node_id, rule=rule)
                    self._active_sessions[message_node_id] = session
                    
                    # 立即尝试推送
                    await self._try_push_streams()
    
    async def _try_push_streams(self):
        """尝试推送流式数据(核心逻辑)"""
        # 遍历所有活跃会话
        for session_id, session in list(self._active_sessions.items()):
            while not session.is_complete():
                segment = session.current_segment()
                if segment is None:
                    break
                
                if isinstance(segment, TextSegment):
                    # 文本段:直接推送
                    await self._push_chunk(session.node_id, segment.text, is_final=False)
                    session.advance()  # 推进到下一段
                
                elif isinstance(segment, VariableSegment):
                    selector = (segment.node_id, segment.node_variable)
                    buffer = self._stream_buffers.get(selector)
                    
                    if buffer and buffer.has_unread():
                        # 流式数据:从缓冲区读取并推送
                        while buffer.has_unread():
                            chunk = buffer.pop()
                            await self._push_chunk(session.node_id, chunk.chunk, is_final=False)
                        
                        # 如果流已关闭,推进到下一段
                        if buffer.is_closed:
                            session.advance()
                        else:
                            break  # 等待更多数据
                    else:
                        # 尝试从节点输出获取标量值
                        value = self._get_variable_value(...)
                        if value is not None:
                            await self._push_chunk(session.node_id, str(value), is_final=False)
                            session.advance()
                        else:
                            break  # 数据未准备好,停止推送
            
            # 会话完成
            if session.is_complete():
                await self._push_chunk(session.node_id, "", is_final=True)
                del self._active_sessions[session_id]
    
    async def _push_chunk(self, node_id: str, chunk: str, is_final: bool):
        """推送数据块到前端(通过事件系统)"""
        # 通过事件系统发送 MESSAGE_CHUNK 事件
        # stream_writer(CustomWorkflowEvent(...))
        pass

设计要点:

  • 异步锁:保证多线程/多协程环境下的线程安全
  • 会话管理:每个消息节点对应一个会话,独立推进
  • 阻塞机制:遇到未就绪的数据时停止推送,等待数据就绪后继续
  • 混合处理:同时支持流式数据和标量数据

希望本文的设计思路和代码示例能为你构建类似的流式输出系统提供参考。

相关推荐
申阳1 小时前
Day 20:开源个人项目时的一些注意事项
前端·后端·程序员
shepherd1111 小时前
一文带你掌握MyBatis-Plus代码生成器:从入门到精通,实现原理与自定义模板全解析
java·spring boot·后端
盼哥PyAI实验室1 小时前
【超详细教程】Python 连接 MySQL 全流程实战
python·mysql·oracle
程序员西西1 小时前
作为开发,你真的懂 OOM 吗?实测 3 种场景,搞懂 JVM 崩溃真相
java·后端
棒棒的皮皮1 小时前
【OpenCV】Python图像处理之按位逻辑运算
图像处理·python·opencv·计算机视觉
小周在成长1 小时前
Java 内部类指南
后端
拾贰_C1 小时前
【ML|DL |python|pytorch|】基础学习
pytorch·python·学习
橘子编程1 小时前
仓颉语言变量与表达式解析
java·linux·服务器·开发语言·数据库·python·mysql
开心就好20251 小时前
Fiddler抓包与接口调试实战,HTTPHTTPS配置、代理设置与移动端抓包详解
后端