【RL】 ROLL中负载均衡

好的,这是一个非常好的问题,它触及了这个复杂系统设计的核心。我们来逐步拆解 load_balance_coordinator 的作用、实现方式,以及这些类之间的关系和整体流程。

1. load_balance_coordinator 在哪个类下?为什么?

load_balance_coordinatorDynamicSamplingSchedulerGenerateScheduler 这两个调度器类自身的成员变量(一个字典)。

为什么在这里?

因为 调度器(Scheduler)是整个系统的"大脑"和"交通指挥官" 。它掌握着全局信息,知道有多少个 worker(actor_cluster),也知道自己向每个 worker 分派了多少任务。因此,跟踪每个 worker 的负载并据此做出决策,是调度器最核心的职责之一。

load_balance_coordinator 放在调度器内部,遵循了"谁决策,谁持有信息"的设计原则。

  • load_balance_coordinator 的结构 :
    它是一个字典,key 是 worker 的 dp_rank(可以看作是 worker 的唯一标识),value 是一个整数,表示当前已分派给该 worker 且尚未收到完成回调的请求数量。它是一个"在途请求计数器"。

2. 负载均衡是如何实现的?

负载均衡的实现主要依赖于两个部分:get_available_dp_rank 方法和 report_response 方法中的计数器更新。

A. get_available_dp_rank() 方法 - 决策

