【RL】slime创建actor的流程

SharedStorage的作用

SharedStorage是ROLL框架中的分布式键值存储组件,基于Ray Actor实现,用于在集群中的所有worker之间共享数据 1

核心功能

1. 分布式键值存储

SharedStorage提供简单的put/get接口 2

python 复制代码
def put(self, key, data):
    ref = ray.put(data)
    self._storage[key] = ref

def get(self, key):
    ref = self._storage.get(key)
    if ref is None:
        logger.warning(f"{key} is not found in storage")
        return None
    return ray.get(ref)

使用示例

示例1:Worker初始化时的地址协调

在Worker初始化时,rank 0会将master地址和端口存储到SharedStorage,其他worker从中读取 3

python 复制代码
if self.rank == 0:
    master_addr = self.get_node_ip()
    master_port = str(self.get_free_port())
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port

# 所有worker都会将地址信息存储到SharedStorage
self.shared_storage.put.remote(
    self.cluster_name, {"MASTER_ADDR": self.master_addr, "MASTER_PORT": self.master_port}
)

示例2:端口分配协调

在分配端口时,worker使用SharedStorage避免端口冲突 4

python 复制代码
# 检查端口是否已被占用
master_addr_port_key = f"MASTER_ADDR_PORT:{master_addr}:{master_port}"
if ray.get(shared_storage.get.remote(master_addr_port_key)) is None:
    ray.get(shared_storage.put.remote(master_addr_port_key, True))
    break  # 端口可用,占用并退出循环

示例3:模型下载缓存

在模型下载时,使用SharedStorage缓存已下载的模型路径,避免重复下载 5

python 复制代码
@model_path_cache
def download_model(model_name_or_path: str, local_dir: Optional[str] = None):
    # 检查缓存
    cached_path = ray.get(shared_storage.get.remote(key=f"{node_ip}:{model_name_or_path}"))
    if cached_path is None or not os.path.exists(cached_path):
        cached_path = func(model_name_or_path, local_dir)
        ray.get(shared_storage.put.remote(key=f"{node_ip}:{model_name_or_path}", data=cached_path))
    return cached_path

Notes

  • SharedStorage是单例Actor,在ROLL命名空间下运行 6
  • 主要用于分布式协调和缓存,减少重复计算和网络请求
  • 支持任意可序列化的数据存储
  • 在ROLL框架的分布式训练中起到关键的协调作用

Wiki pages you may want to explore:

Wiki pages you might want to explore:

Citations

File: roll/distributed/scheduler/storage.py (L8-23)

python 复制代码
@ray.remote
class SharedStorage:

    def __init__(self):
        self._storage = {}

    def put(self, key, data):
        ref = ray.put(data)
        self._storage[key] = ref

    def get(self, key):
        ref = self._storage.get(key)
        if ref is None:
            logger.warning(f"{key} is not found in storage")
            return None
        return ray.get(ref)

File: roll/distributed/executor/worker.py (L59-69)

python 复制代码
        if self.rank == 0:
            master_addr = self.get_node_ip()
            master_port = str(self.get_free_port())
            os.environ["MASTER_ADDR"] = master_addr
            os.environ["MASTER_PORT"] = master_port

        self.master_addr = os.environ["MASTER_ADDR"]
        self.master_port = int(os.environ["MASTER_PORT"])
        self.shared_storage.put.remote(
            self.cluster_name, {"MASTER_ADDR": self.master_addr, "MASTER_PORT": self.master_port}
        )

File: roll/distributed/executor/worker.py (L96-113)

