【大模型训练】加载load_state 中的一些技巧 工厂设计模式

Answer

您询问的是 WorkerState 类的作用,特别是 load_from_json 方法的使用方式 1

WorkerState 的作用

WorkerState 是一个用于保存和恢复训练状态 的数据类,主要用于checkpoint和训练恢复 2 。它包含三个核心字段:

  • step: 当前训练步数
  • log_history: 训练过程中的指标历史记录列表
  • kv: 键值对字典,用于存储额外的状态信息

load_from_json 方法的实现

load_from_json 是一个类方法,用于从JSON文件中恢复 WorkerState 实例 3 :

python 复制代码
@classmethod
def load_from_json(cls, load_dir: str, tag):
    """Create an instance from the content of `json_path`."""
    json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))
    with open(json_path, "r", encoding="utf-8") as f:
        text = f.read()
    return cls(**json.loads(text))

该方法:

  1. 根据 load_dirtag 构造JSON文件路径(格式为 worker_state_{tag}.json) 4
  2. 读取JSON文件内容
  3. 将JSON反序列化为字典,并用 ** 解包传递给类构造函数

在 Pipeline 中的使用

BasePipeline 中,WorkerState 用于保存和恢复整个pipeline的训练状态 5 :

初始化时恢复状态

当从checkpoint恢复训练时,pipeline会加载之前保存的状态 6 :

python 复制代码
if self.pipeline_config.resume_from_checkpoint:
    self.resume_from_checkpoint = download_model(self.pipeline_config.resume_from_checkpoint)
    logger.info(f"resume_from_checkpoint: {self.resume_from_checkpoint}")
    load_dir = os.path.join(self.resume_from_checkpoint, "pipeline")
    self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

这里使用 tag="pipeline" 来标识这是pipeline级别的状态文件 7

恢复训练指标

加载状态后,pipeline会将历史指标重新记录到tracker中 8 :

python 复制代码
def resume_metrics():
    for metrics in self.state.log_history:
        self.tracker.log(values=metrics, step=metrics["system/step"])

self.resume_futures.append(self.executor.submit(resume_metrics))

保存状态

在checkpoint时,pipeline会保存当前状态 9 :

python 复制代码
save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id, "pipeline")
self.state.save_to_json(save_dir=save_dir, tag="pipeline")
self.state.save_rng_state(save_dir=save_dir, tag="pipeline")

配合 RNG 状态保存

WorkerState 还提供了静态方法来保存和加载随机数生成器(RNG)状态,确保训练的可重现性 10 :

  • save_rng_state: 保存Python、NumPy、PyTorch的随机状态
  • load_rng_state: 恢复这些随机状态

这些方法在checkpoint时与 save_to_json/load_from_json 配合使用 11

在 Trainer 中的类似用法

McaTrainer 中也有类似的状态恢复逻辑 12 :

python 复制代码
if resume_from_checkpoint is not None and os.path.isfile(
    os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
    self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))

这里使用的是 TrainerState 而非 WorkerState,但模式相同。

Notes

WorkerState 的设计遵循了简单的序列化/反序列化模式,使用JSON格式存储训练状态 13tag 参数允许在同一目录下保存多个不同角色的状态文件(如 "pipeline", "actor_train" 等)。这种设计使得ROLL能够在训练中断后精确恢复到之前的状态,包括训练步数、指标历史和随机数状态 14

Wiki pages you might want to explore:

Citations

File: roll/utils/worker_state.py (L17-17)

python 复制代码
WORKER_STATE_NAME = "worker_state_{tag}.json"

File: roll/utils/worker_state.py (L20-30)

python 复制代码
@dataclass
class WorkerState:
    step: int = -1
    log_history: List[Dict[str, float]] = None
    kv: Dict[str, Union[float, Dict]] = None

    def __post_init__(self):
        if self.log_history is None:
            self.log_history = []
        if self.kv is None:
            self.kv = {}

File: roll/utils/worker_state.py (L32-38)

