【RL】ROLL下载模型流程

YAML配置到模型下载的完整调用流程

在RLVR中,model_args.model_name_or_path 从YAML配置到模型下载的完整流程如下:

1. YAML配置解析阶段

YAML配置通过Hydra加载到RLVRConfig类中 1

yaml 复制代码
actor_train:
  model_args:
    model_name_or_path: Qwen/Qwen2.5-7B  # 模型路径

2. Pipeline初始化创建Cluster

RLVRPipeline.__init__()中创建Cluster实例 2

python 复制代码
self.actor_train: Any = Cluster(
    name=self.pipeline_config.actor_train.name,
    worker_cls=self.pipeline_config.actor_train.worker_cls,
    resource_manager=self.resource_manager,
    worker_config=self.pipeline_config.actor_train,  # 包含model_args
)

3. 模型下载触发

Pipeline调用download_models()方法 3

python 复制代码
def download_models(self, *clusters: Cluster):
    node2worker: Dict[str, Any] = {}
    node2model_names: Dict[str, set[str]] = defaultdict(set)
    for cluster in clusters:
        for worker, node_ip in cluster.worker2nodes.items():
            node2worker[node_ip] = worker
            if cluster.worker_config.model_args.model_name_or_path:
                node2model_names[node_ip].add(cluster.worker_config.model_args.model_name_or_path)
    # 调用worker的download_models.remote方法
    ray.get([node2worker[node_ip].download_models.remote(model_name_or_paths=model_names) 
             for node_ip, model_names in node2model_names.items()])

4. Worker初始化时下载模型

Worker.initialize()方法中实际触发下载 4

python 复制代码
def initialize(self, pipeline_config, *args, **kwargs):
    self.pipeline_config = pipeline_config
    
    model_name = self.worker_config.model_args.model_name_or_path
    if model_name:
        self.worker_config.model_args.model_name_or_path = download_model(model_name)

5. download_model函数执行

download_model函数通过model_path_cache装饰器实现缓存下载 5

python 复制代码
@model_path_cache
def download_model(model_name_or_path: str, local_dir: Optional[str] = None):
    if os.path.isdir(model_name_or_path):
        return model_name_or_path  # 本地路径直接返回
    
    model_download_type = os.getenv("MODEL_DOWNLOAD_TYPE", "MODELSCOPE")
    model_download_func = model_download_registry[model_download_type]
    
    with file_lock_context(model_name_or_path):
        return model_download_func(model_name_or_path, local_dir=local_dir)

6. 缓存机制

model_path_cache装饰器实现节点级缓存 6

python 复制代码
def model_path_cache(func):
    node_ip = get_node_ip()  # 获取节点IP作为缓存标识
    def wrapper(model_name_or_path: str, local_dir: Optional[str] = None):
        # 检查SharedStorage中的缓存
        cached_path = ray.get(shared_storage.get.remote(key=f"{node_ip}:{model_name_or_path}"))
        if cached_path is None or not os.path.exists(cached_path):
            cached_path = func(model_name_or_path, local_dir)
            ray.get(shared_storage.put.remote(key=f"{node_ip}:{model_name_or_path}", data=cached_path))
        return cached_path
    return wrapper

执行时序图

YAML配置 RLVRPipeline Cluster Worker download_model SharedStorage 加载配置到RLVRConfig 创建Cluster(worker_config) download_models() download_models.remote(model_names) initialize() download_model(model_name_or_path) 检查缓存{node_ip}:{model_name} 从ModelScope/HF下载 更新缓存 alt [缓存未命中] 返回本地路径 YAML配置 RLVRPipeline Cluster Worker download_model SharedStorage

Notes

  • 模型下载是透明的,首次运行时自动下载 7
  • 支持ModelScope和HuggingFace两种下载源 8
  • 使用SharedStorage实现节点级缓存,避免重复下载 9
  • 文件锁机制防止并发下载冲突 10

Wiki pages you might want to explore:

Citations

File: roll/pipeline/rlvr/rlvr_config.py (L90-93)

python 复制代码
    num_return_sequences_in_group: int = field(
        default=1,
        metadata={"help": "The number of return sequences in one group, used in generation_args."}
    )

File: roll/pipeline/rlvr/rlvr_pipeline.py (L207-218)

