【RL】async_engine 远离

好的,我们来深入剖析这段代码,它是 consumer 函数中处理单个 LLM 生成请求的核心循环,非常精妙地融合了流式处理、超时控制、并发操作和中止逻辑

宏观目标

这段代码的目标是:以流式(streaming)的方式从 sglang 的异步生成器中获取文本块(chunk),同时在循环的每一轮都检查是否有中止信号,并处理可能发生的超时。


关键组件和模式解析

1. generator = llm.tokenizer_manager.generate_request(obj, None)
  • generate_request 被调用时,由于 stream=True,它不会立即开始生成并阻塞,而是返回一个异步生成器 (async generator) 。你可以把它想象成一个"承诺",它承诺会陆续产生数据块,但你需要用 async for 或者 await generator.__anext__() 来逐个获取它们。
2. asyncio.wait_for + asyncio.shield + asyncio.create_task 的组合模式

这是这段代码中最复杂也最核心的部分。让我们一步步看它的意图。

简单的方式(有问题):

python 复制代码
# 一个简单的、没有超时和并发检查的循环
async for chunk in generator:
    # process chunk

这个太简单了,无法在生成过程中插入中止检查。

稍微好一点的方式(仍有问题):

python 复制代码
while True:
    try:
        chunk = await generator.__anext__()
        # process chunk
        # check for abort signal
    except StopAsyncIteration:
        break

问题在于 await generator.__anext__() 可能会卡住很久。如果 sglang 后端因为某些原因(比如负载过高)迟迟不返回下一个 chunk,那么这个 await 就会一直等待,中止检查就无法及时执行。

当前代码的精妙之处:

python 复制代码
next_task = asyncio.create_task(generator.__anext__())
while True:
    try:
        chunk = await asyncio.wait_for(asyncio.shield(next_task), timeout=10)
        next_task = asyncio.create_task(generator.__anext__())
    except asyncio.TimeoutError:
        is_timeout = True
    ...
  • next_task = asyncio.create_task(generator.__anext__()):

    • 它不直接 await 下一个结果 ,而是把获取下一个 chunk 的操作 (generator.__anext__()) 包装成一个独立的 asyncio.Task
    • 这就像是说:"嘿,事件循环,请你在后台开始准备获取下一个数据块,但不要让我在这里干等着。"
    • create_task 会立即返回,代码可以继续执行。
  • await asyncio.wait_for(..., timeout=10):

    • 现在,主循环说:"好了,我现在愿意等一下那个后台任务 (next_task),但我的耐心只有 10 秒。"
    • 如果 next_task 在 10 秒内完成了(即 sglang 返回了一个 chunk),await 就会返回 chunk 的结果,一切正常。
    • 如果 10 秒过去了,next_task 还没完成,asyncio.wait_for 就会抛出 asyncio.TimeoutError
  • asyncio.shield(next_task):

    • 这是一个保护层。asyncio.wait_for 在超时时,默认行为是取消它正在等待的任务。
    • shield 的作用是防止 next_taskwait_for 取消。当超时发生时,wait_for 仍然会抛出 TimeoutError,但 next_task 本身没有被取消,它仍在后台继续运行,等待 sglang 的结果。
    • 为什么需要这样做? 这里的意图可能是:即使我(consumer的主循环)因为超时而不再等待了(我要去做中止检查),我也不想粗暴地打断 sglang 内部可能正在进行的工作。让它继续跑,也许下次循环我再来取结果。这增加了一点鲁棒性,但在这个特定代码的后续逻辑中,这个 shield 的效果不是特别明显,因为超时后并没有对 next_task 做特殊处理。不过,这是一个防御性的编程模式。
  • 循环的整体效果:

    • 这个组合拳实现了一个"带超时的轮询"机制。
    • 循环每最多 10 秒就会醒来一次。要么是因为成功收到了一个 chunk,要么是因为超时。
    • 无论哪种情况,它都能保证有机会去执行循环体后面的中止检查,从而使得中止信号的响应延迟最多为 10 秒,而不是可能无限长。