python 复制代码
    def save_to_json(self, save_dir: str, tag):
        """Save the content of this instance in JSON format inside `json_path`."""
        json_path = os.path.join(save_dir, WORKER_STATE_NAME.format(tag=tag))
        json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
        os.makedirs(os.path.dirname(json_path), exist_ok=True)
        with open(json_path, "w", encoding="utf-8") as f:
            f.write(json_string)

File: roll/utils/worker_state.py (L40-46)

python 复制代码
    @classmethod
    def load_from_json(cls, load_dir: str, tag):
        """Create an instance from the content of `json_path`."""
        json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))
        with open(json_path, "r", encoding="utf-8") as f:
            text = f.read()
        return cls(**json.loads(text))

File: roll/utils/worker_state.py (L48-77)

python 复制代码
    @staticmethod
    def save_rng_state(save_dir, tag):
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
            current_platform.device_type: current_platform.random.get_rng_state_all(),
        }
        os.makedirs(save_dir, exist_ok=True)
        torch.save(rng_states, os.path.join(save_dir, f"rng_state_{tag}.pth"))

    @staticmethod
    def load_rng_state(load_dir, tag):
        # Load RNG states from `checkpoint`
        if load_dir is None:
            return
        rng_file = os.path.join(load_dir, f"rng_state_{tag}.pth")
        if not os.path.isfile(rng_file):
            logger.info(
                f"Didn't find an RNG file for process {tag}, if you are resuming a training that "
                "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
            )
            return

        checkpoint_rng_state = torch.load(rng_file, weights_only=False)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        current_platform.random.set_rng_state_all(checkpoint_rng_state[current_platform.device_type])

File: roll/pipeline/base_pipeline.py (L32-32)

python 复制代码
        self.state = WorkerState()

File: roll/pipeline/base_pipeline.py (L43-48)

python 复制代码
        if self.pipeline_config.resume_from_checkpoint:
            self.resume_from_checkpoint = download_model(self.pipeline_config.resume_from_checkpoint)

            logger.info(f"resume_from_checkpoint: {self.resume_from_checkpoint}")
            load_dir = os.path.join(self.resume_from_checkpoint, "pipeline")
            self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

File: roll/pipeline/base_pipeline.py (L50-54)

python 复制代码
            def resume_metrics():
                for metrics in self.state.log_history:
                    self.tracker.log(values=metrics, step=metrics["system/step"])

            self.resume_futures.append(self.executor.submit(resume_metrics))

File: roll/pipeline/base_pipeline.py (L88-92)

python 复制代码
            pipeline_save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id)
            save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id, "pipeline")
            self.state.save_to_json(save_dir=save_dir, tag="pipeline")
            self.state.save_rng_state(save_dir=save_dir, tag="pipeline")
            self.checkpoint_manager.upload(ckpt_id=ckpt_id, local_state_path=pipeline_save_dir)

File: mcore_adapter/src/mcore_adapter/trainer/trainer.py (L664-667)

python 复制代码
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ):
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))

Answer

您询问的是为什么 load_from_json 要设计成类方法(@classmethod)而不是普通方法。

类方法的必要性

load_from_json 必须 是类方法,因为它的作用是创建并返回一个新的类实例 1 。这是一个典型的工厂方法模式

关键区别

普通实例方法:

  • 需要先有一个实例才能调用: instance.method()
  • 第一个参数是 self,指向已存在的实例
  • 用于操作已有对象的状态

类方法:

  • 直接通过类调用: WorkerState.load_from_json(load_dir, tag)
  • 第一个参数是 cls,指向类本身
  • 用于创建新实例

实际使用场景

BasePipeline 中,从checkpoint恢复时需要从零创建 一个 WorkerState 实例 2 :

python 复制代码
self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

此时还没有 现成的 WorkerState 实例,所以不可能用普通方法调用。类方法允许直接通过类名调用,并返回一个新构造的实例。

与构造函数的对比

您可能会问:为什么不直接用 __init__ 构造函数?

