好的,这是一个非常好的问题,它触及了这个复杂系统设计的核心。我们来逐步拆解 load_balance_coordinator 的作用、实现方式,以及这些类之间的关系和整体流程。
1. load_balance_coordinator 在哪个类下?为什么?
load_balance_coordinator 是 DynamicSamplingScheduler 和 GenerateScheduler 这两个调度器类自身的成员变量(一个字典)。
为什么在这里?
因为 调度器(Scheduler)是整个系统的"大脑"和"交通指挥官" 。它掌握着全局信息,知道有多少个 worker(actor_cluster),也知道自己向每个 worker 分派了多少任务。因此,跟踪每个 worker 的负载并据此做出决策,是调度器最核心的职责之一。
将 load_balance_coordinator 放在调度器内部,遵循了"谁决策,谁持有信息"的设计原则。
load_balance_coordinator的结构 :
它是一个字典,key是 worker 的dp_rank(可以看作是 worker 的唯一标识),value是一个整数,表示当前已分派给该 worker 且尚未收到完成回调的请求数量。它是一个"在途请求计数器"。
2. 负载均衡是如何实现的?
负载均衡的实现主要依赖于两个部分:get_available_dp_rank 方法和 report_response 方法中的计数器更新。
A. get_available_dp_rank() 方法 - 决策
python
def get_available_dp_rank(self):
while True:
# 1. 排序:找出最空闲的 worker
sorted_ranks = sorted(
self.load_balance_coordinator.keys(),
key=lambda rank: (self.load_balance_coordinator[rank], rank)
)
# 2. 检查:最空闲的 worker 是否还有容量
if self.load_balance_coordinator[sorted_ranks[0]] < self.max_running_requests:
# 3. 产出:返回这个空闲 worker 的 rank
yield sorted_ranks[0]
- 排序 (
sorted) : 这是负载均衡的核心。它对所有的dp_rank进行排序。key=lambda rank: (self.load_balance_coordinator[rank], rank): 这个key非常巧妙。它首先按照每个 rank 的在途请求数(load_balance_coordinator[rank])升序排列。这意味着负载最低的 worker 会排在最前面。- 如果两个 worker 的负载相同,它会再按照
rank本身的大小排序,这确保了排序结果的稳定性(一个 tie-breaking 规则)。
- 检查 (
if ...) : 排序后,sorted_ranks[0]就是当前最空闲的 worker。代码会检查它的负载是否小于预设的单个 worker 的最大并发数self.max_running_requests。 - 产出 (
yield) : 如果最空闲的 worker 还有容量,就通过yield返回它的dp_rank。yield会让这个函数变成一个生成器,调用方用next()就可以拿到这个 rank。如果最空闲的 worker 也满了,if条件不满足,while True循环会继续执行, effectively 阻塞调用者,直到有 worker 变空闲。
B. 计数器更新 - 维护状态
-
请求分发时(
get_batch):python# ... 选择了 dp_rank ... with self.lock: # ... self.actor_cluster.workers[dp_rank].add_request.remote(...) self.load_balance_coordinator[dp_rank] += 1 # <<< 增加计数当调度器向一个 worker 发送一个请求后,它会立即 将该 worker 的负载计数器
+1。 -
收到响应时(
report_response):python@ray.method(concurrency_group="multi_thread") def report_response(self, data: DataProto): # ... with self.lock: self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1 # <<< 减少计数 # ...当 worker 完成任务并通过回调
report_response返回结果时,调度器会根据request_id找到当初这个请求被发往的dp_rank,然后将该 worker 的负载计数器-1。
这个 +1 和 -1 的闭环操作,精确地维护了每个 worker 当前的在途任务数量,为 get_available_dp_rank 的决策提供了准确的数据。
3. 类之间的关系和整体流程
让我们以 DynamicSamplingScheduler 为例,梳理一下它与其他组件的关系和完整的工作流程。
类关系图:
+---------------------------+ (1. get_batch)
| Training Loop | <--------------------------+
+---------------------------+ |
^ |
| (5. returns batch data) |
| |
+---------------------------+ | (6. data is ready)
| DynamicSamplingScheduler |----------------------------+
| - load_balance_coordinator| (2. add_request) |
| - get_available_dp_rank() |--------------------------->|
| - report_response() | (3. compute_rewards) |
+---------------------------+ |
| ^ |
(4. response_callback) | |
| | |
+---------------------------+ +---------------------------+
| Actor Cluster (LLMs) | | Reward Cluster (RMs) |
| - Worker 1 (dp_rank=0) | | - RM Worker A |
| - Worker 2 (dp_rank=1) | | - RM Worker B |
| ... | | ... |
+---------------------------+ +---------------------------+
整体流程:
-
启动 (Training Loop -> Scheduler):
- 主训练循环需要一批新的训练数据。它调用
DynamicSamplingScheduler的get_batch(batch_size=...)方法。
- 主训练循环需要一批新的训练数据。它调用
-
调度与分发 (Scheduler -> Actor Cluster):
DynamicSamplingScheduler的get_batch方法进入其主循环。- 它从数据集中取出一个
prompt。 - 它调用
get_available_dp_rank()来决策:"我应该把这个新任务给谁?"。假设dp_rank=0(Worker 1) 是最空闲的。 - 它准备好请求数据,并将请求通过
ray.remote()异步地 发送给Actor Cluster中的Worker 1。 - 发送后,它立刻 更新
self.load_balance_coordinator[0] += 1。 - 这个过程会持续进行,直到达到流控上限。
-
生成 (Actor Cluster):
Worker 1接收到请求,调用其内部的 LLM(如sglang引擎)进行文本生成。- 生成完成后,
Worker 1查找请求元信息中的response_callback_fn,发现它指向DynamicSamplingScheduler的report_response方法。
-
回调 (Actor Cluster -> Scheduler):
Worker 1调用scheduler.report_response.remote(response_data)。
-
接收与奖励计算 (Scheduler -> Reward Cluster):
DynamicSamplingScheduler的report_response方法被触发。- 它立刻 更新
self.load_balance_coordinator[0] -= 1,表示Worker 1的一个任务完成了,负载下降了。 - 然后,它将收到的生成结果立即转发 给
Reward Cluster中的一个 worker,去计算奖励分数。
-
过滤与缓冲 (Scheduler 内部):
- 当 Reward 模型返回奖励后,
report_response继续执行。 - 它会进行
response_filter_fn和query_filter_fn过滤。 - 只有通过过滤的、包含了(prompt, generation, reward)的完整"经验"才会被存入
self.completed_buffers。 report_response检查completed_buffers的大小。
- 当 Reward 模型返回奖励后,
-
返回批次 (Scheduler -> Training Loop):
get_batch的主循环一直在检查completed_buffers的大小。- 当
completed_buffers中的数据足够凑成一个batch_size时,get_batch停止采样,将这些数据打包成一个批次,然后return给最初的调用者------训练循环。
这个流程完美地展示了 load_balance_coordinator 作为调度器核心状态,如何驱动请求的智能分发,并通过异步回调机制实现了高效的、流水线式的分布式计算。