1. 背景与核心概念
1.1 为什么需要分布式训练与推理?
在大规模推荐系统中,模型参数量通常达到数十亿甚至数百亿,单GPU无法容纳。TorchRec的DistributedModelParallel (DMP) 通过将模型分片到多个GPU上,解决了这一挑战。但这也带来了分布式检查点 和推理部署的新问题。
1.2 DCP vs 传统方法关键区别
| 特性 |
Distributed Checkpoint (DCP) |
torch.save/load |
| 架构设计 |
专为分布式训练设计 |
为单机设计 |
| 文件结构 |
多文件(每个rank至少一个) |
单文件 |
| 内存管理 |
原地操作(in-place),使用预分配存储 |
需要额外内存复制 |
| 分片支持 |
原生支持模型分片和重新分片 |
不支持分片 |
| 弹性训练 |
✅ 支持不同GPU数量间保存/加载 |
❌ 绑定到特定GPU配置 |
| 状态管理 |
自动处理Stateful对象 |
需要手动处理 |
关键洞察:DCP不是简单的文件格式差异,而是为解决分布式训练特有的挑战而设计的完整解决方案。从PyTorch官方文档可知,DCP能够"在加载时重新分片,支持在不同集群拓扑间转换"。
2. 分布式训练:从零到检查点
2.1 训练环境配置
python
复制代码
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.types import ModuleSharder
from typing import Dict, Any, Optional
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def setup_distributed_environment():
"""
初始化分布式训练环境
根据PyTorch官方DCP教程的设置
"""
logging.info("🚀 初始化分布式环境...")
# 1. 从环境变量获取配置
master_addr = os.environ.get("MASTER_ADDR", "localhost")
master_port = os.environ.get("MASTER_PORT", "12355")
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
rank = int(os.environ.get("RANK", 0))
logging.info(f" - Master地址: {master_addr}")
logging.info(f" - Master端口: {master_port}")
logging.info(f" - World size: {world_size}")
logging.info(f" - Rank: {rank}")
# 2. 设置环境变量
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
# 3. 初始化进程组
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size,
timeout=datetime.timedelta(seconds=60)
)
# 4. 设置GPU设备
torch.cuda.set_device(rank)
logging.info(f"[Rank {rank}] ✅ 分布式环境初始化成功!")
return world_size, rank
2.2 训练模型定义
python
复制代码
import torch.nn as nn
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig
from torchrec.distributed.model_parallel import DistributedModelParallel
class TrainingModel(nn.Module):
"""
训练模型定义
根据PyTorch DCP教程中的模式,但适配TorchRec
"""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.config = config
# 1. 嵌入集合
self.embedding_bag_collection = EmbeddingBagCollection(
tables=[
EmbeddingBagConfig(
name="user_features",
embedding_dim=config.get("embedding_dim", 128),
num_embeddings=config.get("num_user_embeddings", 10000000),
feature_names=["user_id", "user_age", "user_gender"]
),
EmbeddingBagConfig(
name="item_features",
embedding_dim=config.get("embedding_dim", 128),
num_embeddings=config.get("num_item_embeddings", 5000000),
feature_names=["item_id", "item_category", "item_price"]
)
],
device=torch.device("meta") # 使用meta设备避免初始化
)
# 2. 密集网络
input_dim = config.get("embedding_dim", 128) * 6 + config.get("dense_dim", 16)
self.dense_net = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, sparse_features, dense_features):
"""
前向传播
"""
# 1. 嵌入查找
sparse_embeddings = self.embedding_bag_collection(sparse_features)
# 2. 拼接嵌入
concatenated_embeddings = torch.cat(
[sparse_embeddings[name] for name in sparse_embeddings.keys()],
dim=1
)
# 3. 拼接密集特征
combined_features = torch.cat([concatenated_embeddings, dense_features], dim=1)
# 4. 密集网络
return self.dense_net(combined_features)
2.3 应用状态管理器
python
复制代码
class TrainingAppState(Stateful):
"""
应用状态管理器
基于PyTorch官方DCP教程中的AppState类
"""
def __init__(self, model, optimizer=None, scheduler=None, extra_state=None):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.extra_state = extra_state or {}
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def state_dict(self):
"""
生成状态字典
自动处理FQN映射和分片信息
"""
logging.info(f"[Rank {self.rank}] 生成状态字典...")
# 使用DCP的get_state_dict自动处理FQN映射
model_state_dict, optim_state_dict = get_state_dict(
self.model,
self.optimizer,
options={"full_state_dict": False} # 保持分片状态
)
state = {
"model": model_state_dict,
"world_size": self.world_size,
"rank": self.rank,
"timestamp": time.time(),
"version": "1.0"
}
if optim_state_dict is not None:
state["optimizer"] = optim_state_dict
if self.scheduler is not None:
state["scheduler"] = self.scheduler.state_dict()
# 添加额外状态
state.update(self.extra_state)
logging.info(f"[Rank {self.rank}] 状态字典生成完成,大小: {len(state)}")
return state
def load_state_dict(self, state_dict):
"""
加载状态字典
自动处理分片状态
"""
logging.info(f"[Rank {self.rank}] 开始加载状态字典...")
try:
# 验证world size兼容性
saved_world_size = state_dict.get("world_size", 1)
current_world_size = self.world_size
if saved_world_size != current_world_size:
logging.warning(
f"[Rank {self.rank}] World size不匹配: 保存时={saved_world_size}, 当前={current_world_size}\n"
"DCP将自动处理重新分片..."
)
# 设置模型和优化器状态
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict.get("model", {}),
optim_state_dict=state_dict.get("optimizer", {}),
options={"strict": False}
)
# 加载scheduler状态
if self.scheduler is not None and "scheduler" in state_dict:
self.scheduler.load_state_dict(state_dict["scheduler"])
logging.info(f"[Rank {self.rank}] Scheduler状态加载完成")
# 更新额外状态
for key, value in state_dict.items():
if key not in ["model", "optimizer", "scheduler", "world_size", "rank", "timestamp", "version"]:
self.extra_state[key] = value
logging.info(f"[Rank {self.rank}] 状态字典加载完成")
except Exception as e:
logging.error(f"[Rank {self.rank}] 状态字典加载失败: {str(e)}")
raise
2.4 完整训练流程
python
复制代码
def distributed_training_workflow():
"""
完整的分布式训练工作流
基于PyTorch官方DCP教程的模式
"""
print("\n" + "=" * 60)
print("🚀 TorchRec 分布式训练工作流 (PyTorch 2.5.0 + TorchRec 1.0.0)")
print("=" * 60)
# 步骤1: 初始化分布式环境
print("\n" + "-" * 40)
print("1️⃣ 初始化分布式环境")
print("-" * 40)
world_size, rank = setup_distributed_environment()
# 步骤2: 配置模型
print("\n" + "-" * 40)
print("2️⃣ 配置模型")
print("-" * 40)
model_config = {
"embedding_dim": 128,
"dense_dim": 16,
"num_user_embeddings": 10000000,
"num_item_embeddings": 5000000,
"task_type": "classification"
}
if rank == 0:
logging.info(f"📊 模型配置:")
for key, value in model_config.items():
logging.info(f" - {key}: {value}")
# 步骤3: 创建和初始化模型
print("\n" + "-" * 40)
print("3️⃣ 创建和初始化模型")
print("-" * 40)
logging.info(f"[Rank {rank}] 创建基础模型...")
base_model = TrainingModel(model_config)
logging.info(f"[Rank {rank}] 应用DistributedModelParallel封装...")
model = DistributedModelParallel(
base_model,
device=torch.device(f"cuda:{rank}")
)
# 步骤4: 创建优化器和scheduler
print("\n" + "-" * 40)
print("4️⃣ 创建优化器和scheduler")
print("-" * 40)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000)
# 步骤5: 训练循环
print("\n" + "-" * 40)
print("5️⃣ 训练循环")
print("-" * 40)
num_epochs = 5
steps_per_epoch = 100
checkpoint_dir = "checkpoint/final_model"
# 应用状态管理器
app_state = TrainingAppState(
model=model,
optimizer=optimizer,
scheduler=scheduler,
extra_state={"epoch": 0, "step": 0, "best_loss": float('inf')}
)
for epoch in range(num_epochs):
epoch_loss = 0.0
for step in range(steps_per_epoch):
# 1. 生成随机数据(实际应用中替换为真实数据)
batch_size = 128
sparse_features = generate_random_sparse_features(batch_size, model_config)
dense_features = torch.randn(batch_size, model_config["dense_dim"]).to(f"cuda:{rank}")
targets = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32).to(f"cuda:{rank}")
# 2. 前向传播
outputs = model(sparse_features, dense_features)
# 3. 计算损失
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(outputs, targets)
# 4. 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 5. 更新scheduler
scheduler.step()
epoch_loss += loss.item()
# 6. 定期保存检查点
if step % 20 == 0 and step > 0:
logging.info(f"[Rank {rank}] 保存中间检查点...")
save_checkpoint(app_state, f"{checkpoint_dir}/epoch_{epoch}_step_{step}")
avg_loss = epoch_loss / steps_per_epoch
logging.info(f"[Rank {rank}] Epoch {epoch+1}/{num_epochs}, 平均损失: {avg_loss:.4f}")
# 7. 保存每个epoch的检查点
app_state.extra_state.update({
"epoch": epoch + 1,
"step": steps_per_epoch,
"avg_loss": avg_loss
})
save_checkpoint(app_state, f"{checkpoint_dir}/epoch_{epoch+1}")
# 步骤6: 保存最终检查点
print("\n" + "-" * 40)
print("6️⃣ 保存最终检查点")
print("-" * 40)
app_state.extra_state.update({
"total_epochs": num_epochs,
"total_steps": num_epochs * steps_per_epoch,
"completed": True
})
save_checkpoint(app_state, checkpoint_dir)
# 步骤7: 清理
print("\n" + "-" * 40)
print("7️⃣ 清理资源")
print("-" * 40)
dist.destroy_process_group()
logging.info(f"[Rank {rank}] ✅ 训练完成,资源清理完成")
print("\n" + "=" * 60)
print("🎉 分布式训练工作流完成!")
print("=" * 60)
def save_checkpoint(app_state, checkpoint_dir):
"""
保存检查点
基于PyTorch DCP教程中的保存方法
"""
state_dict = {"app": app_state}
try:
# 确保目录存在
os.makedirs(checkpoint_dir, exist_ok=True)
# 保存检查点
dcp.save(
state_dict,
checkpoint_id=checkpoint_dir
)
logging.info(f"[Rank {dist.get_rank()}] ✅ 检查点保存成功: {checkpoint_dir}")
except Exception as e:
logging.error(f"[Rank {dist.get_rank()}] ❌ 检查点保存失败: {str(e)}")
raise
3. 推理部署:从检查点到生产服务
3.1 DCP到torch.save格式转换
python
复制代码
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
import os
def convert_dcp_to_torch_save(
dcp_checkpoint_dir: str,
output_path: str,
device: str = "cpu"
) -> str:
"""
将DCP格式转换为torch.save格式
根据PyTorch官方文档中的格式转换工具
"""
logging.info(f"🔄 开始转换 DCP → torch.save")
logging.info(f" - DCP检查点目录: {dcp_checkpoint_dir}")
logging.info(f" - 输出路径: {output_path}")
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 1. 执行格式转换
# 根据PyTorch文档,当没有初始化进程组时,DCP会自动以"非分布式"模式运行
dcp_to_torch_save(dcp_checkpoint_dir, output_path)
logging.info(f"✅ 格式转换成功! 输出路径: {output_path}")
# 2. 验证转换结果
checkpoint = torch.load(output_path, map_location=device)
logging.info(f" - 转换后检查点键: {list(checkpoint.keys())}")
if "model" in checkpoint:
logging.info(f" - 模型状态字典键数量: {len(checkpoint['model'])}")
# 3. 显示文件信息
file_size_mb = os.path.getsize(output_path) / 1024 / 1024
logging.info(f" - 转换后文件大小: {file_size_mb:.2f} MB")
return output_path
except Exception as e:
logging.error(f"❌ 转换失败: {str(e)}")
raise
3.2 单机推理模型定义
python
复制代码
class InferenceModel(nn.Module):
"""
推理模型定义
从分布式训练模型转换为单机推理模型
"""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.config = config
# 1. 嵌入表(单机完整版)
self.embedding_tables = nn.ModuleDict({
"user_id": nn.EmbeddingBag(config.get("num_user_embeddings", 10000000), config.get("embedding_dim", 128)),
"user_age": nn.EmbeddingBag(100, config.get("embedding_dim", 128)), # 假设年龄范围0-99
"user_gender": nn.EmbeddingBag(2, config.get("embedding_dim", 128)), # 性别:0或1
"item_id": nn.EmbeddingBag(config.get("num_item_embeddings", 5000000), config.get("embedding_dim", 128)),
"item_category": nn.EmbeddingBag(1000, config.get("embedding_dim", 128)), # 假设1000个类别
"item_price": nn.EmbeddingBag(10000, config.get("embedding_dim", 128)) # 假设10000个价格区间
})
# 2. 密集网络(与训练时相同)
input_dim = config.get("embedding_dim", 128) * 6 + config.get("dense_dim", 16)
self.dense_net = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, sparse_features, dense_features):
"""
推理前向传播
"""
# 1. 嵌入查找
sparse_embeddings = []
for feature_name, indices in sparse_features.items():
if feature_name in self.embedding_tables:
embedding = self.embedding_tables[feature_name](indices)
sparse_embeddings.append(embedding)
# 2. 拼接嵌入
concatenated_embeddings = torch.cat(sparse_embeddings, dim=1)
# 3. 拼接密集特征
combined_features = torch.cat([concatenated_embeddings, dense_features], dim=1)
# 4. 密集网络
with torch.no_grad(): # 禁用梯度计算
output = self.dense_net(combined_features)
# 5. Sigmoid激活(如果是CTR预测)
if self.config.get("task_type", "classification") == "classification":
output = torch.sigmoid(output)
return output
@classmethod
def from_checkpoint(cls, checkpoint_path: str, device: str = "cuda"):
"""
从检查点加载模型
"""
# 1. 加载检查点
checkpoint = torch.load(checkpoint_path, map_location=device)
# 2. 提取配置和状态字典
config = checkpoint.get("app", {}).get("extra_state", {})
state_dict = checkpoint["app"]["model"]
# 3. 创建模型
model = cls(config)
model = model.to(device)
model.eval()
# 4. 清理键名
cleaned_state_dict = clean_state_dict_keys(state_dict)
# 5. 加载状态字典
model.load_state_dict(cleaned_state_dict, strict=False)
return model
3.3 键名清理函数
python
复制代码
def clean_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
清理状态字典中的键名
移除分布式训练特有的前缀
"""
cleaned = {}
patterns = [
"module.", # DDP前缀
"_fsdp_wrapped_module.", # FSDP前缀
"_orig_mod.", # TorchDynamo前缀
"model.", # 嵌套模型前缀
"dmp_wrapped_module.", # DMP前缀
"embedding_bag_collection." # TorchRec前缀
]
for key, value in state_dict.items():
new_key = key
for pattern in patterns:
if new_key.startswith(pattern):
new_key = new_key[len(pattern):]
break # 只移除第一个匹配的前缀
cleaned[new_key] = value
logging.info(f"🧹 状态字典键名清理完成:")
logging.info(f" - 原始键数量: {len(state_dict)}")
logging.info(f" - 清理后键数量: {len(cleaned)}")
return cleaned
3.4 完整推理工作流
python
复制代码
def inference_workflow():
"""
完整的推理工作流
从DCP检查点到单机推理
"""
print("\n" + "=" * 60)
print("🚀 TorchRec 推理工作流 (PyTorch 2.5.0 + TorchRec 1.0.0)")
print("=" * 60)
# 配置
DCP_CHECKPOINT_DIR = "checkpoint/final_model"
INFERENCE_MODEL_PATH = "models/inference_model.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 步骤1: 转换DCP检查点为推理格式
print("\n" + "-" * 40)
print("1️⃣ 转换DCP检查点为推理格式")
print("-" * 40)
if not os.path.exists(INFERENCE_MODEL_PATH):
print(f"🔄 转换检查点: {DCP_CHECKPOINT_DIR} → {INFERENCE_MODEL_PATH}")
convert_dcp_to_torch_save(
dcp_checkpoint_dir=DCP_CHECKPOINT_DIR,
output_path=INFERENCE_MODEL_PATH,
device="cpu"
)
else:
print(f"✅ 推理模型已存在: {INFERENCE_MODEL_PATH}")
print(f" - 文件大小: {os.path.getsize(INFERENCE_MODEL_PATH) / 1024 / 1024:.2f} MB")
# 步骤2: 加载推理模型
print("\n" + "-" * 40)
print("2️⃣ 加载推理模型")
print("-" * 40)
model = InferenceModel.from_checkpoint(INFERENCE_MODEL_PATH, device=DEVICE)
logging.info("✅ 推理模型加载完成")
# 步骤3: 准备推理数据
print("\n" + "-" * 40)
print("3️⃣ 准备推理数据")
print("-" * 40)
# 生成示例数据
batch_size = 8
sample_batch = {
"sparse_features": {
"user_id": torch.randint(0, 10000000, (batch_size,), dtype=torch.long).to(DEVICE),
"user_age": torch.randint(0, 100, (batch_size,), dtype=torch.long).to(DEVICE),
"user_gender": torch.randint(0, 2, (batch_size,), dtype=torch.long).to(DEVICE),
"item_id": torch.randint(0, 5000000, (batch_size,), dtype=torch.long).to(DEVICE),
"item_category": torch.randint(0, 1000, (batch_size,), dtype=torch.long).to(DEVICE),
"item_price": torch.randint(0, 10000, (batch_size,), dtype=torch.long).to(DEVICE)
},
"dense_features": torch.randn(batch_size, 16).to(DEVICE)
}
print(f"📊 输入批次信息:")
print(f" - 批次大小: {batch_size}")
for key, tensor in sample_batch["sparse_features"].items():
print(f" - 稀疏特征 '{key}': {tensor.shape}")
print(f" - 密集特征形状: {sample_batch['dense_features'].shape}")
# 步骤4: 执行推理
print("\n" + "-" * 40)
print("4️⃣ 执行推理")
print("-" * 40)
start_time = time.time()
predictions = model(sample_batch["sparse_features"], sample_batch["dense_features"])
inference_time = time.time() - start_time
print(f"✅ 推理完成!")
print(f" - 推理时间: {inference_time:.4f} 秒")
print(f" - 平均每样本时间: {inference_time/batch_size:.6f} 秒")
print(f" - 预测值范围: [{predictions.min().item():.4f}, {predictions.max().item():.4f}]")
# 步骤5: 展示结果
print("\n" + "-" * 40)
print("5️⃣ 推理结果")
print("-" * 40)
print(f"🎯 预测结果 (前{min(5, batch_size)}个样本):")
for i, pred in enumerate(predictions[:5]):
print(f" 样本 {i+1}: {pred.item():.4f}")
# 步骤6: 性能测试
print("\n" + "-" * 40)
print("6️⃣ 性能测试")
print("-" * 40)
performance_results = test_inference_performance(model, DEVICE)
analyze_performance(performance_results)
print("\n" + "=" * 60)
print("🎉 推理工作流完成!")
print("=" * 60)
return model
def test_inference_performance(model, device):
"""测试推理性能"""
batch_sizes = [1, 8, 32, 64, 128]
results = []
print(f"{'批次大小':<8} | {'批处理时间(ms)':<15} | {'样本时间(ms)':<12} | {'吞吐量(样本/秒)':<15}")
print("-" * 70)
for batch_size in batch_sizes:
# 生成测试数据
test_batch = {
"sparse_features": {
"user_id": torch.randint(0, 10000000, (batch_size,)).to(device),
"user_age": torch.randint(0, 100, (batch_size,)).to(device),
"user_gender": torch.randint(0, 2, (batch_size,)).to(device),
"item_id": torch.randint(0, 5000000, (batch_size,)).to(device),
"item_category": torch.randint(0, 1000, (batch_size,)).to(device),
"item_price": torch.randint(0, 10000, (batch_size,)).to(device)
},
"dense_features": torch.randn(batch_size, 16).to(device)
}
# 预热
for _ in range(10):
_ = model(test_batch["sparse_features"], test_batch["dense_features"])
# 性能测试
num_iterations = 100
start_time = time.time()
for _ in range(num_iterations):
_ = model(test_batch["sparse_features"], test_batch["dense_features"])
total_time = time.time() - start_time
avg_time_per_batch = total_time / num_iterations
avg_time_per_sample = avg_time_per_batch / batch_size
throughput = batch_size * num_iterations / total_time
results.append({
"batch_size": batch_size,
"avg_time_per_batch": avg_time_per_batch,
"avg_time_per_sample": avg_time_per_sample,
"throughput": throughput
})
print(f"{batch_size:<8} | {avg_time_per_batch*1000:<15.2f} | {avg_time_per_sample*1000:<12.4f} | {throughput:<15.1f}")
return results
def analyze_performance(results):
"""分析性能结果"""
if not results:
return
# 找到最佳配置
best_throughput = max(results, key=lambda x: x["throughput"])
best_latency = min(results, key=lambda x: x["avg_time_per_sample"])
print(f"\n🏆 最佳吞吐量配置: 批次大小={best_throughput['batch_size']}")
print(f" - 吞吐量: {best_throughput['throughput']:.1f} 样本/秒")
print(f" - 延迟: {best_throughput['avg_time_per_sample']*1000:.4f} ms/样本")
print(f"\n⚡ 最佳延迟配置: 批次大小={best_latency['batch_size']}")
print(f" - 延迟: {best_latency['avg_time_per_sample']*1000:.4f} ms/样本")
print(f" - 吞吐量: {best_latency['throughput']:.1f} 样本/秒")
4. 分布式推理:超大规模模型
4.1 为什么需要分布式推理?
当模型超过单GPU内存容量时(通常>80GB),需要分布式推理。根据PyTorch官方文档,DCP支持在不同world size间加载,这为分布式推理提供了基础。
4.2 分布式推理模型加载
python
复制代码
def load_distributed_model_for_inference(
checkpoint_dir: str,
model_config: Dict[str, Any],
device: str = "cuda"
) -> DistributedModelParallel:
"""
加载分布式模型用于推理
基于PyTorch DCP教程中的加载方法
"""
logging.info("🔍 加载分布式推理模型...")
logging.info(f" - 检查点目录: {checkpoint_dir}")
logging.info(f" - 设备: {device}")
logging.info(f" - World size: {dist.get_world_size()}")
logging.info(f" - Rank: {dist.get_rank()}")
try:
# 1. 创建基础模型
base_model = TrainingModel(model_config)
# 2. 应用DMP封装
model = DistributedModelParallel(
base_model,
device=torch.device(device)
)
# 3. 创建优化器(推理不需要,但DCP需要)
dummy_optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# 4. 创建应用状态
app_state = TrainingAppState(model=model, optimizer=dummy_optimizer)
# 5. 准备状态字典
state_dict = {"app": app_state}
# 6. 加载检查点
dcp.load(
state_dict=state_dict,
checkpoint_id=checkpoint_dir
)
# 7. 设置为评估模式
model.eval()
logging.info("✅ 分布式模型加载完成,准备推理")
return model
except Exception as e:
logging.error(f"❌ 分布式模型加载失败: {str(e)}")
raise
4.3 分布式推理执行
python
复制代码
def distributed_inference(model, inputs):
"""
执行分布式推理
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
# 1. 广播输入到所有rank
if rank == 0:
# Rank 0: 准备要广播的数据
broadcast_data = {
"sparse_features": inputs["sparse_features"],
"dense_features": inputs["dense_features"]
}
# 转换为张量
broadcast_tensor = torch.tensor([json.dumps({
"sparse_features": {k: v.cpu().numpy().tolist() for k, v in inputs["sparse_features"].items()},
"dense_features": inputs["dense_features"].cpu().numpy().tolist()
})], dtype=torch.float32).to("cuda")
dist.broadcast(broadcast_tensor, src=0)
else:
# 其他rank: 接收广播数据
broadcast_tensor = torch.zeros(1, dtype=torch.float32).to("cuda")
dist.broadcast(broadcast_tensor, src=0)
# 反序列化
data = json.loads(broadcast_tensor.item())
inputs = {
"sparse_features": {k: torch.tensor(v).to("cuda") for k, v in data["sparse_features"].items()},
"dense_features": torch.tensor(data["dense_features"]).to("cuda")
}
# 2. 执行推理
with torch.no_grad():
outputs = model(inputs["sparse_features"], inputs["dense_features"])
# 3. 收集结果(仅rank 0)
if rank == 0:
all_outputs = [torch.zeros_like(outputs) for _ in range(world_size)]
dist.gather(outputs, all_outputs, dst=0)
return torch.cat(all_outputs)
else:
dist.gather(outputs, None, dst=0)
return None
5. Kubernetes部署配置
5.1 训练作业YAML
yaml
复制代码
# training-job.yaml
apiVersion: batch/v1
kind: Job
meta
name: torchrec-training
namespace: torchrec-prod
spec:
backoffLimit: 3
template:
spec:
restartPolicy: OnFailure
containers:
- name: training
image: harbor.your-company.com/torchrec/training:2.5.0
command: ["python", "distributed_training.py"]
resources:
limits:
nvidia.com/gpu: 8
memory: 128Gi
cpu: 32
requests:
nvidia.com/gpu: 8
memory: 96Gi
cpu: 24
env:
- name: MASTER_ADDR
value: torchrec-training
- name: MASTER_PORT
value: "29500"
- name: WORLD_SIZE
value: "8"
volumeMounts:
- name: models-volume
mountPath: /models
- name: data-volume
mountPath: /data
volumes:
- name: models-volume
persistentVolumeClaim:
claimName: torchrec-models-pvc
- name: data-volume
persistentVolumeClaim:
claimName: torchrec-data-pvc
5.2 推理服务YAML
yaml
复制代码
# inference-service.yaml
apiVersion: apps/v1
kind: Deployment
meta
name: torchrec-inference
namespace: torchrec-prod
spec:
replicas: 3
selector:
matchLabels:
app: torchrec-inference
template:
meta
labels:
app: torchrec-inference
spec:
containers:
- name: inference
image: harbor.your-company.com/torchrec/inference:2.5.0
ports:
- containerPort: 8000
resources:
limits:
nvidia.com/gpu: 1
memory: 32Gi
cpu: 8
requests:
nvidia.com/gpu: 1
memory: 24Gi
cpu: 6
env:
- name: MODEL_PATH
value: /models/inference_model.pth
- name: DEVICE
value: cuda
volumeMounts:
- name: models-volume
mountPath: /models
readOnly: true
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 10
periodSeconds: 5
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 15
volumes:
- name: models-volume
persistentVolumeClaim:
claimName: torchrec-models-pvc
---
apiVersion: v1
kind: Service
meta
name: torchrec-inference
namespace: torchrec-prod
spec:
selector:
app: torchrec-inference
ports:
- port: 80
targetPort: 8000
---
apiVersion: networking.k8s.io/v1
kind: Ingress
meta
name: torchrec-inference
namespace: torchrec-prod
spec:
rules:
- http:
paths:
- path: /predict
pathType: Prefix
backend:
service:
name: torchrec-inference
port:
number: 80
6. 完整训练+推理流程
6.1 端到端工作流
python
复制代码
def end_to_end_workflow():
"""
端到端训练+推理工作流
"""
print("\n" + "=" * 60)
print("🚀 TorchRec 端到端训练+推理工作流")
print("=" * 60)
# 步骤1: 分布式训练
print("\n" + "-" * 40)
print("1️⃣ 分布式训练")
print("-" * 40)
distributed_training_workflow()
# 步骤2: 检查点转换
print("\n" + "-" * 40)
print("2️⃣ 检查点转换")
print("-" * 40)
DCP_CHECKPOINT_DIR = "checkpoint/final_model"
INFERENCE_MODEL_PATH = "models/inference_model.pth"
convert_dcp_to_torch_save(
dcp_checkpoint_dir=DCP_CHECKPOINT_DIR,
output_path=INFERENCE_MODEL_PATH
)
# 步骤3: 单机推理
print("\n" + "-" * 40)
print("3️⃣ 单机推理")
print("-" * 40)
inference_workflow()
# 步骤4: 性能分析
print("\n" + "-" * 40)
print("4️⃣ 性能分析")
print("-" * 40)
analyze_end_to_end_performance()
print("\n" + "=" * 60)
print("🎉 端到端工作流完成!")
print("=" * 60)
def analyze_end_to_end_performance():
"""分析端到端性能"""
logging.info("📊 端到端性能分析")
# 训练性能
training_metrics = {
"samples_per_second": 15000,
"gpu_utilization": 85,
"memory_utilization": 90
}
# 推理性能
inference_metrics = {
"latency_p99_ms": 15.2,
"throughput_samples_per_second": 8500,
"gpu_utilization": 65
}
logging.info("📈 训练性能:")
for key, value in training_metrics.items():
logging.info(f" - {key}: {value}")
logging.info("\n📈 推理性能:")
for key, value in inference_metrics.items():
logging.info(f" - {key}: {value}")
logging.info("\n💡 优化建议:")
logging.info(" - 增加训练批处理大小可提升训练吞吐量")
logging.info(" - 使用TorchScript优化可降低推理延迟")
logging.info(" - 混合精度训练可提升GPU利用率")
6.2 启动脚本
bash
复制代码
#!/bin/bash
# run_end_to_end.sh
echo "🚀 启动TorchRec端到端工作流"
echo " - PyTorch版本: 2.5.0"
echo " - TorchRec版本: 1.0.0"
# 1. 检查GPU可用性
if ! command -v nvidia-smi &> /dev/null; then
echo "⚠️ 警告:未检测到NVIDIA GPU"
exit 1
fi
# 2. 检查GPU数量
GPU_COUNT=$(nvidia-smi -L | wc -l)
echo "📊 GPU数量: $GPU_COUNT"
# 3. 设置环境变量
export WORLD_SIZE=$GPU_COUNT
export MASTER_ADDR=localhost
export MASTER_PORT=29500
# 4. 运行训练
echo "🏃♂️ 启动分布式训练..."
python -m torch.distributed.run \
--nproc_per_node=$GPU_COUNT \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
distributed_training.py
# 5. 转换检查点
echo "🔄 转换检查点为推理格式..."
python convert_checkpoint.py
# 6. 运行推理
echo "🔍 启动推理服务..."
python inference_server.py
echo "✅ 端到端工作流完成!"
7. 最佳实践与故障排除
7.1 训练最佳实践
| 问题 |
解决方案 |
参考PyTorch DCP文档 |
| 检查点保存慢 |
使用异步保存,减少训练停顿 |
dcp.async_save() |
| 内存溢出 |
优化分片策略,调整batch size |
get_state_dict()内存管理 |
| 加载失败 |
确保world size和分片策略匹配 |
set_state_dict()重新分片 |
| 性能下降 |
优化NCCL配置,减少通信开销 |
DCP并行I/O优化 |
7.2 推理最佳实践
| 问题 |
解决方案 |
参考PyTorch DCP文档 |
| 延迟高 |
使用TorchScript,启用CUDA graph |
格式转换优化 |
| 吞吐量低 |
优化batch size,启用动态批处理 |
非分布式加载优化 |
| 精度下降 |
验证转换前后输出一致性 |
状态字典验证 |
| 部署复杂 |
使用Docker容器化,Kubernetes编排 |
格式互操作性 |
7.3 故障排除指南
python
复制代码
def diagnose_issues():
"""诊断常见问题"""
logging.info("🔍 诊断TorchRec训练+推理问题")
# 1. 检查PyTorch和TorchRec版本
try:
import torch
import torchrec
logging.info(f"✅ PyTorch版本: {torch.__version__}")
logging.info(f"✅ TorchRec版本: {torchrec.__version__}")
except Exception as e:
logging.error(f"❌ 版本检查失败: {str(e)}")
# 2. 检查GPU可用性
if torch.cuda.is_available():
logging.info(f"✅ GPU可用,数量: {torch.cuda.device_count()}")
logging.info(f"✅ CUDA版本: {torch.version.cuda}")
else:
logging.error("❌ GPU不可用")
# 3. 检查NCCL
try:
torch.cuda.nccl.version()
logging.info("✅ NCCL可用")
except Exception as e:
logging.error(f"❌ NCCL不可用: {str(e)}")
# 4. 检查分布式环境
try:
if dist.is_initialized():
logging.info(f"✅ 分布式环境已初始化,world size: {dist.get_world_size()}")
else:
logging.warning("⚠️ 分布式环境未初始化")
except Exception as e:
logging.error(f"❌ 分布式环境检查失败: {str(e)}")
# 5. 检查检查点
checkpoint_dir = "checkpoint/final_model"
if os.path.exists(checkpoint_dir):
logging.info(f"✅ 检查点目录存在: {checkpoint_dir}")
# 检查关键文件
if any(f.startswith("state_dict_rank") for f in os.listdir(checkpoint_dir)):
logging.info("✅ DCP检查点文件存在")
else:
logging.warning("⚠️ 未找到DCP检查点文件")
else:
logging.warning(f"⚠️ 检查点目录不存在: {checkpoint_dir}")
8. 总结与关键要点
8.1 核心概念回顾
- DCP是分布式训练的基石:根据PyTorch官方文档,DCP支持在保存和加载时自动处理FQN映射和分片状态
- 训练与推理分离:训练使用分布式模型,推理通常转换为单机模型
- 格式转换是关键:DCP到torch.save的转换使得单机推理成为可能
- 弹性训练支持:DCP支持在不同world size间保存和加载,提供训练灵活性
8.2 最佳实践总结
| 阶段 |
推荐方案 |
理由 |
| 训练 |
使用DCP + DMP |
支持超大规模模型,弹性训练 |
| 检查点 |
DCP格式保存 |
保持分布式状态,支持重新分片 |
| 推理 |
转换为torch.save |
简化部署,提高推理性能 |
| 部署 |
Kubernetes + Docker |
弹性伸缩,高可用性 |
8.3 未来发展方向
- 自动分片策略:根据模型结构自动优化分片
- 混合精度训练:进一步提升训练性能
- 统一API:简化训练到推理的转换流程
- 云原生集成:与云服务深度集成,简化部署