python 复制代码
        self.actor_train: Any = Cluster(
            name=self.pipeline_config.actor_train.name,
            worker_cls=self.pipeline_config.actor_train.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_train,
        )
        self.actor_infer: Any = Cluster(
            name=self.pipeline_config.actor_infer.name,
            worker_cls=self.pipeline_config.actor_infer.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_infer,
        )

File: roll/pipeline/base_pipeline.py (L97-107)

python 复制代码
    def download_models(self, *clusters: Cluster):
        node2worker: Dict[str, Any] = {}
        node2model_names: Dict[str, set[str]] = defaultdict(set)
        for cluster in clusters:
            for worker, node_ip in cluster.worker2nodes.items():
                node2worker[node_ip] = worker
                if cluster.worker_config.model_args.model_name_or_path:
                    node2model_names[node_ip].add(cluster.worker_config.model_args.model_name_or_path)
                if self.pipeline_config.resume_from_checkpoint:
                    node2model_names[node_ip].add(self.pipeline_config.resume_from_checkpoint)
        ray.get([node2worker[node_ip].download_models.remote(model_name_or_paths=model_names) for node_ip, model_names in node2model_names.items()])

File: roll/distributed/executor/worker.py (L132-137)

python 复制代码
    def initialize(self, pipeline_config, *args, **kwargs):
        self.pipeline_config = pipeline_config

        model_name = self.worker_config.model_args.model_name_or_path
        if model_name:
            self.worker_config.model_args.model_name_or_path = download_model(model_name)

ROLL框架模型下载与Actor创建完整流程图

完整流程图

  1. 策略初始化 5. Worker初始化 4. Actor创建阶段 3. 模型下载阶段 2. Pipeline初始化 1. 配置加载阶段 否 是 下载模型 Worker.initialize 创建Strategy实例 Strategy.initialize 加载模型到GPU offload_states可选 设置环境变量 Worker.init 创建SharedStorage引用 获取节点IP 设置分布式通信 Ray.remote包装Worker类 Cluster._create_workers 创建Ray Actor进程 设置资源分配 收集worker信息 建立worker映射关系 收集所有Cluster Pipeline.download_models 调用Cluster.download_models Worker.download_models.remote Worker.initialize download_model函数 model_path_cache装饰器 SharedStorage检查缓存 缓存存在? 从ModelScope/HF下载 返回缓存路径 更新SharedStorage缓存 创建Cluster实例 RLVRPipeline.init actor_train Cluster actor_infer Cluster reference Cluster critic Cluster rewards Clusters Hydra加载到RLVRConfig YAML配置文件 解析model_args.model_name_or_path 解析data_args.file_name

详细流程说明

1. 配置加载阶段

YAML配置通过Hydra框架加载到RLVRConfig类中 1 ,解析模型路径和数据集路径。

2. Pipeline初始化

RLVRPipeline.__init__()创建多个Cluster实例 2

  • actor_train: 训练模型集群
  • actor_infer: 推理模型集群
  • reference: 参考模型集群
  • critic: 价值函数集群
  • rewards: 多个奖励模型集群

3. 模型下载阶段

Pipeline.download_models()收集所有需要下载模型的Cluster 3 ,通过model_path_cache装饰器实现节点级缓存 4

4. Actor创建阶段

Cluster._create_workers()创建Ray Actor进程 5 ,每个Actor运行在独立进程中。

5. Worker初始化

Worker初始化时设置分布式通信环境 6 ,获取SharedStorage引用和节点IP。

6. 策略初始化

Worker初始化完成后创建训练策略 7 ,加载模型并可选择性地启用状态卸载。

关键组件交互

  • SharedStorage : 全局单例,用于跨Worker数据共享和协调 8
  • Cluster: 管理同类型Worker的集合,提供统一的操作接口
  • Worker: 具体的执行单元,封装模型和策略
  • Strategy: 后端训练策略,处理具体的模型操作

Notes

  • 整个流程是异步并行的,多个Cluster可以同时初始化
  • 模型下载支持缓存机制,避免同节点重复下载
  • Actor创建遵循Ray的分布式计算模型
  • 状态卸载功能可以在内存受限环境下使用

Wiki pages you may want to explore:

Wiki pages you might want to explore:

Citations

File: roll/pipeline/rlvr/rlvr_config.py (L90-93)

python 复制代码
    num_return_sequences_in_group: int = field(
        default=1,
        metadata={"help": "The number of return sequences in one group, used in generation_args."}
    )