python 复制代码
def get_available_dp_rank(self):
    while True:
        # 1. 排序:找出最空闲的 worker
        sorted_ranks = sorted(
            self.load_balance_coordinator.keys(), 
            key=lambda rank: (self.load_balance_coordinator[rank], rank)
        )
        
        # 2. 检查:最空闲的 worker 是否还有容量
        if self.load_balance_coordinator[sorted_ranks[0]] < self.max_running_requests:
            # 3. 产出:返回这个空闲 worker 的 rank
            yield sorted_ranks[0]
  1. 排序 (sorted) : 这是负载均衡的核心。它对所有的 dp_rank 进行排序。
    • key=lambda rank: (self.load_balance_coordinator[rank], rank): 这个 key 非常巧妙。它首先按照每个 rank 的在途请求数(load_balance_coordinator[rank]升序排列。这意味着负载最低的 worker 会排在最前面。
    • 如果两个 worker 的负载相同,它会再按照 rank 本身的大小排序,这确保了排序结果的稳定性(一个 tie-breaking 规则)。
  2. 检查 (if ...) : 排序后,sorted_ranks[0] 就是当前最空闲的 worker。代码会检查它的负载是否小于预设的单个 worker 的最大并发数 self.max_running_requests
  3. 产出 (yield) : 如果最空闲的 worker 还有容量,就通过 yield 返回它的 dp_rankyield 会让这个函数变成一个生成器,调用方用 next() 就可以拿到这个 rank。如果最空闲的 worker 也满了,if 条件不满足,while True 循环会继续执行, effectively 阻塞调用者,直到有 worker 变空闲。

B. 计数器更新 - 维护状态

  • 请求分发时(get_batch:

    python 复制代码
    # ... 选择了 dp_rank ...
    with self.lock:
        # ...
        self.actor_cluster.workers[dp_rank].add_request.remote(...)
        self.load_balance_coordinator[dp_rank] += 1 # <<< 增加计数

    当调度器向一个 worker 发送一个请求后,它会立即 将该 worker 的负载计数器 +1

  • 收到响应时(report_response:

    python 复制代码
    @ray.method(concurrency_group="multi_thread")
    def report_response(self, data: DataProto):
        # ...
        with self.lock:
            self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1 # <<< 减少计数
        # ...

    当 worker 完成任务并通过回调 report_response 返回结果时,调度器会根据 request_id 找到当初这个请求被发往的 dp_rank,然后将该 worker 的负载计数器 -1

这个 +1-1 的闭环操作,精确地维护了每个 worker 当前的在途任务数量,为 get_available_dp_rank 的决策提供了准确的数据。


3. 类之间的关系和整体流程

让我们以 DynamicSamplingScheduler 为例,梳理一下它与其他组件的关系和完整的工作流程。

类关系图:

复制代码
+---------------------------+       (1. get_batch)
|      Training Loop        | <--------------------------+
+---------------------------+                            |
          ^                                              |
          | (5. returns batch data)                      |
          |                                              |
+---------------------------+                            | (6. data is ready)
| DynamicSamplingScheduler  |----------------------------+
| - load_balance_coordinator|       (2. add_request)     |
| - get_available_dp_rank() |--------------------------->|
| - report_response()       |       (3. compute_rewards) |
+---------------------------+                            |
          |         ^                                    |
 (4. response_callback) |                                    |
          |         |                                    |
+---------------------------+       +---------------------------+
|    Actor Cluster (LLMs)   |       |   Reward Cluster (RMs)    |
| - Worker 1 (dp_rank=0)    |       | - RM Worker A             |
| - Worker 2 (dp_rank=1)    |       | - RM Worker B             |
| ...                       |       | ...                       |
+---------------------------+       +---------------------------+

整体流程:

  1. 启动 (Training Loop -> Scheduler):

    • 主训练循环需要一批新的训练数据。它调用 DynamicSamplingSchedulerget_batch(batch_size=...) 方法。
  2. 调度与分发 (Scheduler -> Actor Cluster):

    • DynamicSamplingSchedulerget_batch 方法进入其主循环。
    • 它从数据集中取出一个 prompt
    • 它调用 get_available_dp_rank() 来决策:"我应该把这个新任务给谁?"。假设 dp_rank=0 (Worker 1) 是最空闲的。
    • 它准备好请求数据,并将请求通过 ray.remote() 异步地 发送给 Actor Cluster 中的 Worker 1
    • 发送后,它立刻 更新 self.load_balance_coordinator[0] += 1
    • 这个过程会持续进行,直到达到流控上限。
  3. 生成 (Actor Cluster):

    • Worker 1 接收到请求,调用其内部的 LLM(如 sglang 引擎)进行文本生成。
    • 生成完成后,Worker 1 查找请求元信息中的 response_callback_fn,发现它指向 DynamicSamplingSchedulerreport_response 方法。
  4. 回调 (Actor Cluster -> Scheduler):

    • Worker 1 调用 scheduler.report_response.remote(response_data)
  5. 接收与奖励计算 (Scheduler -> Reward Cluster):

    • DynamicSamplingSchedulerreport_response 方法被触发。
    • 立刻 更新 self.load_balance_coordinator[0] -= 1,表示 Worker 1 的一个任务完成了,负载下降了。
    • 然后,它将收到的生成结果立即转发Reward Cluster 中的一个 worker,去计算奖励分数。
  6. 过滤与缓冲 (Scheduler 内部):

    • 当 Reward 模型返回奖励后,report_response 继续执行。
    • 它会进行 response_filter_fnquery_filter_fn 过滤。
    • 只有通过过滤的、包含了(prompt, generation, reward)的完整"经验"才会被存入 self.completed_buffers
    • report_response 检查 completed_buffers 的大小。
  7. 返回批次 (Scheduler -> Training Loop):

    • get_batch 的主循环一直在检查 completed_buffers 的大小。
    • completed_buffers 中的数据足够凑成一个 batch_size 时,get_batch 停止采样,将这些数据打包成一个批次,然后 return 给最初的调用者------训练循环。

这个流程完美地展示了 load_balance_coordinator 作为调度器核心状态,如何驱动请求的智能分发,并通过异步回调机制实现了高效的、流水线式的分布式计算。

相关推荐
吕了了1 小时前
165 Windows 系统在 UEFI 和 Legacy BIOS 上的启动流程详解
运维·windows·系统
星辞树1 小时前
从计数到预测:深入浅出词向量 (Word Vectors) —— Stanford CS224n 作业实战记录
算法
JarryStudy1 小时前
自动调优在Triton-on-Ascend中的应用:从参数优化到性能极致挖掘
人工智能·算法·昇腾·cann·ascend c
CoderYanger1 小时前
递归、搜索与回溯-穷举vs暴搜vs深搜vs回溯vs剪枝:13.子集
java·算法·leetcode·机器学习·剪枝·1024程序员节
The_cute_cat1 小时前
Ubuntu指令的初步学习
linux·运维·ubuntu
python百炼成钢1 小时前
40.linux自带LED驱动
linux·运维·服务器
黑客思维者1 小时前
底层冗余性原理探秘模型剪枝(Pruning)为何能“无损”压缩模型?
算法·机器学习·剪枝
会飞的土拨鼠呀1 小时前
linux 重新运行NetworkManager
linux·运维·服务器
shawnyz1 小时前
RHCSE--SHELL02--变量
linux·运维·服务器