SharedStorage的作用
SharedStorage是ROLL框架中的分布式键值存储组件,基于Ray Actor实现,用于在集群中的所有worker之间共享数据 1 。
核心功能
1. 分布式键值存储
SharedStorage提供简单的put/get接口 2 :
python
def put(self, key, data):
ref = ray.put(data)
self._storage[key] = ref
def get(self, key):
ref = self._storage.get(key)
if ref is None:
logger.warning(f"{key} is not found in storage")
return None
return ray.get(ref)
使用示例
示例1:Worker初始化时的地址协调
在Worker初始化时,rank 0会将master地址和端口存储到SharedStorage,其他worker从中读取 3 :
python
if self.rank == 0:
master_addr = self.get_node_ip()
master_port = str(self.get_free_port())
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
# 所有worker都会将地址信息存储到SharedStorage
self.shared_storage.put.remote(
self.cluster_name, {"MASTER_ADDR": self.master_addr, "MASTER_PORT": self.master_port}
)
示例2:端口分配协调
在分配端口时,worker使用SharedStorage避免端口冲突 4 :
python
# 检查端口是否已被占用
master_addr_port_key = f"MASTER_ADDR_PORT:{master_addr}:{master_port}"
if ray.get(shared_storage.get.remote(master_addr_port_key)) is None:
ray.get(shared_storage.put.remote(master_addr_port_key, True))
break # 端口可用,占用并退出循环
示例3:模型下载缓存
在模型下载时,使用SharedStorage缓存已下载的模型路径,避免重复下载 5 :
python
@model_path_cache
def download_model(model_name_or_path: str, local_dir: Optional[str] = None):
# 检查缓存
cached_path = ray.get(shared_storage.get.remote(key=f"{node_ip}:{model_name_or_path}"))
if cached_path is None or not os.path.exists(cached_path):
cached_path = func(model_name_or_path, local_dir)
ray.get(shared_storage.put.remote(key=f"{node_ip}:{model_name_or_path}", data=cached_path))
return cached_path
Notes
- SharedStorage是单例Actor,在ROLL命名空间下运行 6
- 主要用于分布式协调和缓存,减少重复计算和网络请求
- 支持任意可序列化的数据存储
- 在ROLL框架的分布式训练中起到关键的协调作用
Wiki pages you may want to explore:
Wiki pages you might want to explore:
Citations
File: roll/distributed/scheduler/storage.py (L8-23)
python
@ray.remote
class SharedStorage:
def __init__(self):
self._storage = {}
def put(self, key, data):
ref = ray.put(data)
self._storage[key] = ref
def get(self, key):
ref = self._storage.get(key)
if ref is None:
logger.warning(f"{key} is not found in storage")
return None
return ray.get(ref)
File: roll/distributed/executor/worker.py (L59-69)
python
if self.rank == 0:
master_addr = self.get_node_ip()
master_port = str(self.get_free_port())
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
self.master_addr = os.environ["MASTER_ADDR"]
self.master_port = int(os.environ["MASTER_PORT"])
self.shared_storage.put.remote(
self.cluster_name, {"MASTER_ADDR": self.master_addr, "MASTER_PORT": self.master_port}
)
File: roll/distributed/executor/worker.py (L96-113)
python
def get_free_port():
shared_storage = SharedStorage.options(
name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()
master_addr = Worker.get_node_ip()
max_retry_count = int(os.environ.get("MAX_PORT_RETRY_COUNT", 1000))
retry_count = 0
master_port = collect_free_port()
while retry_count < max_retry_count:
master_addr_port_key = f"MASTER_ADDR_PORT:{master_addr}:{master_port}"
if ray.get(shared_storage.get.remote(master_addr_port_key)) is None:
ray.get(shared_storage.put.remote(master_addr_port_key, True))
break
master_port = collect_free_port()
retry_count += 1
if retry_count >= max_retry_count:
raise RuntimeError(f"Can not allocate unique MASTER_PORT on {master_addr}.")
return master_port
File: roll/utils/checkpoint_manager.py (L40-53)
python
def model_path_cache(func):
node_ip = get_node_ip()
def wrapper(model_name_or_path: str, local_dir: Optional[str] = None):
global shared_storage
if shared_storage is None:
shared_storage = SharedStorage.options(
name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()
cached_path = ray.get(shared_storage.get.remote(key=f"{node_ip}:{model_name_or_path}"))
if cached_path is None or not os.path.exists(cached_path):
cached_path = func(model_name_or_path, local_dir)
ray.get(shared_storage.put.remote(key=f"{node_ip}:{model_name_or_path}", data=cached_path))
return cached_path
return wrapper
File: roll/utils/constants.py (L5-6)
python
RAY_NAMESPACE = "roll"
STORAGE_NAME = "SHARED_STORAGE_ACTOR"
SharedStorage.remote() 的首次创建位置
SharedStorage.remote() 第一次是在第一个初始化的Worker进程 中创建的,具体位置是 Worker.__init__() 方法 1 。
创建机制分析
1. Worker初始化时的创建
每个Worker在初始化时都会尝试获取SharedStorage引用 1 :
python
self.shared_storage = SharedStorage.options(
name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()
2. get_if_exists=True 的作用
关键参数 get_if_exists=True 控制创建行为 1 :
- 第一个Worker:SharedStorage不存在,创建新的Actor实例
- 后续Worker:SharedStorage已存在,获取现有实例的引用
3. 其他获取位置
SharedStorage在其他地方也会被获取,但都是基于已存在的实例:
4. 执行时序
Ray Cluster Worker 1 (第一个) Worker 2 (后续) SharedStorage 创建Worker进程 init() SharedStorage.remote() [get_if_exists=True] 创建新实例 创建Worker进程 init() SharedStorage.remote() [get_if_exists=True] 返回现有实例 Ray Cluster Worker 1 (第一个) Worker 2 (后续) SharedStorage
Notes
- SharedStorage是Ray命名Actor,在
ROLL_NAMESPACE下运行 4 - 使用单例模式确保整个集群只有一个SharedStorage实例
- 第一个创建的Worker进程决定了SharedStorage的运行位置
Wiki pages you may want to explore:
Wiki pages you might want to explore:
Citations
File: roll/distributed/executor/worker.py (L55-57)
python
self.shared_storage = SharedStorage.options(
name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()
File: roll/distributed/executor/worker.py (L97-99)
python
shared_storage = SharedStorage.options(
name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()
File: roll/utils/checkpoint_manager.py (L44-47)
python
if shared_storage is None:
shared_storage = SharedStorage.options(
name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE
).remote()
File: roll/utils/constants.py (L5-6)
python
RAY_NAMESPACE = "roll"
STORAGE_NAME = "SHARED_STORAGE_ACTOR"
Actor 架构概览
slime 框架中确实有 Megatron Train Actor 和 Rollout Actor ,但没有独立的 Reward Actor。
现有 Actor 结构
1. Megatron Train Actor
MegatronTrainRayActor 是训练的主要 actor 1 :
python
class MegatronTrainRayActor(TrainRayActor):
def init(self, args: Namespace, role: str, with_ref: bool = False) -> int | None:
2. Rollout Manager Actor
RolloutManager 是 rollout 的核心 actor 2 :
python
@ray.remote
class RolloutManager:
"""The class to run rollout and convert rollout data to training data."""
3. SGLang Engine Actors
SGLangEngine 是实际的推理引擎 actors 3 :
python
RolloutRayActor = ray.remote(SGLangEngine)
Reward 处理方式
Reward 不是独立的 actor,而是集成在 rollout 流程中:
1. 在 RolloutManager 中处理
RolloutManager 包含 reward 后处理逻辑 4 :
python
def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
if self.custom_reward_post_process_func is not None:
return self.custom_reward_post_process_func(self.args, samples)
raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
2. 在生成过程中计算
Reward 在 generate_rollout_async 函数中通过 generate_and_rm 计算 5 :
python
tasks.append(
generate_and_rm(
args,
sample,
sampling_params=sampling_params,
evaluation=True,
)
)
Actor 创建流程
train.py create_rollout_manager RolloutManager.remote init_rollout_engines SGLangEngine.remote actors create_training_models MegatronTrainRayActor actors
Notes
- RolloutManager 是唯一的 rollout actor,负责协调整个数据生成流程
- Reward 计算集成在 rollout 生成过程中,不是独立的 actor
- SGLangEngine actors 负责实际的推理,由 RolloutManager 管理
- 这种设计简化了架构,将 reward 作为数据生成流水线的一部分
Wiki pages you might want to explore:
Wiki pages you might want to explore:
Citations
File: slime/backends/megatron_utils/actor.py (L46-53)
python
class MegatronTrainRayActor(TrainRayActor):
@with_defer(lambda: Timer().start("train_wait"))
def init(
self,
args: Namespace,
role: str,
with_ref: bool = False,
) -> int | None:
File: slime/ray/rollout.py (L37-41)
python
@ray.remote
class RolloutManager:
"""The class to run rollout and convert rollout data to training data."""
def __init__(self, args, pg):
File: slime/ray/rollout.py (L188-193)
python
def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
if self.custom_reward_post_process_func is not None:
return self.custom_reward_post_process_func(self.args, samples)
raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
if (
File: slime/ray/rollout.py (L346-347)
python
RolloutRayActor = ray.remote(SGLangEngine)
File: slime/rollout/sglang_rollout.py (L527-532)
python
generate_and_rm(
args,
sample,
sampling_params=sampling_params,
evaluation=True,
)