目录
-
- [1. 背景介绍](#1. 背景介绍)
- [2. DCP核心概念与优势](#2. DCP核心概念与优势)
-
- [2.1 DCP vs 传统方法](#2.1 DCP vs 传统方法)
- [2.2 DCP工作原理](#2.2 DCP工作原理)
- [3. TorchRec DMP模型保存与加载完整实现](#3. TorchRec DMP模型保存与加载完整实现)
-
- [3.1 环境准备与依赖](#3.1 环境准备与依赖)
- [3.2 应用状态管理器(关键组件)](#3.2 应用状态管理器(关键组件))
- [3.3 OSS存储适配器实现](#3.3 OSS存储适配器实现)
- [3.4 分布式模型保存函数](#3.4 分布式模型保存函数)
- [3.5 分布式模型加载函数](#3.5 分布式模型加载函数)
- [3.6 完整训练示例](#3.6 完整训练示例)
- [4. 高级特性与最佳实践](#4. 高级特性与最佳实践)
-
- [4.1 检查点验证](#4.1 检查点验证)
- [4.2 异步检查点保存](#4.2 异步检查点保存)
- [4.3 格式转换工具](#4.3 格式转换工具)
- [5. 常见问题与解决方案](#5. 常见问题与解决方案)
-
- [5.1 分片不匹配问题](#5.1 分片不匹配问题)
- [5.2 内存优化](#5.2 内存优化)
- [6. 总结](#6. 总结)
1. 背景介绍
在TorchRec中,当使用DistributedModelParallel (DMP)进行大规模推荐系统训练时,传统的单机模型保存方法无法满足需求。PyTorch的Distributed Checkpoint (DCP) 是专门为此类场景设计的解决方案,它支持:
- 多rank并行操作:同时从多个rank保存和加载模型
- 加载时重新分片(Load-time resharding):支持在不同GPU配置下保存和加载
- 内存效率:直接使用预分配的存储,避免额外内存开销
- 状态管理:自动处理Stateful对象的状态保存和恢复
2. DCP核心概念与优势
2.1 DCP vs 传统方法
| 特性 | DCP | torch.save/load |
|---|---|---|
| 文件结构 | 多文件(每个rank至少一个) | 单文件 |
| 内存管理 | 原地操作(使用预分配存储) | 需要额外内存 |
| 分布式支持 | 原生支持SPMD模式 | 需要手动聚合 |
| 重新分片 | 支持加载时重新分片 | 不支持 |
| 状态管理 | 自动处理Stateful对象 | 需要手动处理 |
2.2 DCP工作原理
DCP的核心优势在于它能够:
- 并行I/O:所有rank同时读写,最大化I/O吞吐量
- 分片感知:理解模型的分片拓扑结构
- 弹性训练:支持在不同world size配置下保存和加载
- 状态管理 :通过
Stateful协议自动处理复杂对象
3. TorchRec DMP模型保存与加载完整实现
3.1 环境准备与依赖
python
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
get_state_dict,
set_state_dict,
StateDictOptions
)
from torch.distributed.checkpoint.stateful import Stateful
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.types import ModuleSharder
from typing import Dict, Any, Optional, Union, List
import logging
import time
3.2 应用状态管理器(关键组件)
python
class TorchRecAppState(Stateful):
"""
TorchRec应用状态管理器,实现Stateful协议
自动处理DMP模型和优化器的状态字典管理
"""
def __init__(
self,
model: DistributedModelParallel,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
extra_state: Optional[Dict[str, Any]] = None
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.extra_state = extra_state or {}
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def state_dict(self) -> Dict[str, Any]:
"""
生成状态字典,自动处理DMP模型的分片状态
"""
logging.info(f"[Rank {self.rank}] 生成状态字典...")
# 使用DCP的get_state_dict自动处理FQN映射和分片
model_state_dict, optim_state_dict = get_state_dict(
self.model,
self.optimizer,
options=StateDictOptions(
full_state_dict=False, # 保持分片状态
strict=False
)
)
state = {
"model": model_state_dict,
"world_size": self.world_size,
"rank": self.rank,
"timestamp": time.time(),
"version": "1.0"
}
if optim_state_dict is not None:
state["optimizer"] = optim_state_dict
if self.scheduler is not None:
state["scheduler"] = self.scheduler.state_dict()
# 添加额外状态
for key, value in self.extra_state.items():
if isinstance(value, Stateful):
state[f"extra_{key}"] = value.state_dict()
else:
state[f"extra_{key}"] = value
logging.info(f"[Rank {self.rank}] 状态字典生成完成,大小: {len(state)}")
return state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
加载状态字典,自动处理DMP模型的分片状态
"""
logging.info(f"[Rank {self.rank}] 开始加载状态字典...")
try:
# 验证world size兼容性
saved_world_size = state_dict.get("world_size", 1)
current_world_size = self.world_size
if saved_world_size != current_world_size:
logging.warning(
f"[Rank {self.rank}] World size不匹配: 保存时={saved_world_size}, 当前={current_world_size}\n"
"DCP将自动处理重新分片..."
)
# 设置模型和优化器状态
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict.get("model", {}),
optim_state_dict=state_dict.get("optimizer", {}),
options=StateDictOptions(
full_state_dict=False, # 保持分片状态
strict=False
)
)
# 加载scheduler状态
if self.scheduler is not None and "scheduler" in state_dict:
self.scheduler.load_state_dict(state_dict["scheduler"])
logging.info(f"[Rank {self.rank}] Scheduler状态加载完成")
# 加载额外状态
for key in list(state_dict.keys()):
if key.startswith("extra_"):
orig_key = key[6:] # 移除"extra_"前缀
value = state_dict[key]
if orig_key in self.extra_state and isinstance(self.extra_state[orig_key], Stateful):
self.extra_state[orig_key].load_state_dict(value)
else:
self.extra_state[orig_key] = value
logging.info(f"[Rank {self.rank}] 状态字典加载完成")
except Exception as e:
logging.error(f"[Rank {self.rank}] 状态字典加载失败: {str(e)}")
raise
3.3 OSS存储适配器实现
python
class OSSStorageWriter:
"""
OSS存储写入器,支持DCP的StorageWriter接口
"""
def __init__(
self,
bucket_name: str,
prefix: str,
endpoint: Optional[str] = None,
access_key_id: Optional[str] = None,
access_key_secret: Optional[str] = None
):
try:
import oss2
except ImportError:
raise ImportError("请安装oss2: pip install oss2")
self.bucket_name = bucket_name
self.prefix = prefix.rstrip('/') + '/'
self.endpoint = endpoint or os.environ.get("OSS_ENDPOINT", "https://oss-cn-hangzhou.aliyuncs.com")
self.access_key_id = access_key_id or os.environ.get("OSS_ACCESS_KEY_ID")
self.access_key_secret = access_key_secret or os.environ.get("OSS_ACCESS_KEY_SECRET")
if not all([self.access_key_id, self.access_key_secret]):
raise ValueError("OSS_ACCESS_KEY_ID和OSS_ACCESS_KEY_SECRET必须设置")
self.auth = oss2.Auth(self.access_key_id, self.access_key_secret)
self.bucket = oss2.Bucket(self.auth, self.endpoint, bucket_name)
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
logging.info(f"[Rank {self.rank}] OSS存储写入器初始化完成: bucket={bucket_name}, prefix={prefix}")
def write(self, Union[bytes, tensor], path: str) -> None:
"""
写入数据到OSS
"""
# 处理路径中的rank和world_size变量
path = path.replace("{rank}", str(self.rank))
path = path.replace("{world_size}", str(self.world_size))
oss_path = f"{self.prefix}{path}"
try:
start_time = time.time()
if isinstance(data, torch.Tensor):
# 将tensor转换为bytes
buffer = io.BytesIO()
torch.save(data, buffer)
data_bytes = buffer.getvalue()
else:
data_bytes = data
self.bucket.put_object(oss_path, data_bytes)
elapsed = time.time() - start_time
logging.debug(f"[Rank {self.rank}] 写入OSS完成: {oss_path}, 大小={len(data_bytes)}, 耗时={elapsed:.2f}s")
except Exception as e:
logging.error(f"[Rank {self.rank}] OSS写入失败: {oss_path}, 错误={str(e)}")
raise
def __setstate__(self, state):
"""支持pickle序列化"""
self.__dict__.update(state)
# 重新初始化OSS连接
import oss2
self.auth = oss2.Auth(self.access_key_id, self.access_key_secret)
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
def __getstate__(self):
"""支持pickle序列化"""
state = self.__dict__.copy()
# 移除不能序列化的对象
state.pop('auth', None)
state.pop('bucket', None)
return state
class OSSStorageReader:
"""
OSS存储读取器,支持DCP的StorageReader接口
"""
def __init__(
self,
bucket_name: str,
prefix: str,
endpoint: Optional[str] = None,
access_key_id: Optional[str] = None,
access_key_secret: Optional[str] = None
):
try:
import oss2
except ImportError:
raise ImportError("请安装oss2: pip install oss2")
self.bucket_name = bucket_name
self.prefix = prefix.rstrip('/') + '/'
self.endpoint = endpoint or os.environ.get("OSS_ENDPOINT", "https://oss-cn-hangzhou.aliyuncs.com")
self.access_key_id = access_key_id or os.environ.get("OSS_ACCESS_KEY_ID")
self.access_key_secret = access_key_secret or os.environ.get("OSS_ACCESS_KEY_SECRET")
if not all([self.access_key_id, self.access_key_secret]):
raise ValueError("OSS_ACCESS_KEY_ID和OSS_ACCESS_KEY_SECRET必须设置")
self.auth = oss2.Auth(self.access_key_id, self.access_key_secret)
self.bucket = oss2.Bucket(self.auth, self.endpoint, bucket_name)
self.rank = dist.get_rank()
logging.info(f"[Rank {self.rank}] OSS存储读取器初始化完成: bucket={bucket_name}, prefix={prefix}")
def read(self, path: str) -> bytes:
"""
从OSS读取数据
"""
# 处理路径中的rank变量
path = path.replace("{rank}", str(self.rank))
oss_path = f"{self.prefix}{path}"
try:
start_time = time.time()
result = self.bucket.get_object(oss_path)
data = result.read()
elapsed = time.time() - start_time
logging.debug(f"[Rank {self.rank}] 从OSS读取完成: {oss_path}, 大小={len(data)}, 耗时={elapsed:.2f}s")
return data
except Exception as e:
logging.error(f"[Rank {self.rank}] OSS读取失败: {oss_path}, 错误={str(e)}")
raise
def __setstate__(self, state):
"""支持pickle序列化"""
self.__dict__.update(state)
# 重新初始化OSS连接
import oss2
self.auth = oss2.Auth(self.access_key_id, self.access_key_secret)
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
def __getstate__(self):
"""支持pickle序列化"""
state = self.__dict__.copy()
# 移除不能序列化的对象
state.pop('auth', None)
state.pop('bucket', None)
return state
3.4 分布式模型保存函数
python
def save_distributed_model(
model: DistributedModelParallel,
checkpoint_dir: str,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
extra_state: Optional[Dict[str, Any]] = None,
storage_type: str = "oss",
**storage_kwargs
) -> bool:
"""
保存分布式TorchRec模型到指定位置
Args:
model: DMP封装的模型
checkpoint_dir: 检查点目录路径
optimizer: 可选的优化器
scheduler: 可选的学习率调度器
extra_state: 额外需要保存的状态
storage_type: 存储类型 ('oss' or 'filesystem')
**storage_kwargs: 存储相关的额外参数
Returns:
bool: 保存是否成功
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
try:
logging.info(f"[Rank {rank}] 开始保存分布式模型到: {checkpoint_dir}")
start_time = time.time()
# 创建应用状态管理器
app_state = TorchRecAppState(
model=model,
optimizer=optimizer,
scheduler=scheduler,
extra_state=extra_state
)
# 准备状态字典
state_dict = {
"app": app_state
}
# 创建存储写入器
if storage_type == "oss":
storage_writer = OSSStorageWriter(
bucket_name=storage_kwargs.get("bucket_name", "torchrec-checkpoints"),
prefix=checkpoint_dir,
**storage_kwargs
)
checkpoint_id = checkpoint_dir # OSS使用prefix作为标识
elif storage_type == "filesystem":
from torch.distributed.checkpoint import FileSystemWriter
os.makedirs(checkpoint_dir, exist_ok=True)
storage_writer = FileSystemWriter(checkpoint_dir)
checkpoint_id = checkpoint_dir
else:
raise ValueError(f"不支持的存储类型: {storage_type}")
# 保存检查点
dcp.save(
state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer
)
# 同步所有rank
dist.barrier()
elapsed = time.time() - start_time
if rank == 0:
logging.info(f"✅ 分布式模型保存成功!")
logging.info(f" - 保存位置: {checkpoint_dir}")
logging.info(f" - World size: {world_size}")
logging.info(f" - 耗时: {elapsed:.2f}秒")
return True
except Exception as e:
logging.error(f"[Rank {rank}] 保存失败: {str(e)}")
# 尝试同步以避免死锁
try:
dist.barrier()
except:
pass
return False
3.5 分布式模型加载函数
python
def load_distributed_model(
model: DistributedModelParallel,
checkpoint_dir: str,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
extra_state: Optional[Dict[str, Any]] = None,
storage_type: str = "oss",
strict: bool = False,
**storage_kwargs
) -> Dict[str, Any]:
"""
从检查点加载分布式TorchRec模型
Args:
model: DMP封装的模型(需要先初始化)
checkpoint_dir: 检查点目录路径
optimizer: 可选的优化器
scheduler: 可选的学习率调度器
extra_state: 额外需要加载的状态
storage_type: 存储类型 ('oss' or 'filesystem')
strict: 是否严格匹配键
**storage_kwargs: 存储相关的额外参数
Returns:
Dict[str, Any]: 加载的元数据
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
try:
logging.info(f"[Rank {rank}] 开始从 {checkpoint_dir} 加载分布式模型")
start_time = time.time()
# 创建应用状态管理器
app_state = TorchRecAppState(
model=model,
optimizer=optimizer,
scheduler=scheduler,
extra_state=extra_state or {}
)
# 准备状态字典(需要预分配存储)
state_dict = {
"app": app_state
}
# 创建存储读取器
if storage_type == "oss":
storage_reader = OSSStorageReader(
bucket_name=storage_kwargs.get("bucket_name", "torchrec-checkpoints"),
prefix=checkpoint_dir,
**storage_kwargs
)
checkpoint_id = checkpoint_dir
elif storage_type == "filesystem":
from torch.distributed.checkpoint import FileSystemReader
storage_reader = FileSystemReader(checkpoint_dir)
checkpoint_id = checkpoint_dir
else:
raise ValueError(f"不支持的存储类型: {storage_type}")
# 加载检查点
dcp.load(
state_dict,
checkpoint_id=checkpoint_id,
storage_reader=storage_reader,
no_dist=(world_size == 1) # 单机模式
)
# 同步所有rank
dist.barrier()
elapsed = time.time() - start_time
logging.info(f"[Rank {rank}] 模型加载完成,耗时: {elapsed:.2f}秒")
# 获取元数据
metadata = {}
if hasattr(app_state, 'extra_state'):
metadata = app_state.extra_state.copy()
if rank == 0:
logging.info(f"✅ 分布式模型加载成功!")
logging.info(f" - 加载位置: {checkpoint_dir}")
logging.info(f" - World size: {world_size}")
logging.info(f" - 耗时: {elapsed:.2f}秒")
return metadata
except Exception as e:
logging.error(f"[Rank {rank}] 加载失败: {str(e)}")
# 尝试同步以避免死锁
try:
dist.barrier()
except:
pass
raise
3.6 完整训练示例
python
def train_torchrec_model_with_checkpointing():
"""
完整的TorchRec训练示例,包含检查点保存和加载
"""
# 初始化分布式环境
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
# 配置日志
logging.basicConfig(
level=logging.INFO if rank == 0 else logging.WARNING,
format=f'[Rank {rank}] %(asctime)s - %(levelname)s - %(message)s'
)
# 创建TorchRec模型(示例)
logging.info(f"[Rank {rank}] 创建TorchRec模型...")
model = create_torchrec_model() # 需要实现你的模型创建函数
model = DistributedModelParallel(model)
# 创建优化器和调度器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000)
# 检查点配置
checkpoint_dir = "your-oss-bucket/checkpoints/model_v1"
resume_training = True
save_interval = 100 # 每100步保存一次
# 尝试加载检查点
start_step = 0
if resume_training:
try:
logging.info(f"[Rank {rank}] 尝试从 {checkpoint_dir} 加载检查点...")
metadata = load_distributed_model(
model=model,
optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=checkpoint_dir,
storage_type="oss",
bucket_name="your-oss-bucket"
)
start_step = metadata.get("step", 0) + 1
logging.info(f"[Rank {rank}] 检查点加载成功,从step {start_step}开始训练")
except Exception as e:
logging.warning(f"[Rank {rank}] 检查点加载失败,从头开始训练: {str(e)}")
start_step = 0
# 训练循环
total_steps = 10000
for step in range(start_step, total_steps):
# 训练步骤(需要实现你的训练逻辑)
loss = training_step(model, optimizer) # 需要实现
if rank == 0 and step % 10 == 0:
logging.info(f"Step {step}/{total_steps}, Loss: {loss:.4f}")
# 更新调度器
scheduler.step()
# 定期保存检查点
if step > 0 and step % save_interval == 0:
logging.info(f"[Rank {rank}] 保存检查点到 step_{step}...")
success = save_distributed_model(
model=model,
optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=f"{checkpoint_dir}/step_{step}",
extra_state={"step": step, "loss": loss.item()},
storage_type="oss",
bucket_name="your-oss-bucket"
)
if success and rank == 0:
logging.info(f"✅ 检查点 step_{step} 保存成功")
# 保存最终模型
logging.info(f"[Rank {rank}] 保存最终模型...")
save_distributed_model(
model=model,
optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=f"{checkpoint_dir}/final",
extra_state={"step": total_steps, "completed": True},
storage_type="oss",
bucket_name="your-oss-bucket"
)
logging.info(f"[Rank {rank}] 训练完成!")
dist.destroy_process_group()
4. 高级特性与最佳实践
4.1 检查点验证
python
def validate_checkpoint(
checkpoint_dir: str,
storage_type: str = "oss",
**storage_kwargs
) -> bool:
"""
验证检查点的完整性和可读性
"""
try:
if storage_type == "oss":
reader = OSSStorageReader(
bucket_name=storage_kwargs.get("bucket_name"),
prefix=checkpoint_dir,
**storage_kwargs
)
else:
from torch.distributed.checkpoint import FileSystemReader
reader = FileSystemReader(checkpoint_dir)
# 读取元数据文件
try:
metadata = reader.read("metadata.json")
logging.info("✅ 元数据文件验证通过")
except Exception as e:
logging.error(f"❌ 元数据文件验证失败: {str(e)}")
return False
# 检查rank文件
world_size = json.loads(metadata).get("world_size", 1)
for rank in range(world_size):
try:
# 尝试读取一个rank的文件
reader.read(f"state_dict_rank{rank}.pt")
logging.debug(f"✅ Rank {rank} 文件验证通过")
except Exception as e:
logging.error(f"❌ Rank {rank} 文件验证失败: {str(e)}")
return False
logging.info("✅ 检查点验证通过,所有文件完整")
return True
except Exception as e:
logging.error(f"❌ 检查点验证失败: {str(e)}")
return False
4.2 异步检查点保存
python
from concurrent.futures import ThreadPoolExecutor
import threading
class AsyncCheckpointSaver:
"""
异步检查点保存器,避免阻塞训练
"""
def __init__(self, max_workers: int = 2):
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.lock = threading.Lock()
self.pending_futures = []
def async_save(
self,
model: DistributedModelParallel,
checkpoint_dir: str,
**save_kwargs
) -> None:
"""
异步保存检查点
"""
with self.lock:
future = self.executor.submit(
save_distributed_model,
model=model,
checkpoint_dir=checkpoint_dir,
**save_kwargs
)
self.pending_futures.append(future)
# 清理已完成的future
self.pending_futures = [
f for f in self.pending_futures if not f.done()
]
def wait_all(self) -> None:
"""
等待所有异步保存完成
"""
for future in self.pending_futures:
future.result()
self.pending_futures.clear()
def __del__(self):
self.executor.shutdown(wait=False)
4.3 格式转换工具
python
def convert_dcp_to_torch_save(
dcp_checkpoint_dir: str,
torch_save_path: str,
storage_type: str = "oss",
**storage_kwargs
) -> None:
"""
将DCP格式转换为torch.save格式(用于单机推理)
"""
try:
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
if storage_type == "oss":
# 首先将OSS检查点下载到本地临时目录
import tempfile
import shutil
with tempfile.TemporaryDirectory() as temp_dir:
logging.info(f"从OSS下载检查点到临时目录: {temp_dir}")
# 实现OSS下载逻辑...
# ...
# 转换格式
dcp_to_torch_save(temp_dir, torch_save_path)
else:
dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)
logging.info(f"✅ 格式转换成功: {dcp_checkpoint_dir} -> {torch_save_path}")
except Exception as e:
logging.error(f"❌ 格式转换失败: {str(e)}")
raise
5. 常见问题与解决方案
5.1 分片不匹配问题
python
def load_with_fallback(
model: DistributedModelParallel,
checkpoint_dir: str,
**load_kwargs
) -> Dict[str, Any]:
"""
带回退机制的加载,处理分片不匹配问题
"""
try:
return load_distributed_model(model, checkpoint_dir, **load_kwargs)
except RuntimeError as e:
if "sharding" in str(e).lower() or "shape" in str(e).lower():
logging.warning("检测到分片不匹配,尝试非严格模式加载...")
return load_distributed_model(model, checkpoint_dir, strict=False, **load_kwargs)
raise
5.2 内存优化
python
def optimize_checkpoint_memory(
model: DistributedModelParallel,
optimizer: torch.optim.Optimizer
) -> None:
"""
优化检查点内存使用
"""
# 1. 将模型设置为评估模式
model.eval()
# 2. 清理优化器状态
for group in optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad = None
# 3. 手动触发垃圾回收
import gc
gc.collect()
torch.cuda.empty_cache()
logging.info("✅ 检查点内存优化完成")
6. 总结
通过结合PyTorch官方DCP教程和TorchRec的最佳实践,我们创建了一个完整的分布式模型保存与加载解决方案。关键要点包括:
- 使用Stateful协议 :通过
TorchRecAppState类自动处理复杂状态管理 - OSS集成:实现自定义的OSS存储适配器,支持云端存储
- 弹性训练:利用DCP的load-time resharding特性支持不同GPU配置
- 异步操作:通过异步保存避免阻塞训练过程
- 错误处理:完善的异常处理和回退机制