python 复制代码
    def get_free_port():
        shared_storage = SharedStorage.options(
            name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
        ).remote()
        master_addr = Worker.get_node_ip()
        max_retry_count = int(os.environ.get("MAX_PORT_RETRY_COUNT", 1000))
        retry_count = 0
        master_port = collect_free_port()
        while retry_count < max_retry_count:
            master_addr_port_key = f"MASTER_ADDR_PORT:{master_addr}:{master_port}"
            if ray.get(shared_storage.get.remote(master_addr_port_key)) is None:
                ray.get(shared_storage.put.remote(master_addr_port_key, True))
                break
            master_port = collect_free_port()
            retry_count += 1
        if retry_count >= max_retry_count:
            raise RuntimeError(f"Can not allocate unique MASTER_PORT on {master_addr}.")
        return master_port

File: roll/utils/checkpoint_manager.py (L40-53)

python 复制代码
def model_path_cache(func):
    node_ip = get_node_ip()
    def wrapper(model_name_or_path: str, local_dir: Optional[str] = None):
        global shared_storage
        if shared_storage is None:
            shared_storage = SharedStorage.options(
                name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
            ).remote()
        cached_path = ray.get(shared_storage.get.remote(key=f"{node_ip}:{model_name_or_path}"))
        if cached_path is None or not os.path.exists(cached_path):
            cached_path = func(model_name_or_path, local_dir)
            ray.get(shared_storage.put.remote(key=f"{node_ip}:{model_name_or_path}", data=cached_path))
        return cached_path
    return wrapper

File: roll/utils/constants.py (L5-6)

python 复制代码
RAY_NAMESPACE = "roll"
STORAGE_NAME = "SHARED_STORAGE_ACTOR"

SharedStorage.remote() 的首次创建位置

SharedStorage.remote() 第一次是在第一个初始化的Worker进程 中创建的,具体位置是 Worker.__init__() 方法 1

创建机制分析

1. Worker初始化时的创建

每个Worker在初始化时都会尝试获取SharedStorage引用 1

