【RL]expand_requests干啥的

expand_requests 函数用于根据配置参数扩展和复制生成请求,支持两种不同的采样策略。

函数作用

expand_requests 根据 is_num_return_sequences_expand 配置决定如何处理多采样请求 1

扩展模式(is_num_return_sequences_expand=True

  • 将每个请求的 num_return_sequences 设为 1
  • 复制 num_return_sequences 个相同的请求
  • 每个请求独立生成 1 个响应

非扩展模式(is_num_return_sequences_expand=False

  • 直接设置 num_return_sequences 为指定值
  • 只发送 1 个请求,生成多个响应

代码实现

python 复制代码
def expand_requests(self, data: DataProto):
    generate_opt_level = self.pipeline_config.generate_opt_level
    is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
    num_return_sequences = self.generation_config["num_return_sequences"]

    target_requests = []
    if is_num_return_sequences_expand:
        generation_config["num_return_sequences"] = 1
        for _ in range(num_return_sequences):
            target_requests.append(copy.deepcopy(data))
    else:
        generation_config["num_return_sequences"] = num_return_sequences
        target_requests.append(copy.deepcopy(data))

    return target_requests

配置参数

  • is_num_return_sequences_expand: 控制是否启用扩展模式 2
  • num_return_sequences: 每个prompt需要生成的响应数量 3

在调度流程中的位置

DynamicSamplingScheduler.get_batch 中调用 expand_requests 4

  1. 从数据集获取 prompt
  2. 调用 expand_requests 扩展请求
  3. 为每个扩展的请求分配唯一的 request_id
  4. 发送给 SGLang 进行推理

设计优势

  • 灵活性: 支持两种采样策略,适应不同的推理引擎特性
  • 负载均衡: 扩展模式可以将请求分散到不同的worker
  • 性能优化: 某些引擎在单采样模式下性能更好

Notes

  • 异步调度器中有相同的实现 5
  • 要求 generate_opt_level > 0,即必须使用调度器模式 6
  • 在RLVR Pipeline中通过 num_return_sequences_in_group 参数控制 3

Wiki pages you might want to explore:

Citations

File: roll/distributed/scheduler/generate_scheduler.py (L487-549)

python 复制代码
    def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
        """
        从dataset里,按给定策略sample batch
        1. 常规无过滤
        2. 动态过滤
        """
        self.batch_size = batch_size
        self.reset_status()
        self.running = True
        prompt_id_counter = itertools.count()
        self.generation_config = copy.deepcopy(data.meta_info["generation_config"])
        num_return_sequences = self.generation_config["num_return_sequences"]
        while True:
            if (
                sum([len(v) for v in list(self.completed_buffers.values())[:]])
                >= self.batch_size * num_return_sequences
            ):
                self.running = False
                break
            self.check_worker_alive(self.actor_cluster)
            self.check_response_callback()
            if not self.check_send_new_request():
                time.sleep(1)
                continue

            # get a query from dataset
            prompt_id = next(prompt_id_counter)
            dataset_item = self.get_next_dataset_item()
            if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
                prompt_digest = hashlib.md5(
                    (dataset_item.get('prompt', '') + dataset_item.get('messages', '')).encode()
                ).digest()
            domain = dataset_item.get("domain", "default")
            collect_data = self.collect_fn([dataset_item])
            request_data: DataProto = DataProto.from_single_dict(collect_data, meta_info=data.meta_info)

            # replica, redundancy
            request_data_list = self.expand_requests(request_data)

            dp_rank = next(self.get_available_dp_rank())
            with self.lock:
                self.prompt_use_count += 1
                self.running_prompts += 1
                for req in request_data_list:
                    # get a available worker, 需要控制max_running_request, 当前策略会始终保持worker的满载
                    request_id = ray.get(self.request_counter.get_value.remote())
                    req.meta_info["request_id"] = f"{request_id}"
                    req.meta_info["response_callback_fn"] = self.response_callback_fn
                    self.request_id_2_prompt_id[req.meta_info["request_id"]] = prompt_id
                    self.request_id_2_dp_rank[req.meta_info["request_id"]] = dp_rank
                    self.prompt_id_2_request_ids[prompt_id].add(req.meta_info["request_id"])  # 用于replica情况
                    if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
                        self.prompt_id_2_hash_str[prompt_id] = base64.urlsafe_b64encode(prompt_digest).decode().rstrip('=') # prompt_id 对应 unique prompt
                    self.requests_buffers[req.meta_info["request_id"]] = req
                    ray.get(
                        self.actor_cluster.workers[dp_rank].add_request.remote(
                            command=GenerateRequestType.ADD, data=req
                        )
                    )
                    req.meta_info.pop("response_callback_fn")
                    self.load_balance_coordinator[dp_rank] += 1
                    self.dp_fetch_count[dp_rank] += 1

File: roll/distributed/scheduler/generate_scheduler.py (L741-764)

python 复制代码
    def expand_requests(self, data: DataProto):
        """
        replica, 以及redundancy
        """
        generate_opt_level = self.pipeline_config.generate_opt_level
        is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
        num_return_sequences = self.generation_config["num_return_sequences"]

        assert generate_opt_level > 0, (
            f"generate_opt_level {generate_opt_level} should > 0, " f"in dynamic sampling scheduler."
        )
        assert "generation_config" in data.meta_info, f"data {data.meta_info} should have key 'generation_config'"
        generation_config = data.meta_info["generation_config"]

        target_requests = []
        if is_num_return_sequences_expand:
            generation_config["num_return_sequences"] = 1
            for _ in range(num_return_sequences):
                target_requests.append(copy.deepcopy(data))
        else:
            generation_config["num_return_sequences"] = num_return_sequences
            target_requests.append(copy.deepcopy(data))

        return target_requests

File: roll/configs/base_config.py (L23-26)

python 复制代码
    is_num_return_sequences_expand: bool = field(
        default=False,
        metadata={"help": "whether replicate `num_return_sequences` times in prompts or not."}
    )

File: roll/pipeline/rlvr/rlvr_config.py (L90-93)

python 复制代码
    num_return_sequences_in_group: int = field(
        default=1,
        metadata={"help": "The number of return sequences in one group, used in generation_args."}
    )

File: roll/distributed/scheduler/async_generate_scheduler.py (L804-824)

python 复制代码
    def expand_requests(self, data: DataProto):
        generate_opt_level = self.pipeline_config.generate_opt_level
        is_num_return_sequences_expand = self.pipeline_config.is_num_return_sequences_expand
        num_return_sequences = self.generation_config["num_return_sequences"]

        assert generate_opt_level > 0, (
            f"generate_opt_level {generate_opt_level} should > 0, " f"in dynamic sampling scheduler."
        )
        assert "generation_config" in data.meta_info, f"data {data.meta_info} should have key 'generation_config'"
        generation_config = data.meta_info["generation_config"]

        target_requests = []
        if is_num_return_sequences_expand:
            generation_config["num_return_sequences"] = 1
            for _ in range(num_return_sequences):
                target_requests.append(copy.deepcopy(data))
        else:
            generation_config["num_return_sequences"] = num_return_sequences
            target_requests.append(copy.deepcopy(data))

        return target_requests
相关推荐
不染尘.2 分钟前
TCP客户服务器编程模型
linux·服务器·网络·网络协议·tcp/ip·计算机网络·ssh
Trouvaille ~3 分钟前
【Linux】目录、路径与软硬链接:Linux文件组织的奥秘
linux·运维·服务器·chrome·文件系统·软硬链接·路径缓存
乾元5 分钟前
LLM 自动生成安全基线与等保合规初稿——把“网络工程事实”转译为“可审计的制度语言”
运维·网络·人工智能·python·安全·架构
全栈陈序员7 分钟前
【Python】基础语法入门(二十四)——文件与目录操作进阶:安全、高效地处理本地数据
开发语言·人工智能·python·学习
是有头发的程序猿10 分钟前
Python爬虫实战:面向对象编程构建高可维护的1688商品数据采集系统
开发语言·爬虫·python
lkbhua莱克瓦2412 分钟前
多线程综合练习3
java·开发语言·多线程·githup
摸鱼仙人~13 分钟前
企业级 RAG 问答系统开发上线流程分析
后端·python·rag·检索
serve the people19 分钟前
tensorflow tf.nn.softmax 核心解析
人工智能·python·tensorflow
郝学胜-神的一滴22 分钟前
封装OpenGL的Shader相关类:从理论到实践
开发语言·c++·程序人生·游戏·图形渲染
癫狂的兔子27 分钟前
【BUG】【Python】eval()报错
python·bug