类方法提供了更灵活的构造方式:

  1. 语义清晰 : load_from_json 明确表达"从JSON文件加载"的意图,比直接调用 __init__ 更易读
  2. 多种构造方式 : 可以有多个类方法提供不同的构造路径(如 from_dict, from_json, from_yaml),而 __init__ 只能有一个
  3. 预处理逻辑 : 类方法可以在调用 __init__ 前进行文件读取、数据解析等操作 3

相同模式在其他地方的应用

这个模式在ROLL代码库中被广泛使用:

  • MegatronArguments.from_json_file 4 : 从JSON文件创建训练参数实例
  • PretrainedConfig.from_json_file 5 : 从JSON文件创建模型配置实例
  • McaModelConfig.from_pretrained 6 : 从预训练路径创建配置实例

所有这些都使用 @classmethod 装饰器,因为它们都是替代构造函数 (alternative constructors)。

如果改成普通方法会怎样

假设改成普通方法:

python 复制代码
def load_from_json(self, load_dir: str, tag):
    # ... 读取JSON ...
    # 但这里无法返回新实例,只能修改self

这样就无法实现从文件创建新实例的功能,因为:

  1. 调用前必须先有实例: state = WorkerState(); state.load_from_json(...)
  2. 无法返回新对象,只能修改现有对象的属性
  3. 违反了"从文件加载"的语义

Notes

类方法是Python中实现工厂模式和替代构造函数的标准做法。@classmethod 装饰器让方法接收类本身作为第一个参数(cls),从而可以调用 cls(...) 来创建新实例 1 。这种设计在需要多种对象创建方式时非常有用,是面向对象设计的最佳实践。

Wiki pages you might want to explore:

Citations

File: roll/utils/worker_state.py (L40-46)

python 复制代码
    @classmethod
    def load_from_json(cls, load_dir: str, tag):
        """Create an instance from the content of `json_path`."""
        json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))
        with open(json_path, "r", encoding="utf-8") as f:
            text = f.read()
        return cls(**json.loads(text))

File: roll/pipeline/base_pipeline.py (L48-48)

python 复制代码
            self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

File: mcore_adapter/src/mcore_adapter/training_args.py (L317-321)

python 复制代码
    @classmethod
    def from_json_file(cls, json_file_path) -> "MegatronArguments":
        with open(json_file_path, "r", encoding="utf-8") as reader:
            text = reader.read()
        return cls(**json.loads(text))

File: mcore_adapter/src/mcore_adapter/models/model_config.py (L69-73)

python 复制代码
    @classmethod
    def from_json_file(cls, json_file_path):
        with open(json_file_path, "r", encoding="utf-8") as reader:
            text = reader.read()
        return cls(**json.loads(text))

File: mcore_adapter/src/mcore_adapter/models/model_config.py (L124-125)

python 复制代码
    @classmethod
    def from_pretrained(cls, model_name_or_path: str, args: Optional["TrainingArguments"] = None):
相关推荐
gladiator+7 小时前
Java中的设计模式------策略设计模式
java·开发语言·设计模式
在未来等你10 小时前
AI Agent设计模式 Day 2:Plan-and-Execute模式:先规划后执行的智能策略
设计模式·llm·react·ai agent·plan-and-execute
在未来等你16 小时前
AI Agent设计模式 Day 3:Self-Ask模式:自我提问驱动的推理链
设计模式·llm·react·ai agent·plan-and-execute
xiaodaidai丶1 天前
设计模式之策略模式
设计模式·策略模式
_院长大人_1 天前
设计模式-工厂模式
java·开发语言·设计模式
王道长服务器 | 亚马逊云2 天前
AWS + 苹果CMS:影视站建站的高效组合方案
服务器·数据库·搜索引擎·设计模式·云计算·aws
在未来等你2 天前
AI Agent设计模式 Day 1:ReAct模式:推理与行动的完美结合
设计模式·llm·react·ai agent·plan-and-execute
乐悠小码2 天前
Java设计模式精讲---03建造者模式
java·设计模式·建造者模式
_院长大人_2 天前
设计模式-代理模式
设计模式·代理模式