File: roll/pipeline/rlvr/rlvr_pipeline.py (L207-246)

python 复制代码
        self.actor_train: Any = Cluster(
            name=self.pipeline_config.actor_train.name,
            worker_cls=self.pipeline_config.actor_train.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_train,
        )
        self.actor_infer: Any = Cluster(
            name=self.pipeline_config.actor_infer.name,
            worker_cls=self.pipeline_config.actor_infer.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_infer,
        )
        download_clusters = [self.actor_train, self.actor_infer]
        # use unwrapped model as reference for lora training
        if not self.is_lora and self.pipeline_config.enable_reference:
            self.reference: Any = Cluster(
                name=self.pipeline_config.reference.name,
                worker_cls=self.pipeline_config.reference.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=self.pipeline_config.reference,
            )
            download_clusters.append(self.reference)
        if self.pipeline_config.adv_estimator == "gae":
            self.critic: Any = Cluster(
                name=self.pipeline_config.critic.name,
                worker_cls=self.pipeline_config.critic.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=self.pipeline_config.critic,
            )
            download_clusters.append(self.critic)
        self.rewards: Dict[str, Any] = {
            key: Cluster(
                name=f"reward-{key}",
                worker_cls=worker_config.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=worker_config,
            )
            for key, worker_config in self.pipeline_config.rewards.items()
        }
        download_clusters.extend(self.rewards.values())

File: roll/distributed/executor/cluster.py (L100-145)

python 复制代码
        for rank, pgs in enumerate(placement_groups):
            deploy_pg = pgs[0]
            pg_zero_gpu_ranks = sorted([pg["gpu_rank"] for pg in pgs if pg["node_rank"] == deploy_pg["node_rank"]])
            worker_name = f"{self.cluster_name}-{rank}"
            env_vars = {
                "WORLD_SIZE": str(self.world_size),
                "RANK": str(rank),
                "LOCAL_RANK": str(0),
                "CLUSTER_NAME": self.cluster_name,
                "WORKER_NAME": worker_name,
            }

            if rank != 0:
                env_vars["MASTER_ADDR"] = self.master_addr
                env_vars["MASTER_PORT"] = str(self.master_port)
            if deploy_pg["gpu_rank"] is not None:
                current_platform.update_env_vars_for_visible_devices(env_vars=env_vars, gpu_ranks=pg_zero_gpu_ranks)
            if "ROLL_LOG_DIR" in os.environ:
                env_vars["ROLL_LOG_DIR"] = os.environ["ROLL_LOG_DIR"]
            env_vars.update(self.worker_config.system_envs)

            runtime_env = RuntimeEnv(env_vars=env_vars)
            self.worker_config.resource_placement_groups = pgs

            worker_options = {
                "scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=deploy_pg["placement_group"]),
                "name": worker_name,
                "namespace": RAY_NAMESPACE,
                "runtime_env": runtime_env,
                "num_cpus": 0.01,
            }

            if current_platform.ray_device_key == "GPU":
                worker_options.update({"num_gpus": 0.01 if self.worker_config.device_mapping else 0})
            elif current_platform.ray_device_key == "NPU":
                worker_options.update(
                    {
                        "num_gpus": 0,
                        "resources": {
                            current_platform.ray_device_key: 0.01 if self.worker_config.device_mapping else 0
                        },
                    }
                )

            worker = self.worker_cls.options(**worker_options).remote(worker_config=self.worker_config)
            self.workers.append(worker)

在ROLL框架中,Cluster是指管理一组相同类型Worker的分布式集群管理器。

Cluster的定义

Cluster是ROLL框架中的核心类,用于创建和管理一组Ray Actor workers 1

python 复制代码
class Cluster:
    def __init__(self, name, worker_cls, resource_manager, worker_config):
        self.cluster_name = name
        self.worker_cls = worker_cls  # Worker类
        self.workers: List[Any] = []  # Worker实例列表

Cluster的作用

1. Worker生命周期管理

Cluster负责创建、初始化和管理多个Worker进程 2

python 复制代码
def _create_workers(self):
    for rank, pgs in enumerate(placement_groups):
        worker = self.worker_cls.options(**worker_options).remote(worker_config=self.worker_config)
        self.workers.append(worker)

2. 分布式方法调用