3. 中止逻辑 (if need_abort: ...)
  • need_abort = stop_flag: 首先检查全局停止标志。
  • async with abort_lock: ...: 然后在锁的保护下,检查当前请求的 rid_str 是否在 abort_rid_set 中。
  • 如果 need_abortTrue,说明这个正在进行中的生成任务需要被中止。
  • 中止步骤 :
    1. llm.tokenizer_manager.abort_request(obj.rid): 最重要的一步 。调用 sglang 的 API,明确通知后端引擎:"停止为这个请求 ID 生成更多 token"。这会释放后端 GPU 资源。
    2. generate_success = False: 标记这次生成是不成功的。
    3. next_task.cancel(): 取消我们正在等待的那个获取下一个 chunk 的 asyncio.Task。因为我们已经告诉后端中止了,所以这个任务永远不会有结果了。取消它可以让它立即以 CancelledError 异常退出。
    4. with contextlib.suppress(asyncio.CancelledError): await next_task:
      • 这是一个优雅地处理取消操作的方式。当我们 cancel() 一个任务后,await 它会抛出 asyncio.CancelledError
      • contextlib.suppress 的作用就是捕获并"吞掉"这个预期的 CancelledError,防止它冒泡到外层导致程序崩溃。我们只是想确保这个被取消的任务已经被清理掉了。
    5. break: 退出 while 循环,结束对这个请求的处理。
4. 结果收集
  • if not is_timeout:: 只有在没有超时,即成功获取了一个 chunk 的情况下,才处理它。
  • final_chunks[chunk_index] = chunk: sglang 在处理 n > 1 的采样时,返回的 chunk 会有一个 index 字段来区分是哪个序列的。这里将收到的 chunk 存入对应索引的 final_chunks 列表中,以便最后将所有序列的结果拼接起来。

流程总结

把所有部分串起来,整个循环的生命周期如下:

  1. 启动 : 提交请求给 sglang,得到一个异步生成器。并立即创建一个后台任务 next_task 去获取第一个 chunk。
  2. 循环开始 :
    • 等待 : await next_task,但最多等 10 秒。
    • 分支 A (成功) : 10 秒内拿到了 chunk
      • chunk 存起来。
      • 立即创建新的 next_task 去获取下一个 chunk。
    • 分支 B (超时) : 10 秒过去了,还没拿到 chunk
      • is_timeout 设为 True
    • 分支 C (结束) : generator 已经耗尽,__anext__ 抛出 StopAsyncIteration
      • break 退出循环。
  3. 检查 : 无论上述分支如何,都向下执行。检查全局停止标志和 abort_rid_set
  4. 中止 : 如果需要中止:
    • 通知 sglang 后端中止。
    • 取消正在等待的 next_task
    • break 退出循环。
  5. 继续: 如果不需要中止,回到第 2 步,开始下一轮循环。
  6. 循环结束 : 无论是正常完成还是被中止,循环都会退出,后续代码会根据 generate_successcollect_unfinished 的状态来决定如何调用 process_sglang_output
相关推荐
夏幻灵5 小时前
C++ 中手动重载赋值运算符(operator=)时实现部分复制的思路和方法
开发语言·c++·算法
亚林瓜子5 小时前
nodejs里面的百分号解码之URLSearchParams
开发语言·javascript·ecmascript·node·url·百分号编码
superman超哥5 小时前
仓颉语言中包与模块系统的深度剖析与工程实践
c语言·开发语言·c++·python·仓颉
尤老师FPGA5 小时前
使用ZYNQ芯片和LVGL框架实现用户高刷新UI设计系列教程(第四十二讲)
android·java·ui
x70x805 小时前
C++中不同容器的用法及接口(vector / deque / stack / queue / priority_queue)
开发语言·c++
再__努力1点5 小时前
LBP纹理特征提取:高鲁棒性的纹理特征算法
开发语言·人工智能·python·算法·计算机视觉
LCG米5 小时前
车载以太网SOME/IP协议栈在TI TDA4VM平台上的实现与测试
网络·网络协议·tcp/ip
lsx2024066 小时前
Bootstrap4 卡片布局指南
开发语言
老朱佩琪!6 小时前
Unity备忘录模式
java·unity·备忘录模式
嵌入式×边缘AI:打怪升级日志6 小时前
USB协议详解:从物理连接到数据传输的完整解析
网络·学习·usb