python 复制代码
self.shared_storage = SharedStorage.options(
    name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()

2. get_if_exists=True 的作用

关键参数 get_if_exists=True 控制创建行为 1

  • 第一个Worker:SharedStorage不存在,创建新的Actor实例
  • 后续Worker:SharedStorage已存在,获取现有实例的引用

3. 其他获取位置

SharedStorage在其他地方也会被获取,但都是基于已存在的实例:

  • 模型下载缓存 :在 model_path_cache 装饰器中获取 2
  • 端口分配 :在 get_free_port() 静态方法中获取 3

4. 执行时序

Ray Cluster Worker 1 (第一个) Worker 2 (后续) SharedStorage 创建Worker进程 init() SharedStorage.remote() [get_if_exists=True] 创建新实例 创建Worker进程 init() SharedStorage.remote() [get_if_exists=True] 返回现有实例 Ray Cluster Worker 1 (第一个) Worker 2 (后续) SharedStorage

Notes

  • SharedStorage是Ray命名Actor,在 ROLL_NAMESPACE 下运行 4
  • 使用单例模式确保整个集群只有一个SharedStorage实例
  • 第一个创建的Worker进程决定了SharedStorage的运行位置

Wiki pages you may want to explore:

Wiki pages you might want to explore:

Citations

File: roll/distributed/executor/worker.py (L55-57)

python 复制代码
        self.shared_storage = SharedStorage.options(
            name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
        ).remote()

File: roll/distributed/executor/worker.py (L97-99)

python 复制代码
        shared_storage = SharedStorage.options(
            name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
        ).remote()

File: roll/utils/checkpoint_manager.py (L44-47)

python 复制代码
        if shared_storage is None:
            shared_storage = SharedStorage.options(
                name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
            ).remote()

File: roll/utils/constants.py (L5-6)

python 复制代码
RAY_NAMESPACE = "roll"
STORAGE_NAME = "SHARED_STORAGE_ACTOR"

Actor 架构概览

slime 框架中确实有 Megatron Train ActorRollout Actor ,但没有独立的 Reward Actor

现有 Actor 结构

1. Megatron Train Actor

MegatronTrainRayActor 是训练的主要 actor 1

python 复制代码
class MegatronTrainRayActor(TrainRayActor):
    def init(self, args: Namespace, role: str, with_ref: bool = False) -> int | None:
2. Rollout Manager Actor

RolloutManager 是 rollout 的核心 actor 2

python 复制代码
@ray.remote
class RolloutManager:
    """The class to run rollout and convert rollout data to training data."""
3. SGLang Engine Actors

SGLangEngine 是实际的推理引擎 actors 3

python 复制代码
RolloutRayActor = ray.remote(SGLangEngine)

Reward 处理方式

Reward 不是独立的 actor,而是集成在 rollout 流程中:

1. 在 RolloutManager 中处理

RolloutManager 包含 reward 后处理逻辑 4

python 复制代码
def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
    if self.custom_reward_post_process_func is not None:
        return self.custom_reward_post_process_func(self.args, samples)
    
    raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
2. 在生成过程中计算

Reward 在 generate_rollout_async 函数中通过 generate_and_rm 计算 5

python 复制代码
tasks.append(
    generate_and_rm(
        args,
        sample,
        sampling_params=sampling_params,
        evaluation=True,
    )
)

Actor 创建流程

train.py create_rollout_manager RolloutManager.remote init_rollout_engines SGLangEngine.remote actors create_training_models MegatronTrainRayActor actors

Notes

  • RolloutManager 是唯一的 rollout actor,负责协调整个数据生成流程
  • Reward 计算集成在 rollout 生成过程中,不是独立的 actor
  • SGLangEngine actors 负责实际的推理,由 RolloutManager 管理
  • 这种设计简化了架构,将 reward 作为数据生成流水线的一部分

Wiki pages you might want to explore:

Wiki pages you might want to explore:

Citations

File: slime/backends/megatron_utils/actor.py (L46-53)

python 复制代码
class MegatronTrainRayActor(TrainRayActor):
    @with_defer(lambda: Timer().start("train_wait"))
    def init(
        self,
        args: Namespace,
        role: str,
        with_ref: bool = False,
    ) -> int | None:

File: slime/ray/rollout.py (L37-41)

python 复制代码
@ray.remote
class RolloutManager:
    """The class to run rollout and convert rollout data to training data."""

    def __init__(self, args, pg):

File: slime/ray/rollout.py (L188-193)

python 复制代码
    def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
        if self.custom_reward_post_process_func is not None:
            return self.custom_reward_post_process_func(self.args, samples)

        raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
        if (

File: slime/ray/rollout.py (L346-347)

python 复制代码
    RolloutRayActor = ray.remote(SGLangEngine)

File: slime/rollout/sglang_rollout.py (L527-532)

python 复制代码
                generate_and_rm(
                    args,
                    sample,
                    sampling_params=sampling_params,
                    evaluation=True,
                )
相关推荐
YongCheng_Liang2 分钟前
从零开始学 Python:自动化 / 运维开发实战(核心库 + 3 大实战场景)
python·自动化·运维开发
鸽芷咕12 分钟前
为什么越来越多开发者转向 CANN 仓库中的 Python 自动化方案?
python·microsoft·自动化·cann
秋邱13 分钟前
用 Python 写出 C++ 的性能?用CANN中PyPTO 算子开发硬核上手指南
开发语言·c++·python
wazmlp0018873691 小时前
python第三次作业
开发语言·python
深蓝电商API1 小时前
住宅代理与数据中心代理在爬虫中的选择
爬虫·python
历程里程碑2 小时前
普通数组----合并区间
java·数据结构·python·算法·leetcode·职场和发展·tornado
weixin_395448912 小时前
mult_yolov5_post_copy.c_cursor_0205
c语言·python·yolo
执风挽^3 小时前
Python基础编程题2
开发语言·python·算法·visual studio code
纤纡.3 小时前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
kjkdd3 小时前
6.1 核心组件(Agent)
python·ai·语言模型·langchain·ai编程