Cluster提供统一接口来调用所有Worker的方法 3

python 复制代码
def execute_all_async(self, method_name: str, *args, **kwargs):
    return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self.workers]

RLVR中的Cluster类型

在RLVR Pipeline中,创建多种不同用途的Cluster 4

Cluster名称 用途 Worker类型
actor_train 策略模型训练 ActorWorker
actor_infer 策略模型推理 ActorWorker
reference 参考模型 ActorWorker
critic 价值函数 CriticWorker
rewards 奖励计算 RewardWorker

Cluster与Actor的关系

  • Cluster:管理器,负责创建和协调多个Actor
  • Actor/Worker:实际的执行单元,运行在独立进程中

每个Cluster包含多个Worker实例,形成分布式计算集群 5

python 复制代码
self.rank2worker = {k: self.workers[k] for k in range(len(self.workers))}
self.worker2rank = {self.workers[k]: k for k in range(len(self.workers))}

Notes

  • Cluster是ROLL分布式架构的基础抽象
  • 通过Cluster实现不同角色Worker的资源隔离和协调
  • 支持多种分布式调度策略(如DP、TP、PP并行)
  • 每个Cluster有独立的资源配置和GPU分配

Wiki pages you may want to explore:

Wiki pages you might want to explore:

Citations

File: roll/distributed/executor/cluster.py (L29-61)

python 复制代码
class Cluster:

    def __init__(
        self,
        name,
        worker_cls: Union[RemoteFunctionNoArgs[Worker], Type[Worker], str],
        resource_manager: ResourceManager,
        worker_config: WorkerConfig,
    ):

        self.cluster_name = name
        if isinstance(worker_cls, str):
            worker_cls = safe_import_class(worker_cls)

        if not hasattr(worker_cls, "__ray_actor_class__"):
            logger.info(f"wrap {worker_cls.__name__} to ray.remote()")
            self.worker_cls = ray.remote(worker_cls)
        else:
            self.worker_cls = worker_cls
        self.resource_manager = resource_manager
        self.worker_config = worker_config

        self.workers: List[Any] = []

        self.master_addr = None
        self.master_port = None
        self.world_size = self.worker_config.world_size

        self._create_workers()
        self._bind_worker_method()
        self._worker_rank_info = None
        self.initialized = False

File: roll/distributed/executor/cluster.py (L62-70)

python 复制代码
        self.rank2worker = {k: self.workers[k] for k in range(len(self.workers))}
        self.worker2rank = {self.workers[k]: k for k in range(len(self.workers))}
        self.rank2devices = dict(zip(map(lambda worker: self.worker2rank[worker], self.workers),
                                     ray.get([worker.get_devices_info.remote() for worker in self.workers])))
        self.worker2nodes = dict(zip(self.workers, ray.get([worker.get_node_ip.remote() for worker in self.workers])))
        logger.debug(f"{self.cluster_name} rank2devices {self.rank2devices}")
        # for cluster object can transfer by ray rpc.
        del self.worker_cls

File: roll/distributed/executor/cluster.py (L94-148)

python 复制代码
    def _create_workers(self):
        placement_groups: List[List[Dict]] = self.resource_manager.allocate_placement_group(
            device_mapping=self.worker_config.device_mapping, world_size=self.worker_config.world_size
        )
        logger.debug(f"placement_groups: {placement_groups}")

        for rank, pgs in enumerate(placement_groups):
            deploy_pg = pgs[0]
            pg_zero_gpu_ranks = sorted([pg["gpu_rank"] for pg in pgs if pg["node_rank"] == deploy_pg["node_rank"]])
            worker_name = f"{self.cluster_name}-{rank}"
            env_vars = {
                "WORLD_SIZE": str(self.world_size),
                "RANK": str(rank),
                "LOCAL_RANK": str(0),
                "CLUSTER_NAME": self.cluster_name,
                "WORKER_NAME": worker_name,
            }

            if rank != 0:
                env_vars["MASTER_ADDR"] = self.master_addr
                env_vars["MASTER_PORT"] = str(self.master_port)
            if deploy_pg["gpu_rank"] is not None:
                current_platform.update_env_vars_for_visible_devices(env_vars=env_vars, gpu_ranks=pg_zero_gpu_ranks)
            if "ROLL_LOG_DIR" in os.environ:
                env_vars["ROLL_LOG_DIR"] = os.environ["ROLL_LOG_DIR"]
            env_vars.update(self.worker_config.system_envs)

            runtime_env = RuntimeEnv(env_vars=env_vars)
            self.worker_config.resource_placement_groups = pgs

            worker_options = {
                "scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=deploy_pg["placement_group"]),
                "name": worker_name,
                "namespace": RAY_NAMESPACE,
                "runtime_env": runtime_env,
                "num_cpus": 0.01,
            }

            if current_platform.ray_device_key == "GPU":
                worker_options.update({"num_gpus": 0.01 if self.worker_config.device_mapping else 0})
            elif current_platform.ray_device_key == "NPU":
                worker_options.update(
                    {
                        "num_gpus": 0,
                        "resources": {
                            current_platform.ray_device_key: 0.01 if self.worker_config.device_mapping else 0
                        },
                    }
                )

            worker = self.worker_cls.options(**worker_options).remote(worker_config=self.worker_config)
            self.workers.append(worker)
            if rank == 0:
                self.master_addr, self.master_port = ray.get(worker.get_master_addr_and_port.remote())

