expand_requests 函数用于根据配置参数扩展和复制生成请求,支持两种不同的采样策略。
函数作用
expand_requests 根据 is_num_return_sequences_expand 配置决定如何处理多采样请求 1 :
扩展模式(is_num_return_sequences_expand=True)
- 将每个请求的
num_return_sequences设为 1 - 复制
num_return_sequences个相同的请求 - 每个请求独立生成 1 个响应
非扩展模式(is_num_return_sequences_expand=False)
- 直接设置
num_return_sequences为指定值 - 只发送 1 个请求,生成多个响应
代码实现
python
def expand_requests(self, data: DataProto):
generate_opt_level = self.pipeline_config.generate_opt_level
is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
num_return_sequences = self.generation_config["num_return_sequences"]
target_requests = []
if is_num_return_sequences_expand:
generation_config["num_return_sequences"] = 1
for _ in range(num_return_sequences):
target_requests.append(copy.deepcopy(data))
else:
generation_config["num_return_sequences"] = num_return_sequences
target_requests.append(copy.deepcopy(data))
return target_requests
配置参数
在调度流程中的位置
在 DynamicSamplingScheduler.get_batch 中调用 expand_requests 4 :
- 从数据集获取 prompt
- 调用
expand_requests扩展请求 - 为每个扩展的请求分配唯一的
request_id - 发送给 SGLang 进行推理
设计优势
- 灵活性: 支持两种采样策略,适应不同的推理引擎特性
- 负载均衡: 扩展模式可以将请求分散到不同的worker
- 性能优化: 某些引擎在单采样模式下性能更好
Notes
- 异步调度器中有相同的实现 5
- 要求
generate_opt_level > 0,即必须使用调度器模式 6 - 在RLVR Pipeline中通过
num_return_sequences_in_group参数控制 3
Wiki pages you might want to explore:
Citations
File: roll/distributed/scheduler/generate_scheduler.py (L487-549)
python
def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
"""
从dataset里,按给定策略sample batch
1. 常规无过滤
2. 动态过滤
"""
self.batch_size = batch_size
self.reset_status()
self.running = True
prompt_id_counter = itertools.count()
self.generation_config = copy.deepcopy(data.meta_info["generation_config"])
num_return_sequences = self.generation_config["num_return_sequences"]
while True:
if (
sum([len(v) for v in list(self.completed_buffers.values())[:]])
>= self.batch_size * num_return_sequences
):
self.running = False
break
self.check_worker_alive(self.actor_cluster)
self.check_response_callback()
if not self.check_send_new_request():
time.sleep(1)
continue
# get a query from dataset
prompt_id = next(prompt_id_counter)
dataset_item = self.get_next_dataset_item()
if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
prompt_digest = hashlib.md5(
(dataset_item.get('prompt', '') + dataset_item.get('messages', '')).encode()
).digest()
domain = dataset_item.get("domain", "default")
collect_data = self.collect_fn([dataset_item])
request_data: DataProto = DataProto.from_single_dict(collect_data, meta_info=data.meta_info)
# replica, redundancy
request_data_list = self.expand_requests(request_data)
dp_rank = next(self.get_available_dp_rank())
with self.lock:
self.prompt_use_count += 1
self.running_prompts += 1
for req in request_data_list:
# get a available worker, 需要控制max_running_request, 当前策略会始终保持worker的满载
request_id = ray.get(self.request_counter.get_value.remote())
req.meta_info["request_id"] = f"{request_id}"
req.meta_info["response_callback_fn"] = self.response_callback_fn
self.request_id_2_prompt_id[req.meta_info["request_id"]] = prompt_id
self.request_id_2_dp_rank[req.meta_info["request_id"]] = dp_rank
self.prompt_id_2_request_ids[prompt_id].add(req.meta_info["request_id"]) # 用于replica情况
if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
self.prompt_id_2_hash_str[prompt_id] = base64.urlsafe_b64encode(prompt_digest).decode().rstrip('=') # prompt_id 对应 unique prompt
self.requests_buffers[req.meta_info["request_id"]] = req
ray.get(
self.actor_cluster.workers[dp_rank].add_request.remote(
command=GenerateRequestType.ADD, data=req
)
)
req.meta_info.pop("response_callback_fn")
self.load_balance_coordinator[dp_rank] += 1
self.dp_fetch_count[dp_rank] += 1
File: roll/distributed/scheduler/generate_scheduler.py (L741-764)
python
def expand_requests(self, data: DataProto):
"""
replica, 以及redundancy
"""
generate_opt_level = self.pipeline_config.generate_opt_level
is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
num_return_sequences = self.generation_config["num_return_sequences"]
assert generate_opt_level > 0, (
f"generate_opt_level {generate_opt_level} should > 0, " f"in dynamic sampling scheduler."
)
assert "generation_config" in data.meta_info, f"data {data.meta_info} should have key 'generation_config'"
generation_config = data.meta_info["generation_config"]
target_requests = []
if is_num_return_sequences_expand:
generation_config["num_return_sequences"] = 1
for _ in range(num_return_sequences):
target_requests.append(copy.deepcopy(data))
else:
generation_config["num_return_sequences"] = num_return_sequences
target_requests.append(copy.deepcopy(data))
return target_requests
File: roll/configs/base_config.py (L23-26)
python
is_num_return_sequences_expand: bool = field(
default=False,
metadata={"help": "whether replicate `num_return_sequences` times in prompts or not."}
)
File: roll/pipeline/rlvr/rlvr_config.py (L90-93)
python
num_return_sequences_in_group: int = field(
default=1,
metadata={"help": "The number of return sequences in one group, used in generation_args."}
)
File: roll/distributed/scheduler/async_generate_scheduler.py (L804-824)
python
def expand_requests(self, data: DataProto):
generate_opt_level = self.pipeline_config.generate_opt_level
is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
num_return_sequences = self.generation_config["num_return_sequences"]
assert generate_opt_level > 0, (
f"generate_opt_level {generate_opt_level} should > 0, " f"in dynamic sampling scheduler."
)
assert "generation_config" in data.meta_info, f"data {data.meta_info} should have key 'generation_config'"
generation_config = data.meta_info["generation_config"]
target_requests = []
if is_num_return_sequences_expand:
generation_config["num_return_sequences"] = 1
for _ in range(num_return_sequences):
target_requests.append(copy.deepcopy(data))
else:
generation_config["num_return_sequences"] = num_return_sequences
target_requests.append(copy.deepcopy(data))
return target_requests