大模型推理引擎vLLM(19): vLLM中的DBO(Dual Batch Overlap)功能代码实现分析

文章目录

  • [1 什么是DBO--Dual Batch Overlap](#1 什么是DBO--Dual Batch Overlap)
  • [2 先看下vllm/docs/design/dbo.md 文档](#2 先看下vllm/docs/design/dbo.md 文档)
    • [2.1 Introduction](#2.1 Introduction)
    • [2.2 哪些代码是跟DBO相关的](#2.2 哪些代码是跟DBO相关的)
    • [2.3 DBO是把哪些过程做了overlap](#2.3 DBO是把哪些过程做了overlap)
    • [2.4 GPU Model Runner](#2.4 GPU Model Runner)
    • [2.5 UBatchWrapper](#2.5 UBatchWrapper)
    • [2.6 UBatchContext](#2.6 UBatchContext)
  • [3 代码中具体是怎么实现DBO的](#3 代码中具体是怎么实现DBO的)
    • [3.0 先看基础类或者说工具类:UBatchContext类--vllm/vllm/v1/worker/ubatching.py](#3.0 先看基础类或者说工具类:UBatchContext类--vllm/vllm/v1/worker/ubatching.py)
      • [3.0.1 def yield_(self)是什么意思](#3.0.1 def yield_(self)是什么意思)
      • [3.0.2 dbo_switch_to_compute和dbo_switch_to_comm()什么意思](#3.0.2 dbo_switch_to_compute和dbo_switch_to_comm()什么意思)
      • [3.0.3 dbo_switch_to_compute_sync什么意思](#3.0.3 dbo_switch_to_compute_sync什么意思)
      • [3.0.4 yield_and_switch_from_compute_to_comm和yield_and_switch_from_comm_to_compute函数什么意思](#3.0.4 yield_and_switch_from_compute_to_comm和yield_and_switch_from_comm_to_compute函数什么意思)
      • [3.0.5 dbo_get_previous_event函数什么意思](#3.0.5 dbo_get_previous_event函数什么意思)
      • [3.0.6 dbo_register_recv_hook和maybe_run_recv_hook什么意思](#3.0.6 dbo_register_recv_hook和maybe_run_recv_hook什么意思)
    • [3.1 modular_kernel.py文件](#3.1 modular_kernel.py文件)
      • [3.1.1 调用dbo_register_recv_hook(hook)和dbo_yield()](#3.1.1 调用dbo_register_recv_hook(hook)和dbo_yield())
      • [3.1.2 疑问:为什么这里面还有个dbo_maybe_run_recv_hook()](#3.1.2 疑问:为什么这里面还有个dbo_maybe_run_recv_hook())
    • [3.2 deepep_ll_prepare_finalize.py](#3.2 deepep_ll_prepare_finalize.py)
      • [3.2.1 疑问:为什么deepep_ll_prepare_finalize.py不像ht那样用那两个switch函数](#3.2.1 疑问:为什么deepep_ll_prepare_finalize.py不像ht那样用那两个switch函数)
    • [3.3 deepep_ht_prepare_finalize.py](#3.3 deepep_ht_prepare_finalize.py)
      • [3.3.1 _do_dispatch函数](#3.3.1 _do_dispatch函数)
      • [3.3.2 _finalize函数](#3.3.2 _finalize函数)
  • [4 相关疑问](#4 相关疑问)
    • [4.1 vllm中是在哪里创建UBatchWrapper的](#4.1 vllm中是在哪里创建UBatchWrapper的)
    • [4.2 一个ubatch线程里面有两个stream,那么是怎么创建两个stream的,](#4.2 一个ubatch线程里面有两个stream,那么是怎么创建两个stream的,)
    • [4.3 deepep_ht_prepare_finalize.py中的previous_event是做什么用的](#4.3 deepep_ht_prepare_finalize.py中的previous_event是做什么用的)

本文基于 GitHub 上开源的 vLLM 0.15.1 版本代码和文档,对其中的 DBO(Dual Batch Overlap)功能进行理解整理。

1 什么是DBO--Dual Batch Overlap

DBO相当于是将一个batch切分成两个micro-batch,然后每个线程跑一个micro-batch,然后中间用类似yield的函数,轮流交出控制权,让两个线程轮流执行,这样可以实现比如线程1的计算和线程2的通信overlap,或者线程2的计算和线程1的通信进行overlap。

2 先看下vllm/docs/design/dbo.md 文档

2.1 Introduction

The Dual Batch Overlap system works by splitting the batch in the model runner, creating two worker threads, and then running the model on each of these worker threads. When DBO is enabled, yield points within the FusedMoEModularKernel allow the two CPU worker threads (also called UBatch threads) to ping-pong between each other so that when one is running compute, the other is waiting on communication. Throughout the code, ubatch may be used as a short form of microbatch; this is an ASCII-friendly version of the short form µ-batch.

这里就是说将一个batch分成两个microbatch,然后一个线程负责一个umatch,然后线程中,通过在FusedMoEModularKernel加入yield使两个线程进行乒乓交互,这样就可以实现一个在做通信,另一个在做计算。

2.2 哪些代码是跟DBO相关的

The DBO system includes modifications to GpuModelRunner and ModularKernel, and defines two utility classes: UBatchWrapper and UBatchContext. UBatchWrapper manages thread lifecycle and CUDA graph execution of the model. UBatchContext wraps ForwardContext to coordinate synchronization between the two UBatch threads.

在文档中有上面的介绍,那么代码中实际修改的应该主要跟这几个类有关系,

  • GPUModelRunner:位于 vllm/vllm/v1/worker/gpu_model_runner.py,对应文档中的 GpuModelRunner,负责 batch 是否切成两个 micro-batch 以及相关元数据处理。
  • FusedMoEModularKernel:位于 vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py,对应文档中的 ModularKernel,负责 MoE 层内部执行流程,并挂接 DBO 的 yield 点。
  • UBatchWrapper:位于 vllm/vllm/v1/worker/gpu_ubatch_wrapper.py,负责起两个 UBatch 线程以及 CUDA Graph 的 capture/replay。
  • UBatchContext:位于 vllm/vllm/v1/worker/ubatching.py,包装 ForwardContext,用于两个 UBatch 线程之间的同步和 ping-pong 协调。

2.3 DBO是把哪些过程做了overlap

python 复制代码
# Schedule notation legend:
#    S = Shared expert
#    A0 = MLA qkv proj,
#    A1 = Core attn + out proj + MoE gate
#    D = Dispatch
#    C = Combine

# Comp: |-A0₀-A1₀-||-MLP₁-||-S₁-MLP₀-||-S₀-A0₁-A1₁-|
# Comm: |----D₁---||--D₀--||----C₁---||-----C₀-----|
# Order: D₁ send, A0₀, A1₀, D₁ recv, D₀ send, MLP₁, D₀ recv,
#        C₁ send, S₁, MLP₀, C₁ recv, C₀ send, S₀, A0₁, A1₁, C₀ recv.
# MLP_SHARED_OVERLAP = "mlp_shared_overlap"
  • S这里就是指共享专家的计算
  • 这里的A0是指的是MLA里面求Q K V三个矩阵的运算
  • A1里面,
    • Core attn就是注意力机制的运算,就是QKT,还有softmax,还有乘以V;
    • 然后out proj就是多头注意力算完之后的再乘以一个矩阵的那次运算;
    • 然后MoEgate,就是router相当于,就是个小分类器决定token路由到哪个专家上。
  • D 就是dispatch的那次all-to-all通信。
  • C 就是combine的那次all-to-all通信。
  • 然后这里面其实还少了个MLP没写,可能是因为MLP大家都知道不需要单独写出来吧。

2.4 GPU Model Runner

The batch is split into microbatches by the GPUModelRunner class. This is accomplished in two steps. First, coordination across all DP ranks is performed to determine whether microbatching will be applied. Microbatching must be uniform across all DP ranks. If microbatching is not feasible for any DP rank, it is disabled for all ranks. If all DP ranks are going to microbatch, the total number of tokens is padded up to the max number of tokens amongst all ranks. If any rank would end up with an empty second microbatch after the padding is applied, microbatching will be aborted and no ranks will microbatch. Once microbatching has been initiated by all ranks, the second step is performed. The CommonAttentionMetadata is sliced in half by the GPUModelRunner so that there is one attention metadata per-microbatch.

其实就是说判断能不能做DBO的流程

  • 首先,每个 DP rank 根据 --dbo-decode-token-threshold / --dbo-prefill-token-threshold 做本地判断,决定本 rank 是否尝试 DBO(ubatching)。
  • 然后做跨 DP 同步:只有当所有 rank 都同意尝试时,才继续走 DBO。
  • 接着进行二次可行性检查:按 DP 维度把 token 数对齐到全局最大值后,按 2 个 ubatch 切分;如果会导致第二个 ubatch 没有有效 token(只有 padding),则放弃 DBO。
  • 上述条件都满足时,GPUModelRunner 才会执行 microbatch 切分并进入 DBO 路径。

2.5 UBatchWrapper

其实这个类就是具体干活的。

2.6 UBatchContext

The current implementation has all dbo_yield and dbo_maybe_run_recv_hook calls in the FusedMoEModularKernel.forward method.

这里意思是目前,dbo_yield and dbo_maybe_run_recv_hook 只在moe的forward里面调用了。但其实后面看代码不是这样的,其实在deepep_ll_prepare_finalize.py里面也用了一次这个函数。

The dbo_register_recv_hook method registers a callback that can be returned by the FusedMoEPrepareAndFinalize class in the other UBatch thread's UBatchContext. The callback will be run when the other thread calls dbo_maybe_run_recv_hook. This is typically used to wait on an all-to-all kernel.

The dbo_maybe_run_recv_hook method runs a callback that's set by the dbo_register_recv_hook function if that callback exists.

  • 这里的意思其实就是本来是只有一个线程的,然后线程返回了hook函数,然后在合适的地方调用hook函数,那么就可以等待拿到receiver,
  • 现在是两个线程了,那么相当于是线程A拿到了hook后,利用dbo_register_recv_hook(hook) 把它登记到 线程 B 的 UBatchContext 里;
  • 线程 B 在它自己的执行点调用 dbo_maybe_run_recv_hook(),从自己的 context 里取出并执行这个 hook。
python 复制代码
            if hook is not None:
                if dbo_enabled():
                    # If DBO is being used, register the hook with the ubatch
                    # context and call it in dbo_maybe_run_recv_hook instead of
                    #  passing it to the receiver.
                    dbo_register_recv_hook(hook)
                    dbo_yield()
                else:
                    hook()

3 代码中具体是怎么实现DBO的

要想看代码中是怎么实现DBO的,最好的方法其实搜索下面三个

  • dbo_yield: 在modular_kernel.py中被用到。
  • dbo_register_recv_hook:在modular_kernel.py中被用到。
  • dbo_maybe_run_recv_hook:在deepep_ll_prepare_finalize.py和modular_kernel.py中被用到。
  • deepep_ht_prepare_finalize.py用到了下面这些
    • dbo_switch_to_compute和dbo_switch_to_comm()
    • dbo_switch_to_compute_sync,没用到switch_to_comm_sync
    • yield_and_switch_from_compute_to_comm, 和yield_and_switch_from_comm_to_compute,
    • dbo_get_previous_event

3.0 先看基础类或者说工具类:UBatchContext类--vllm/vllm/v1/worker/ubatching.py

python 复制代码
    def update_stream(self, stream):
        self.current_stream = stream
        if current_stream() != self.current_stream:
            torch.cuda.set_stream(self.current_stream)

    def _signal_comm_done(self):
        self.gpu_comm_done_event.record(self.comm_stream)

    def _signal_compute_done(self):
        self.gpu_compute_done_event.record(self.compute_stream)

    def _wait_compute_done(self):
        self.comm_stream.wait_event(self.gpu_compute_done_event)

    def _wait_comm_done(self):
        self.compute_stream.wait_event(self.gpu_comm_done_event)

    def _cpu_yield(self):
        # It is critical for correctness that only one thread is running
        # at a time. These asserts just make sure that this is the only
        # thread running before waking the other one up and going to sleep
        assert forward_context._forward_context == self.forward_context
        assert current_stream() == self.current_stream
        assert not self.cpu_wait_event.is_set()

        self.cpu_signal_event.set()
        self.cpu_wait_event.wait()
        self.cpu_wait_event.clear()
        self._restore_context()

    def switch_to_comm(self):
        self.update_stream(self.comm_stream)

    def switch_to_compute(self):
        self.update_stream(self.compute_stream)

    def switch_to_comm_sync(self):
        self._signal_compute_done()
        self.update_stream(self.comm_stream)
        self._wait_compute_done()

    def switch_to_compute_sync(self):
        self._signal_comm_done()
        self.update_stream(self.compute_stream)
        self._wait_comm_done()

    def maybe_run_recv_hook(self):
        if self.recv_hook is not None:
            self.recv_hook()
            self.recv_hook = None

    def yield_(self):
        self.current_stream = current_stream()
        self._cpu_yield()
        self.update_stream(self.current_stream)

    def yield_and_switch_from_compute_to_comm(self):
        assert current_stream() == self.compute_stream
        self._signal_compute_done()
        self._cpu_yield()
        assert self.current_stream == self.compute_stream
        self.update_stream(self.comm_stream)
        self._wait_compute_done()

    def yield_and_switch_from_comm_to_compute(self):
        assert current_stream() == self.comm_stream
        self._signal_comm_done()
        self._cpu_yield()
        assert self.current_stream == self.comm_stream
        self.update_stream(self.compute_stream)
        self._wait_comm_done()
......

def dbo_register_recv_hook(recv_hook):
    if len(_THREAD_ID_TO_CONTEXT) > 0:
        ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
        next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
        next_ctx.recv_hook = recv_hook


def dbo_get_previous_event(func, *args, **kwargs):
    if len(_THREAD_ID_TO_CONTEXT) > 0:
        ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
        ctx = _CURRENT_CONTEXTS[ctx_idx]
        # execute callable on the ubatch compute stream to record/wait events there
        with torch.cuda.stream(ctx.compute_stream):
            return func(*args, **kwargs)

3.0.1 def yield_(self)是什么意思

这的update_stream就是更新或者说切换stream,那么yield函数其实就是先保存当前stream,然后调用_cpu_yield()相当于交出cpu控制权,然后等过会重新获取到cpu控制权之后,再切换到之前保存的stream。

3.0.2 dbo_switch_to_compute和dbo_switch_to_comm()什么意思

就是指切换stream,最简单的函数。

3.0.3 dbo_switch_to_compute_sync什么意思

这个就是先往comm通信流中注册一个event,然后切换到compute流,然后等待刚才通信流中注册的那个event事件。

3.0.4 yield_and_switch_from_compute_to_comm和yield_and_switch_from_comm_to_compute函数什么意思

  • assert current_stream() == self.compute_stream:确认当前是计算stream,
  • self._signal_compute_done(): 发送一个event信号,当前计算stream上的任务执行结束后,这个event标记完成。
  • self._cpu_yield():让出cpu控制权
  • assert self.current_stream == self.compute_stream:等重新获取到cpu控制权后,确认当前是计算stream,
  • self.update_stream(self.comm_stream):切换到通信stream,
  • self._wait_compute_done():等待,等待的就是前面计算流上的任务结束。

那么这个函数其实就是相当于完成当前线程上计算流的任务,然后让出cpu控制权给别的线程,然后等会cpu控制权重新回来之后,切换到通信流,等到计算流的工作完成,然后就可以开始通信了。
那么就相当于cpu调度下当前线程的计算任务,然后让他执行着,然后cpu给另一个线程,等会回来,cpu再切换到通信流,并且还要等待计算流的任务结束。
这里切换stream,其实还是当前这个线程上的不同stream,然后yield让出cpu控制权是指让给另一个线程
其实他跟前面函数的区别就是他不仅切换了流,还让除了CPU控制权。

3.0.5 dbo_get_previous_event函数什么意思

这个函数是这么用的

python 复制代码
        previous_event = dbo_get_previous_event(self.buffer.capture)
        combined_x, _, event = self.buffer.combine(
            # HT combine only supports BF16
            x=fused_expert_output,
            handle=handle,
            topk_weights=None,
            config=self._get_combine_config(),
            previous_event=previous_event,
            async_finish=do_async and not dbo_enabled(),
            allocate_on_comm_stream=False,
        )

这里的意思,其实这个capture就相当于是在计算流上注册了一个event,然后将这个previous传入deepep的接口,那么deepep内部就会wait这个event。

3.0.6 dbo_register_recv_hook和maybe_run_recv_hook什么意思

这个dbo_register_recv_hook相当于当前线程将hook函数注册到另一个ubatch线程里面,然后另一个线程就可以调用maybe_run_recv_hook去等待当前线程的通信结束。

3.1 modular_kernel.py文件

3.1.1 调用dbo_register_recv_hook(hook)和dbo_yield()

这个文件中出现的dbo_yield()和dbo_register_recv_hook(hook)分别是在_prepare和_finalize函数中,在调用了子类的prepare_async以及finalize函数后面,也就是在调用的deepep_ll_prepare_finalize.py或者deepep_ht_prepare_finalize.py中的prepare_async以及finalize函数后面。

python 复制代码
            if hook is not None:
                if dbo_enabled():
                    # If DBO is being used, register the hook with the ubatch
                    # context and call it in dbo_maybe_run_recv_hook instead of
                    #  passing it to the receiver.
                    dbo_register_recv_hook(hook)
                    dbo_yield()
                else:
                    hook()

那么这里其实就是相当于调用了dispatch或者combine通信函数之后,不直接去执行返回的hook函数,而是将这个hook函数注册到另一个线程中,然后调用yield让当前线程让出cpu控制权,然后让另一个线程在合适的时候去调用这个hook函数。

3.1.2 疑问:为什么这里面还有个dbo_maybe_run_recv_hook()

python 复制代码
            # Overlap shared expert compute with all2all dispatch.
            dbo_maybe_run_recv_hook()
            prepare_ret = self.prepare_finalize.prepare_async(
                hidden_states,
                topk_weights,
                topk_ids,
                global_num_experts,
                local_num_experts,
                expert_map,
                apply_router_weight_on_input,
                self.fused_experts.quant_config,
            )

他是在prepare之前有一个dbo_maybe_run_recv_hook(),然后就是在deepep_ll_prepare_finalize.py文件中finalize函数里面的low_latency_combine之前也有一个dbo_maybe_run_recv_hook()函数,

那么合起来就是相当于:

  • 当前线程在某个 DBO 检查点(例如 _prepare 前或 LL _finalize 内)调用 dbo_maybe_run_recv_hook(),先消费并等待对方线程之前注册过来的通信 hook;对方那次通信可能是 dispatch,也可能是 combine。
  • 然后当前线程继续推进自己这轮 MoE 阶段(例如发起 prepare_async/dispatch)。若返回了 hook,则在 DBO 下执行 dbo_register_recv_hook(hook) 并 dbo_yield():把"等待这轮通信完成"的动作托管给对方线程,同时让出 CPU 控制权。
  • 对方线程被唤醒后,从它自己的当前位置继续执行(可能是 MLP/shared expert/下一层计算,也可能先到通信相关步骤),不是固定单一类型任务。
  • 对方线程在后续到达下一个检查点时,也会用 dbo_maybe_run_recv_hook() 消费当前线程注册过去的 hook;随后发起自己这轮通信,并再次 register + yield 把等待动作交回去。
  • 两个线程就这样在"通信等待托管 + CPU 交棒 + 各自阶段推进"中循环,实现跨 ubatch 的通信/计算重叠

3.2 deepep_ll_prepare_finalize.py

这个文件里面就一个dbo_maybe_run_recv_hook()函数,前面已经讨论过了。

3.2.1 疑问:为什么deepep_ll_prepare_finalize.py不像ht那样用那两个switch函数

为什么deepep_ll_prepare_finalize.py不像ht那样用那两个yield_and_switch_from_compute_to_comm 和 yield_and_switch_from_comm_to_compute函数,而是只用了一个dbo_maybe_run_recv_hook()函数,再配合modular_kernel.py里面的文件就可以了。

看了下代码,以及理解了一下yield_and_switch_from_compute_to_commyield_and_switch_from_comm_to_compute,我才发现,其实在ht里面他是用了2个stream,一个计算stream,一个通信stream,然后它相当于在一个ubatch线程中,计算和通信也是在不同的stream上并行的,而ll里面,他的一个ubatch线程中,计算和通信都是在同一个stream上串行的(这里只是说DBO,不说sbo,ll中sbo的话也是用了两个流的),。

  • 也就是对于ht的DBO,他是用了两个ubatch并行,并且每个ubatch内部又用了两个流分别做计算和通信,
  • 而对于ll的DBO,他也是用了两个ubatch线程并行,但是每个ubatch内部只有一个stream串行的处理计算和通信任务。

3.3 deepep_ht_prepare_finalize.py

首先有个问题,就是对于ht,依然是会跑到modular_kernel.py里面,这是肯定的,前面分析过,modular_kernel.py文件里面其实已经

  • 在prepare和finalize的函数后面都有一个dbo_register_recv_hook和dbo_yield函数了,
  • 在prepare前面已经有个dbo_maybe_run_recv_hook函数了,

但是对于ht来说其实这三个函数都没用,因为比如

python 复制代码
            if hook is not None:
                if dbo_enabled():
                    # If DBO is being used, register the hook with the ubatch
                    # context and call it in dbo_maybe_run_recv_hook instead of
                    #  passing it to the receiver.
                    dbo_register_recv_hook(hook)
                    dbo_yield()
                else:
                    hook()

对于ht来说这里的hook是空的,所以根本就没用到dbo_register_recv_hook和dbo_yield函数。

然后对于dbo_maybe_run_recv_hook

python 复制代码
    def maybe_run_recv_hook(self):
        if self.recv_hook is not None:
            self.recv_hook()
            self.recv_hook = None

这里是None,那么就什么不执行,

ht文件中关于dbo功能的代码纯靠他自己文件中的这些函数实现的,

  • dbo_switch_to_compute(1次)和dbo_switch_to_comm(1次)
  • dbo_switch_to_compute_sync(1次),没用到switch_to_comm_sync
  • yield_and_switch_from_compute_to_comm(2次), 和yield_and_switch_from_comm_to_compute(1次),

3.3.1 _do_dispatch函数

python 复制代码
        # We yield before launching the dispatch kernel since the dispatch
        # kernel will block the CPU so we want to queue up all the compute
        # for the other ubatch before the dispatch kernel starts.
        dbo_yield_and_switch_from_compute_to_comm()

        # capture a DeepEP event and pass it as previous_event so
        # DeepEP honors the dependency internally.
        previous_event = dbo_get_previous_event(self.buffer.capture)    

        调用 self.buffer.dispatch( 函数,并把previous_event传进去

        dbo_switch_to_compute_sync()
  • dbo_yield_and_switch_from_compute_to_comm会先往计算流中加一个event,然后交出cpu控制权给另一个线程u调度,另一个线程可能调度计算任务也可能是通信任务,不一定,然后等会cpu控制权回来,切换到通信流,并等待前面的计算流任务结束,
  • previous_event = dbo_get_previous_event(self.buffer.capture),然后调用 self.buffer.dispatch( 函数,并把previous_event传进去,然后这两个函数相当于也是往计算流中增加一个event,并且传给deepep,deepep内部就会等待这个event,然后做通信,
  • 调用dbo_switch_to_compute_sync,这个往通信流中加一个event,然后切换到计算流,在计算流中会等待这个event结束,

3.3.2 _finalize函数

python 复制代码
        dbo_yield_and_switch_from_compute_to_comm()

        previous_event = dbo_get_previous_event(self.buffer.capture)

        combined_x, _, event = self.buffer.combine(
            # HT combine only supports BF16
            x=fused_expert_output,
            handle=handle,
            topk_weights=None,
            config=self._get_combine_config(),
            previous_event=previous_event,
            async_finish=do_async and not dbo_enabled(),
            allocate_on_comm_stream=False,
        )

        dbo_switch_to_compute()

        if do_async:
            def _receiver():
                if event.event is not None:
                    event.current_stream_wait()
                dbo_switch_to_comm()
                # Respect inplace outputs.
                output.copy_(combined_x, non_blocking=True)

                # TODO(lucas): refactor the modular kernel so this will be
                # handled there
                dbo_yield_and_switch_from_comm_to_compute()
            return _receiver
  • dbo_yield_and_switch_from_compute_to_comm() 首先是在计算流上 record 一个 done event,然后让出 CPU 控制权;这时另一个线程获得 CPU 去调度它自己的后续任务(可能是计算也可能是通信,不固定)。等当前线程重新拿回 CPU 后,会切换到通信流,并在通信流上 wait 前面计算流的 done event,确保 compute->comm 依赖成立。

  • 然后 previous_event = dbo_get_previous_event(self.buffer.capture) 是在当前 ubatch 的计算流上再 capture 一个 DeepEP event,并传给 combine。combine 内部会在真正执行通信前等待这个 previous_event,用于库内部依赖约束。

  • dbo_switch_to_compute() 这里是轻量切回计算流(不做 sync / 不交出 CPU)。之所以不是 dbo_switch_to_compute_sync(),是因为后面 async 路径会由 _receiver 通过 event.current_stream_wait() 在需要的时机再做等待,不想在这里过早把流程同步死,在 _receiver 里:

    • if event.event is not None: event.current_stream_wait():等待的是 combine 返回的 completion event(不是"直接等通信流本身");
    • 然后 dbo_switch_to_comm(),把后续 output.copy_ 放到通信流上执行;
    • 最后 dbo_yield_and_switch_from_comm_to_compute():在通信流上 signal done,交出 CPU,回来后切回计算流并在计算流上 wait 通信 done,恢复 DBO 的 ping-pong 节奏。
    • 最后这一步dbo_yield_and_switch_from_comm_to_compute,又相当于是在当前通信流上加一个event,然后交出cpu,让其他线程调度。 可是这里其实前面if event.event is not None: event.current_stream_wait()就已经等待通信流完成了,所以这个函数其实里面的等待没效果,更多的我感觉就是交出了cpu控制权而已。

4 相关疑问

4.1 vllm中是在哪里创建UBatchWrapper的

vllm/vllm/v1/worker/gpu_model_runner.py中 def load_model(self, eep_scale_up: bool = False) -> None:函数里,创建的UBatchWrapper

python 复制代码
        # for other compilation modes, cudagraph behavior is controlled by
        # CudagraphWraper and CudagraphDispatcher of vllm.

        # wrap the model with full cudagraph wrapper if needed.
        cudagraph_mode = self.compilation_config.cudagraph_mode
        assert cudagraph_mode is not None
        if (
            cudagraph_mode.has_full_cudagraphs()
            and not self.parallel_config.use_ubatching
        ):
            self.model = CUDAGraphWrapper(
                self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
            )
        elif self.parallel_config.use_ubatching:
            if cudagraph_mode.has_full_cudagraphs():
                self.model = UBatchWrapper(
                    self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
                )
            else:
                self.model = UBatchWrapper(
                    self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
                )

4.2 一个ubatch线程里面有两个stream,那么是怎么创建两个stream的,

DBO其实是默认的当前主stream用来做计算,然后额外创建了一个stream用来做通信,

其中通信stream在构造函数中,self.comm_stream = torch.cuda.Stream(device=device)

python 复制代码
class UBatchWrapper:
    def __init__(
        self,
        runnable: Callable,
        vllm_config: VllmConfig,
        runtime_mode: CUDAGraphMode,
        device: torch.cuda.device,
    ):
        self.runnable = runnable
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.comm_stream = torch.cuda.Stream(device=device)
        # Ubatch threads plus the main thread
        self.ready_barrier = threading.Barrier(
            self.vllm_config.parallel_config.num_ubatches + 1
        )

        self.cudagraphs: dict[int, CUDAGraphMetaData] = {}

        self.cudagraph_wrapper = None
        self.graph_pool = None
        if runtime_mode is not CUDAGraphMode.NONE:
            self.cudagraph_wrapper = CUDAGraphWrapper(
                runnable, vllm_config, runtime_mode=runtime_mode
            )
            self.graph_pool = current_platform.get_global_graph_pool()

        self.sm_control = self._create_sm_control_context(vllm_config)
        self.device = device

然后默认的那个计算流是在call函数中每次计算时候获取的compute_stream = torch.cuda.current_stream()

python 复制代码
    def __call__(self, *args, **kwargs):
        forward_context = get_forward_context()
        batch_descriptor = forward_context.batch_descriptor
        ubatch_slices = forward_context.ubatch_slices
        cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

        # If there's no ubatching, just run the runnable object
        if ubatch_slices is None:
            # This is to account for the case where ubatching was aborted.
            # When we capture full graphs we only capture one graph per shape,
            # meaning that if we have a ubatched  cudagraph for the current
            # num_tokens, we don't have a non-ubatched one. Without this
            # check, the cudagraph wrapper will try to capture a cudagraph
            # for this shape during a normal run.
            if cudagraph_runtime_mode is CUDAGraphMode.FULL:
                assert batch_descriptor is not None
                if batch_descriptor.num_tokens in self.cudagraphs:
                    cudagraph_runtime_mode = CUDAGraphMode.NONE

            if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
                return self.runnable(*args, **kwargs)
            else:
                assert self.cudagraph_wrapper is not None
                return self.cudagraph_wrapper(*args, **kwargs)

        attn_metadata = forward_context.attn_metadata
        slot_mapping = forward_context.slot_mapping
        num_tokens = (
            ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
        ) * 2
        input_ids = kwargs["input_ids"]
        positions = kwargs["positions"]
        intermediate_tensors = kwargs["intermediate_tensors"]
        inputs_embeds = kwargs["inputs_embeds"]
        compute_stream = torch.cuda.current_stream()

        dp_metadata = forward_context.dp_metadata

4.3 deepep_ht_prepare_finalize.py中的previous_event是做什么用的

在deepep_ht_prepare_finalize.py文件中,在dispatch以及combine中都出现了这个previous_event。

比如

python 复制代码
    def _do_dispatch(
        self,
        tokens: torch.Tensor,
        token_scales: torch.Tensor | None,
        rank_topk_ids: torch.Tensor,
        rank_topk_weights: torch.Tensor,
        num_experts: int,
        a1_scale: torch.Tensor | None,
        quant_config: FusedMoEQuantConfig,
    ) -> Callable:
        has_scales = token_scales is not None

        # We yield before launching the dispatch kernel since the dispatch
        # kernel will block the CPU so we want to queue up all the compute
        # for the other ubatch before the dispatch kernel starts.
        dbo_yield_and_switch_from_compute_to_comm()

        # capture a DeepEP event and pass it as previous_event so
        # DeepEP honors the dependency internally.
        previous_event = dbo_get_previous_event(self.buffer.capture)

        (
            num_tokens_per_rank,
            num_tokens_per_rdma_rank,
            dispatch_expert_num_tokens,
            is_token_in_rank,
            event,
        ) = self.buffer.get_dispatch_layout(
            topk_idx=rank_topk_ids,
            num_experts=num_experts,
            previous_event=previous_event,
            async_finish=False,
            allocate_on_comm_stream=False,
        )

比如

python 复制代码
    def _finalize(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: mk.TopKWeightAndReduce,
        do_async: bool,
    ) -> Callable | None:
        a2a_idx = dbo_current_ubatch_id()
        handle = self.handles[a2a_idx]
        assert handle is not None

        # fused_expert_output can have 0 tokens - This happens when none of the
        # tokens from the all2all reach this EP rank.
        if fused_expert_output.numel() != 0:
            if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
                weight_and_reduce_impl = TopKWeightAndReduceContiguous()
            fused_expert_output = weight_and_reduce_impl.apply(
                output=None,
                fused_expert_output=fused_expert_output,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        dbo_yield_and_switch_from_compute_to_comm()
        assert fused_expert_output.dtype == torch.bfloat16, (
            f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
        )
        previous_event = dbo_get_previous_event(self.buffer.capture)
        combined_x, _, event = self.buffer.combine(
            # HT combine only supports BF16
            x=fused_expert_output,
            handle=handle,
            topk_weights=None,
            config=self._get_combine_config(),
            previous_event=previous_event,
            async_finish=do_async and not dbo_enabled(),
            allocate_on_comm_stream=False,
        )

        dbo_switch_to_compute()      

然后这个dbo_get_previous_event函数定义如下

python 复制代码
def dbo_get_previous_event(func, *args, **kwargs):
    if len(_THREAD_ID_TO_CONTEXT) > 0:
        ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
        ctx = _CURRENT_CONTEXTS[ctx_idx]
        # execute callable on the ubatch compute stream to record/wait events there
        with torch.cuda.stream(ctx.compute_stream):
            return func(*args, **kwargs)
相关推荐
AIGC_北苏20 小时前
Qwen3.5开源模型实测
vllm
npupengsir1 天前
nano vllm代码详解
人工智能·算法·vllm
冰封剑心1 天前
容器参数错误,更换参数
人工智能·计算机视觉·vllm
吴佳浩 Alben2 天前
GPU 生产环境实践:硬件拓扑、显存管理与完整运维体系
运维·人工智能·pytorch·语言模型·transformer·vllm
m0_564876842 天前
nano-vllm学习
学习·vllm
谢白羽4 天前
vllm实践
android·vllm
冰封剑心4 天前
VLLM部署
vllm
翱翔的苍鹰4 天前
通过LangChain Agent模拟实现美团外卖下单场景
人工智能·深度学习·语言模型·自然语言处理·langchain·vllm
susu10830189116 天前
LiteLLM + vLLM模型调用引擎架构
vllm