【RL]expand_requests干啥的

expand_requests 函数用于根据配置参数扩展和复制生成请求,支持两种不同的采样策略。

函数作用

expand_requests 根据 is_num_return_sequences_expand 配置决定如何处理多采样请求 1

扩展模式(is_num_return_sequences_expand=True

  • 将每个请求的 num_return_sequences 设为 1
  • 复制 num_return_sequences 个相同的请求
  • 每个请求独立生成 1 个响应

非扩展模式(is_num_return_sequences_expand=False

  • 直接设置 num_return_sequences 为指定值
  • 只发送 1 个请求,生成多个响应

代码实现

python 复制代码
def expand_requests(self, data: DataProto):
    generate_opt_level = self.pipeline_config.generate_opt_level
    is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
    num_return_sequences = self.generation_config["num_return_sequences"]

    target_requests = []
    if is_num_return_sequences_expand:
        generation_config["num_return_sequences"] = 1
        for _ in range(num_return_sequences):
            target_requests.append(copy.deepcopy(data))
    else:
        generation_config["num_return_sequences"] = num_return_sequences
        target_requests.append(copy.deepcopy(data))

    return target_requests

配置参数

  • is_num_return_sequences_expand: 控制是否启用扩展模式 2
  • num_return_sequences: 每个prompt需要生成的响应数量 3

在调度流程中的位置

DynamicSamplingScheduler.get_batch 中调用 expand_requests 4

  1. 从数据集获取 prompt
  2. 调用 expand_requests 扩展请求
  3. 为每个扩展的请求分配唯一的 request_id
  4. 发送给 SGLang 进行推理

设计优势

  • 灵活性: 支持两种采样策略,适应不同的推理引擎特性
  • 负载均衡: 扩展模式可以将请求分散到不同的worker
  • 性能优化: 某些引擎在单采样模式下性能更好

Notes

  • 异步调度器中有相同的实现 5
  • 要求 generate_opt_level > 0,即必须使用调度器模式 6
  • 在RLVR Pipeline中通过 num_return_sequences_in_group 参数控制 3

Wiki pages you might want to explore:

Citations

File: roll/distributed/scheduler/generate_scheduler.py (L487-549)

python 复制代码
    def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
        """
        从dataset里,按给定策略sample batch
        1. 常规无过滤
        2. 动态过滤
        """
        self.batch_size = batch_size
        self.reset_status()
        self.running = True
        prompt_id_counter = itertools.count()
        self.generation_config = copy.deepcopy(data.meta_info["generation_config"])
        num_return_sequences = self.generation_config["num_return_sequences"]
        while True:
            if (
                sum([len(v) for v in list(self.completed_buffers.values())[:]])
                >= self.batch_size * num_return_sequences
            ):
                self.running = False
                break
            self.check_worker_alive(self.actor_cluster)
            self.check_response_callback()
            if not self.check_send_new_request():
                time.sleep(1)
                continue

            # get a query from dataset
            prompt_id = next(prompt_id_counter)
            dataset_item = self.get_next_dataset_item()
            if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
                prompt_digest = hashlib.md5(
                    (dataset_item.get('prompt', '') + dataset_item.get('messages', '')).encode()
                ).digest()
            domain = dataset_item.get("domain", "default")
            collect_data = self.collect_fn([dataset_item])
            request_data: DataProto = DataProto.from_single_dict(collect_data, meta_info=data.meta_info)

            # replica, redundancy
            request_data_list = self.expand_requests(request_data)

            dp_rank = next(self.get_available_dp_rank())
            with self.lock:
                self.prompt_use_count += 1
                self.running_prompts += 1
                for req in request_data_list:
                    # get a available worker, 需要控制max_running_request, 当前策略会始终保持worker的满载
                    request_id = ray.get(self.request_counter.get_value.remote())
                    req.meta_info["request_id"] = f"{request_id}"
                    req.meta_info["response_callback_fn"] = self.response_callback_fn
                    self.request_id_2_prompt_id[req.meta_info["request_id"]] = prompt_id
                    self.request_id_2_dp_rank[req.meta_info["request_id"]] = dp_rank
                    self.prompt_id_2_request_ids[prompt_id].add(req.meta_info["request_id"])  # 用于replica情况
                    if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
                        self.prompt_id_2_hash_str[prompt_id] = base64.urlsafe_b64encode(prompt_digest).decode().rstrip('=') # prompt_id 对应 unique prompt
                    self.requests_buffers[req.meta_info["request_id"]] = req
                    ray.get(
                        self.actor_cluster.workers[dp_rank].add_request.remote(
                            command=GenerateRequestType.ADD, data=req
                        )
                    )
                    req.meta_info.pop("response_callback_fn")
                    self.load_balance_coordinator[dp_rank] += 1
                    self.dp_fetch_count[dp_rank] += 1

File: roll/distributed/scheduler/generate_scheduler.py (L741-764)

python 复制代码
    def expand_requests(self, data: DataProto):
        """
        replica, 以及redundancy
        """
        generate_opt_level = self.pipeline_config.generate_opt_level
        is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
        num_return_sequences = self.generation_config["num_return_sequences"]

        assert generate_opt_level > 0, (
            f"generate_opt_level {generate_opt_level} should > 0, " f"in dynamic sampling scheduler."
        )
        assert "generation_config" in data.meta_info, f"data {data.meta_info} should have key 'generation_config'"
        generation_config = data.meta_info["generation_config"]

        target_requests = []
        if is_num_return_sequences_expand:
            generation_config["num_return_sequences"] = 1
            for _ in range(num_return_sequences):
                target_requests.append(copy.deepcopy(data))
        else:
            generation_config["num_return_sequences"] = num_return_sequences
            target_requests.append(copy.deepcopy(data))

        return target_requests

File: roll/configs/base_config.py (L23-26)

python 复制代码
    is_num_return_sequences_expand: bool = field(
        default=False,
        metadata={"help": "whether replicate `num_return_sequences` times in prompts or not."}
    )

File: roll/pipeline/rlvr/rlvr_config.py (L90-93)

python 复制代码
    num_return_sequences_in_group: int = field(
        default=1,
        metadata={"help": "The number of return sequences in one group, used in generation_args."}
    )

File: roll/distributed/scheduler/async_generate_scheduler.py (L804-824)

python 复制代码
    def expand_requests(self, data: DataProto):
        generate_opt_level = self.pipeline_config.generate_opt_level
        is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
        num_return_sequences = self.generation_config["num_return_sequences"]

        assert generate_opt_level > 0, (
            f"generate_opt_level {generate_opt_level} should > 0, " f"in dynamic sampling scheduler."
        )
        assert "generation_config" in data.meta_info, f"data {data.meta_info} should have key 'generation_config'"
        generation_config = data.meta_info["generation_config"]

        target_requests = []
        if is_num_return_sequences_expand:
            generation_config["num_return_sequences"] = 1
            for _ in range(num_return_sequences):
                target_requests.append(copy.deepcopy(data))
        else:
            generation_config["num_return_sequences"] = num_return_sequences
            target_requests.append(copy.deepcopy(data))

        return target_requests
相关推荐
Zeku1 小时前
20251125 - 为什么Linux是非实时操作系统?
linux·服务器
kesifan1 小时前
JAVA线程的建立方法
java·开发语言·python
周杰伦fans1 小时前
C#中ValueTask
开发语言·c#
2***c4351 小时前
nginx服务器实现上传文件功能_使用nginx-upload-module模块
服务器·前端·nginx
TechMasterPlus1 小时前
调用千问大模型服务
人工智能·python
菠菠萝宝1 小时前
【Java手搓OpenManus】-5- 工具系统设计
java·开发语言·人工智能·openai·agent·manus
kyle~1 小时前
数据结构---堆(Heap)
服务器·开发语言·数据结构·c++
apocelipes1 小时前
Linux的binfmt_misc机制
linux·c语言·c++·python·golang·linux编程·开发工具和环境
中工钱袋1 小时前
Java Stream 流详解
java·windows·python