File: roll/distributed/executor/cluster.py (L213-231)

python 复制代码
    def execute_all(self, method_name: str, *args, **kwargs):
        return self.execute_all_async(method_name, *args, **kwargs)

    def execute_all_sync(self, method_name: str, *args, **kwargs):
        return ray.get(self.execute_all_async(method_name, *args, **kwargs))

    def execute_all_async(self, method_name: str, *args, **kwargs):
        length = len(self.workers)
        if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
            if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
                result = []
                for i in range(length):
                    sliced_args = tuple(arg[i] for arg in args)
                    sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
                    remote_call = getattr(self.workers[i], method_name)
                    result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
                return result

        return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self.workers]

File: roll/pipeline/rlvr/rlvr_pipeline.py (L207-246)

python 复制代码
        self.actor_train: Any = Cluster(
            name=self.pipeline_config.actor_train.name,
            worker_cls=self.pipeline_config.actor_train.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_train,
        )
        self.actor_infer: Any = Cluster(
            name=self.pipeline_config.actor_infer.name,
            worker_cls=self.pipeline_config.actor_infer.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_infer,
        )
        download_clusters = [self.actor_train, self.actor_infer]
        # use unwrapped model as reference for lora training
        if not self.is_lora and self.pipeline_config.enable_reference:
            self.reference: Any = Cluster(
                name=self.pipeline_config.reference.name,
                worker_cls=self.pipeline_config.reference.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=self.pipeline_config.reference,
            )
            download_clusters.append(self.reference)
        if self.pipeline_config.adv_estimator == "gae":
            self.critic: Any = Cluster(
                name=self.pipeline_config.critic.name,
                worker_cls=self.pipeline_config.critic.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=self.pipeline_config.critic,
            )
            download_clusters.append(self.critic)
        self.rewards: Dict[str, Any] = {
            key: Cluster(
                name=f"reward-{key}",
                worker_cls=worker_config.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=worker_config,
            )
            for key, worker_config in self.pipeline_config.rewards.items()
        }
        download_clusters.extend(self.rewards.values())
相关推荐
لا معنى له2 小时前
目标分割介绍及最新模型----学习笔记
人工智能·笔记·深度学习·学习·机器学习·计算机视觉
yaoh.wang2 小时前
力扣(LeetCode) 100: 相同的树 - 解法思路
python·程序人生·算法·leetcode·面试·职场和发展·跳槽
SadSunset2 小时前
力扣题目142. 环形链表 II的解法分享,附图解
算法·leetcode·链表
carver w2 小时前
one-hot编码
人工智能
Sunsets_Red2 小时前
2025 FZYZ夏令营游记
java·c语言·c++·python·算法·c#
邮一朵向日葵3 小时前
企查查开放平台MCP:为AI智能体注入精准商业数据,驱动智能决策新时代
大数据·人工智能
沃达德软件3 小时前
智能警务视频侦查系统
大数据·人工智能·数据挖掘·数据分析·实时音视频·视频编解码
iAkuya3 小时前
(leetcode)力扣100 19螺旋矩阵(方向数组/边界把控)
算法·leetcode·矩阵
说私域3 小时前
链动2+1模式AI智能名片S2B2C商城小程序中电商直播的应用机制与价值创新研究
人工智能·小程序