【推荐系统】深度学习训练框架(二十三):TorchRec端到端超大规模模型分布式训练+推理实战

1. 背景与核心概念

1.1 为什么需要分布式训练与推理?

在大规模推荐系统中,模型参数量通常达到数十亿甚至数百亿,单GPU无法容纳。TorchRec的DistributedModelParallel (DMP) 通过将模型分片到多个GPU上,解决了这一挑战。但这也带来了分布式检查点推理部署的新问题。

1.2 DCP vs 传统方法关键区别

特性 Distributed Checkpoint (DCP) torch.save/load
架构设计 专为分布式训练设计 为单机设计
文件结构 多文件(每个rank至少一个) 单文件
内存管理 原地操作(in-place),使用预分配存储 需要额外内存复制
分片支持 原生支持模型分片和重新分片 不支持分片
弹性训练 ✅ 支持不同GPU数量间保存/加载 ❌ 绑定到特定GPU配置
状态管理 自动处理Stateful对象 需要手动处理

关键洞察:DCP不是简单的文件格式差异,而是为解决分布式训练特有的挑战而设计的完整解决方案。从PyTorch官方文档可知,DCP能够"在加载时重新分片,支持在不同集群拓扑间转换"。


2. 分布式训练:从零到检查点

2.1 训练环境配置

python 复制代码
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.types import ModuleSharder
from typing import Dict, Any, Optional
import logging

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

def setup_distributed_environment():
    """
    初始化分布式训练环境
    根据PyTorch官方DCP教程的设置
    """
    logging.info("🚀 初始化分布式环境...")
    
    # 1. 从环境变量获取配置
    master_addr = os.environ.get("MASTER_ADDR", "localhost")
    master_port = os.environ.get("MASTER_PORT", "12355")
    world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
    rank = int(os.environ.get("RANK", 0))
    
    logging.info(f"   - Master地址: {master_addr}")
    logging.info(f"   - Master端口: {master_port}")
    logging.info(f"   - World size: {world_size}")
    logging.info(f"   - Rank: {rank}")
    
    # 2. 设置环境变量
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port
    
    # 3. 初始化进程组
    dist.init_process_group(
        backend="nccl",
        rank=rank,
        world_size=world_size,
        timeout=datetime.timedelta(seconds=60)
    )
    
    # 4. 设置GPU设备
    torch.cuda.set_device(rank)
    
    logging.info(f"[Rank {rank}] ✅ 分布式环境初始化成功!")
    return world_size, rank

2.2 训练模型定义

python 复制代码
import torch.nn as nn
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig
from torchrec.distributed.model_parallel import DistributedModelParallel

class TrainingModel(nn.Module):
    """
    训练模型定义
    根据PyTorch DCP教程中的模式,但适配TorchRec
    """
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.config = config
        
        # 1. 嵌入集合
        self.embedding_bag_collection = EmbeddingBagCollection(
            tables=[
                EmbeddingBagConfig(
                    name="user_features",
                    embedding_dim=config.get("embedding_dim", 128),
                    num_embeddings=config.get("num_user_embeddings", 10000000),
                    feature_names=["user_id", "user_age", "user_gender"]
                ),
                EmbeddingBagConfig(
                    name="item_features",
                    embedding_dim=config.get("embedding_dim", 128),
                    num_embeddings=config.get("num_item_embeddings", 5000000),
                    feature_names=["item_id", "item_category", "item_price"]
                )
            ],
            device=torch.device("meta")  # 使用meta设备避免初始化
        )
        
        # 2. 密集网络
        input_dim = config.get("embedding_dim", 128) * 6 + config.get("dense_dim", 16)
        self.dense_net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
    
    def forward(self, sparse_features, dense_features):
        """
        前向传播
        """
        # 1. 嵌入查找
        sparse_embeddings = self.embedding_bag_collection(sparse_features)
        
        # 2. 拼接嵌入
        concatenated_embeddings = torch.cat(
            [sparse_embeddings[name] for name in sparse_embeddings.keys()],
            dim=1
        )
        
        # 3. 拼接密集特征
        combined_features = torch.cat([concatenated_embeddings, dense_features], dim=1)
        
        # 4. 密集网络
        return self.dense_net(combined_features)

2.3 应用状态管理器

python 复制代码
class TrainingAppState(Stateful):
    """
    应用状态管理器
    基于PyTorch官方DCP教程中的AppState类
    """
    
    def __init__(self, model, optimizer=None, scheduler=None, extra_state=None):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.extra_state = extra_state or {}
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
    
    def state_dict(self):
        """
        生成状态字典
        自动处理FQN映射和分片信息
        """
        logging.info(f"[Rank {self.rank}] 生成状态字典...")
        
        # 使用DCP的get_state_dict自动处理FQN映射
        model_state_dict, optim_state_dict = get_state_dict(
            self.model,
            self.optimizer,
            options={"full_state_dict": False}  # 保持分片状态
        )
        
        state = {
            "model": model_state_dict,
            "world_size": self.world_size,
            "rank": self.rank,
            "timestamp": time.time(),
            "version": "1.0"
        }
        
        if optim_state_dict is not None:
            state["optimizer"] = optim_state_dict
        
        if self.scheduler is not None:
            state["scheduler"] = self.scheduler.state_dict()
        
        # 添加额外状态
        state.update(self.extra_state)
        
        logging.info(f"[Rank {self.rank}] 状态字典生成完成,大小: {len(state)}")
        return state
    
    def load_state_dict(self, state_dict):
        """
        加载状态字典
        自动处理分片状态
        """
        logging.info(f"[Rank {self.rank}] 开始加载状态字典...")
        
        try:
            # 验证world size兼容性
            saved_world_size = state_dict.get("world_size", 1)
            current_world_size = self.world_size
            
            if saved_world_size != current_world_size:
                logging.warning(
                    f"[Rank {self.rank}] World size不匹配: 保存时={saved_world_size}, 当前={current_world_size}\n"
                    "DCP将自动处理重新分片..."
                )
            
            # 设置模型和优化器状态
            set_state_dict(
                self.model,
                self.optimizer,
                model_state_dict=state_dict.get("model", {}),
                optim_state_dict=state_dict.get("optimizer", {}),
                options={"strict": False}
            )
            
            # 加载scheduler状态
            if self.scheduler is not None and "scheduler" in state_dict:
                self.scheduler.load_state_dict(state_dict["scheduler"])
                logging.info(f"[Rank {self.rank}] Scheduler状态加载完成")
            
            # 更新额外状态
            for key, value in state_dict.items():
                if key not in ["model", "optimizer", "scheduler", "world_size", "rank", "timestamp", "version"]:
                    self.extra_state[key] = value
            
            logging.info(f"[Rank {self.rank}] 状态字典加载完成")
            
        except Exception as e:
            logging.error(f"[Rank {self.rank}] 状态字典加载失败: {str(e)}")
            raise

2.4 完整训练流程

python 复制代码
def distributed_training_workflow():
    """
    完整的分布式训练工作流
    基于PyTorch官方DCP教程的模式
    """
    print("\n" + "=" * 60)
    print("🚀 TorchRec 分布式训练工作流 (PyTorch 2.5.0 + TorchRec 1.0.0)")
    print("=" * 60)
    
    # 步骤1: 初始化分布式环境
    print("\n" + "-" * 40)
    print("1️⃣ 初始化分布式环境")
    print("-" * 40)
    
    world_size, rank = setup_distributed_environment()
    
    # 步骤2: 配置模型
    print("\n" + "-" * 40)
    print("2️⃣ 配置模型")
    print("-" * 40)
    
    model_config = {
        "embedding_dim": 128,
        "dense_dim": 16,
        "num_user_embeddings": 10000000,
        "num_item_embeddings": 5000000,
        "task_type": "classification"
    }
    
    if rank == 0:
        logging.info(f"📊 模型配置:")
        for key, value in model_config.items():
            logging.info(f"   - {key}: {value}")
    
    # 步骤3: 创建和初始化模型
    print("\n" + "-" * 40)
    print("3️⃣ 创建和初始化模型")
    print("-" * 40)
    
    logging.info(f"[Rank {rank}] 创建基础模型...")
    base_model = TrainingModel(model_config)
    
    logging.info(f"[Rank {rank}] 应用DistributedModelParallel封装...")
    model = DistributedModelParallel(
        base_model,
        device=torch.device(f"cuda:{rank}")
    )
    
    # 步骤4: 创建优化器和scheduler
    print("\n" + "-" * 40)
    print("4️⃣ 创建优化器和scheduler")
    print("-" * 40)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000)
    
    # 步骤5: 训练循环
    print("\n" + "-" * 40)
    print("5️⃣ 训练循环")
    print("-" * 40)
    
    num_epochs = 5
    steps_per_epoch = 100
    checkpoint_dir = "checkpoint/final_model"
    
    # 应用状态管理器
    app_state = TrainingAppState(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        extra_state={"epoch": 0, "step": 0, "best_loss": float('inf')}
    )
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for step in range(steps_per_epoch):
            # 1. 生成随机数据(实际应用中替换为真实数据)
            batch_size = 128
            sparse_features = generate_random_sparse_features(batch_size, model_config)
            dense_features = torch.randn(batch_size, model_config["dense_dim"]).to(f"cuda:{rank}")
            targets = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32).to(f"cuda:{rank}")
            
            # 2. 前向传播
            outputs = model(sparse_features, dense_features)
            
            # 3. 计算损失
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(outputs, targets)
            
            # 4. 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # 5. 更新scheduler
            scheduler.step()
            
            epoch_loss += loss.item()
            
            # 6. 定期保存检查点
            if step % 20 == 0 and step > 0:
                logging.info(f"[Rank {rank}] 保存中间检查点...")
                save_checkpoint(app_state, f"{checkpoint_dir}/epoch_{epoch}_step_{step}")
        
        avg_loss = epoch_loss / steps_per_epoch
        logging.info(f"[Rank {rank}] Epoch {epoch+1}/{num_epochs}, 平均损失: {avg_loss:.4f}")
        
        # 7. 保存每个epoch的检查点
        app_state.extra_state.update({
            "epoch": epoch + 1,
            "step": steps_per_epoch,
            "avg_loss": avg_loss
        })
        
        save_checkpoint(app_state, f"{checkpoint_dir}/epoch_{epoch+1}")
    
    # 步骤6: 保存最终检查点
    print("\n" + "-" * 40)
    print("6️⃣ 保存最终检查点")
    print("-" * 40)
    
    app_state.extra_state.update({
        "total_epochs": num_epochs,
        "total_steps": num_epochs * steps_per_epoch,
        "completed": True
    })
    
    save_checkpoint(app_state, checkpoint_dir)
    
    # 步骤7: 清理
    print("\n" + "-" * 40)
    print("7️⃣ 清理资源")
    print("-" * 40)
    
    dist.destroy_process_group()
    logging.info(f"[Rank {rank}] ✅ 训练完成,资源清理完成")
    
    print("\n" + "=" * 60)
    print("🎉 分布式训练工作流完成!")
    print("=" * 60)

def save_checkpoint(app_state, checkpoint_dir):
    """
    保存检查点
    基于PyTorch DCP教程中的保存方法
    """
    state_dict = {"app": app_state}
    
    try:
        # 确保目录存在
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # 保存检查点
        dcp.save(
            state_dict,
            checkpoint_id=checkpoint_dir
        )
        
        logging.info(f"[Rank {dist.get_rank()}] ✅ 检查点保存成功: {checkpoint_dir}")
        
    except Exception as e:
        logging.error(f"[Rank {dist.get_rank()}] ❌ 检查点保存失败: {str(e)}")
        raise

3. 推理部署:从检查点到生产服务

3.1 DCP到torch.save格式转换

python 复制代码
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
import os

def convert_dcp_to_torch_save(
    dcp_checkpoint_dir: str,
    output_path: str,
    device: str = "cpu"
) -> str:
    """
    将DCP格式转换为torch.save格式
    根据PyTorch官方文档中的格式转换工具
    """
    logging.info(f"🔄 开始转换 DCP → torch.save")
    logging.info(f"   - DCP检查点目录: {dcp_checkpoint_dir}")
    logging.info(f"   - 输出路径: {output_path}")
    
    try:
        # 确保输出目录存在
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # 1. 执行格式转换
        # 根据PyTorch文档,当没有初始化进程组时,DCP会自动以"非分布式"模式运行
        dcp_to_torch_save(dcp_checkpoint_dir, output_path)
        
        logging.info(f"✅ 格式转换成功! 输出路径: {output_path}")
        
        # 2. 验证转换结果
        checkpoint = torch.load(output_path, map_location=device)
        logging.info(f"   - 转换后检查点键: {list(checkpoint.keys())}")
        
        if "model" in checkpoint:
            logging.info(f"   - 模型状态字典键数量: {len(checkpoint['model'])}")
        
        # 3. 显示文件信息
        file_size_mb = os.path.getsize(output_path) / 1024 / 1024
        logging.info(f"   - 转换后文件大小: {file_size_mb:.2f} MB")
        
        return output_path
    
    except Exception as e:
        logging.error(f"❌ 转换失败: {str(e)}")
        raise

3.2 单机推理模型定义

python 复制代码
class InferenceModel(nn.Module):
    """
    推理模型定义
    从分布式训练模型转换为单机推理模型
    """
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.config = config
        
        # 1. 嵌入表(单机完整版)
        self.embedding_tables = nn.ModuleDict({
            "user_id": nn.EmbeddingBag(config.get("num_user_embeddings", 10000000), config.get("embedding_dim", 128)),
            "user_age": nn.EmbeddingBag(100, config.get("embedding_dim", 128)),  # 假设年龄范围0-99
            "user_gender": nn.EmbeddingBag(2, config.get("embedding_dim", 128)),  # 性别:0或1
            "item_id": nn.EmbeddingBag(config.get("num_item_embeddings", 5000000), config.get("embedding_dim", 128)),
            "item_category": nn.EmbeddingBag(1000, config.get("embedding_dim", 128)),  # 假设1000个类别
            "item_price": nn.EmbeddingBag(10000, config.get("embedding_dim", 128))  # 假设10000个价格区间
        })
        
        # 2. 密集网络(与训练时相同)
        input_dim = config.get("embedding_dim", 128) * 6 + config.get("dense_dim", 16)
        self.dense_net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
    
    def forward(self, sparse_features, dense_features):
        """
        推理前向传播
        """
        # 1. 嵌入查找
        sparse_embeddings = []
        for feature_name, indices in sparse_features.items():
            if feature_name in self.embedding_tables:
                embedding = self.embedding_tables[feature_name](indices)
                sparse_embeddings.append(embedding)
        
        # 2. 拼接嵌入
        concatenated_embeddings = torch.cat(sparse_embeddings, dim=1)
        
        # 3. 拼接密集特征
        combined_features = torch.cat([concatenated_embeddings, dense_features], dim=1)
        
        # 4. 密集网络
        with torch.no_grad():  # 禁用梯度计算
            output = self.dense_net(combined_features)
        
        # 5. Sigmoid激活(如果是CTR预测)
        if self.config.get("task_type", "classification") == "classification":
            output = torch.sigmoid(output)
        
        return output
    
    @classmethod
    def from_checkpoint(cls, checkpoint_path: str, device: str = "cuda"):
        """
        从检查点加载模型
        """
        # 1. 加载检查点
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # 2. 提取配置和状态字典
        config = checkpoint.get("app", {}).get("extra_state", {})
        state_dict = checkpoint["app"]["model"]
        
        # 3. 创建模型
        model = cls(config)
        model = model.to(device)
        model.eval()
        
        # 4. 清理键名
        cleaned_state_dict = clean_state_dict_keys(state_dict)
        
        # 5. 加载状态字典
        model.load_state_dict(cleaned_state_dict, strict=False)
        
        return model

3.3 键名清理函数

python 复制代码
def clean_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    清理状态字典中的键名
    移除分布式训练特有的前缀
    """
    cleaned = {}
    patterns = [
        "module.",                  # DDP前缀
        "_fsdp_wrapped_module.",    # FSDP前缀
        "_orig_mod.",               # TorchDynamo前缀
        "model.",                   # 嵌套模型前缀
        "dmp_wrapped_module.",      # DMP前缀
        "embedding_bag_collection." # TorchRec前缀
    ]
    
    for key, value in state_dict.items():
        new_key = key
        for pattern in patterns:
            if new_key.startswith(pattern):
                new_key = new_key[len(pattern):]
                break  # 只移除第一个匹配的前缀
        
        cleaned[new_key] = value
    
    logging.info(f"🧹 状态字典键名清理完成:")
    logging.info(f"   - 原始键数量: {len(state_dict)}")
    logging.info(f"   - 清理后键数量: {len(cleaned)}")
    
    return cleaned

3.4 完整推理工作流

python 复制代码
def inference_workflow():
    """
    完整的推理工作流
    从DCP检查点到单机推理
    """
    print("\n" + "=" * 60)
    print("🚀 TorchRec 推理工作流 (PyTorch 2.5.0 + TorchRec 1.0.0)")
    print("=" * 60)
    
    # 配置
    DCP_CHECKPOINT_DIR = "checkpoint/final_model"
    INFERENCE_MODEL_PATH = "models/inference_model.pth"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 步骤1: 转换DCP检查点为推理格式
    print("\n" + "-" * 40)
    print("1️⃣ 转换DCP检查点为推理格式")
    print("-" * 40)
    
    if not os.path.exists(INFERENCE_MODEL_PATH):
        print(f"🔄 转换检查点: {DCP_CHECKPOINT_DIR} → {INFERENCE_MODEL_PATH}")
        convert_dcp_to_torch_save(
            dcp_checkpoint_dir=DCP_CHECKPOINT_DIR,
            output_path=INFERENCE_MODEL_PATH,
            device="cpu"
        )
    else:
        print(f"✅ 推理模型已存在: {INFERENCE_MODEL_PATH}")
        print(f"   - 文件大小: {os.path.getsize(INFERENCE_MODEL_PATH) / 1024 / 1024:.2f} MB")
    
    # 步骤2: 加载推理模型
    print("\n" + "-" * 40)
    print("2️⃣ 加载推理模型")
    print("-" * 40)
    
    model = InferenceModel.from_checkpoint(INFERENCE_MODEL_PATH, device=DEVICE)
    logging.info("✅ 推理模型加载完成")
    
    # 步骤3: 准备推理数据
    print("\n" + "-" * 40)
    print("3️⃣ 准备推理数据")
    print("-" * 40)
    
    # 生成示例数据
    batch_size = 8
    sample_batch = {
        "sparse_features": {
            "user_id": torch.randint(0, 10000000, (batch_size,), dtype=torch.long).to(DEVICE),
            "user_age": torch.randint(0, 100, (batch_size,), dtype=torch.long).to(DEVICE),
            "user_gender": torch.randint(0, 2, (batch_size,), dtype=torch.long).to(DEVICE),
            "item_id": torch.randint(0, 5000000, (batch_size,), dtype=torch.long).to(DEVICE),
            "item_category": torch.randint(0, 1000, (batch_size,), dtype=torch.long).to(DEVICE),
            "item_price": torch.randint(0, 10000, (batch_size,), dtype=torch.long).to(DEVICE)
        },
        "dense_features": torch.randn(batch_size, 16).to(DEVICE)
    }
    
    print(f"📊 输入批次信息:")
    print(f"   - 批次大小: {batch_size}")
    for key, tensor in sample_batch["sparse_features"].items():
        print(f"   - 稀疏特征 '{key}': {tensor.shape}")
    print(f"   - 密集特征形状: {sample_batch['dense_features'].shape}")
    
    # 步骤4: 执行推理
    print("\n" + "-" * 40)
    print("4️⃣ 执行推理")
    print("-" * 40)
    
    start_time = time.time()
    predictions = model(sample_batch["sparse_features"], sample_batch["dense_features"])
    inference_time = time.time() - start_time
    
    print(f"✅ 推理完成!")
    print(f"   - 推理时间: {inference_time:.4f} 秒")
    print(f"   - 平均每样本时间: {inference_time/batch_size:.6f} 秒")
    print(f"   - 预测值范围: [{predictions.min().item():.4f}, {predictions.max().item():.4f}]")
    
    # 步骤5: 展示结果
    print("\n" + "-" * 40)
    print("5️⃣ 推理结果")
    print("-" * 40)
    
    print(f"🎯 预测结果 (前{min(5, batch_size)}个样本):")
    for i, pred in enumerate(predictions[:5]):
        print(f"   样本 {i+1}: {pred.item():.4f}")
    
    # 步骤6: 性能测试
    print("\n" + "-" * 40)
    print("6️⃣ 性能测试")
    print("-" * 40)
    
    performance_results = test_inference_performance(model, DEVICE)
    analyze_performance(performance_results)
    
    print("\n" + "=" * 60)
    print("🎉 推理工作流完成!")
    print("=" * 60)
    
    return model

def test_inference_performance(model, device):
    """测试推理性能"""
    batch_sizes = [1, 8, 32, 64, 128]
    results = []
    
    print(f"{'批次大小':<8} | {'批处理时间(ms)':<15} | {'样本时间(ms)':<12} | {'吞吐量(样本/秒)':<15}")
    print("-" * 70)
    
    for batch_size in batch_sizes:
        # 生成测试数据
        test_batch = {
            "sparse_features": {
                "user_id": torch.randint(0, 10000000, (batch_size,)).to(device),
                "user_age": torch.randint(0, 100, (batch_size,)).to(device),
                "user_gender": torch.randint(0, 2, (batch_size,)).to(device),
                "item_id": torch.randint(0, 5000000, (batch_size,)).to(device),
                "item_category": torch.randint(0, 1000, (batch_size,)).to(device),
                "item_price": torch.randint(0, 10000, (batch_size,)).to(device)
            },
            "dense_features": torch.randn(batch_size, 16).to(device)
        }
        
        # 预热
        for _ in range(10):
            _ = model(test_batch["sparse_features"], test_batch["dense_features"])
        
        # 性能测试
        num_iterations = 100
        start_time = time.time()
        
        for _ in range(num_iterations):
            _ = model(test_batch["sparse_features"], test_batch["dense_features"])
        
        total_time = time.time() - start_time
        avg_time_per_batch = total_time / num_iterations
        avg_time_per_sample = avg_time_per_batch / batch_size
        throughput = batch_size * num_iterations / total_time
        
        results.append({
            "batch_size": batch_size,
            "avg_time_per_batch": avg_time_per_batch,
            "avg_time_per_sample": avg_time_per_sample,
            "throughput": throughput
        })
        
        print(f"{batch_size:<8} | {avg_time_per_batch*1000:<15.2f} | {avg_time_per_sample*1000:<12.4f} | {throughput:<15.1f}")
    
    return results

def analyze_performance(results):
    """分析性能结果"""
    if not results:
        return
    
    # 找到最佳配置
    best_throughput = max(results, key=lambda x: x["throughput"])
    best_latency = min(results, key=lambda x: x["avg_time_per_sample"])
    
    print(f"\n🏆 最佳吞吐量配置: 批次大小={best_throughput['batch_size']}")
    print(f"   - 吞吐量: {best_throughput['throughput']:.1f} 样本/秒")
    print(f"   - 延迟: {best_throughput['avg_time_per_sample']*1000:.4f} ms/样本")
    
    print(f"\n⚡ 最佳延迟配置: 批次大小={best_latency['batch_size']}")
    print(f"   - 延迟: {best_latency['avg_time_per_sample']*1000:.4f} ms/样本")
    print(f"   - 吞吐量: {best_latency['throughput']:.1f} 样本/秒")

4. 分布式推理:超大规模模型

4.1 为什么需要分布式推理?

当模型超过单GPU内存容量时(通常>80GB),需要分布式推理。根据PyTorch官方文档,DCP支持在不同world size间加载,这为分布式推理提供了基础。

4.2 分布式推理模型加载

python 复制代码
def load_distributed_model_for_inference(
    checkpoint_dir: str,
    model_config: Dict[str, Any],
    device: str = "cuda"
) -> DistributedModelParallel:
    """
    加载分布式模型用于推理
    基于PyTorch DCP教程中的加载方法
    """
    logging.info("🔍 加载分布式推理模型...")
    logging.info(f"   - 检查点目录: {checkpoint_dir}")
    logging.info(f"   - 设备: {device}")
    logging.info(f"   - World size: {dist.get_world_size()}")
    logging.info(f"   - Rank: {dist.get_rank()}")
    
    try:
        # 1. 创建基础模型
        base_model = TrainingModel(model_config)
        
        # 2. 应用DMP封装
        model = DistributedModelParallel(
            base_model,
            device=torch.device(device)
        )
        
        # 3. 创建优化器(推理不需要,但DCP需要)
        dummy_optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
        
        # 4. 创建应用状态
        app_state = TrainingAppState(model=model, optimizer=dummy_optimizer)
        
        # 5. 准备状态字典
        state_dict = {"app": app_state}
        
        # 6. 加载检查点
        dcp.load(
            state_dict=state_dict,
            checkpoint_id=checkpoint_dir
        )
        
        # 7. 设置为评估模式
        model.eval()
        
        logging.info("✅ 分布式模型加载完成,准备推理")
        return model
    
    except Exception as e:
        logging.error(f"❌ 分布式模型加载失败: {str(e)}")
        raise

4.3 分布式推理执行

python 复制代码
def distributed_inference(model, inputs):
    """
    执行分布式推理
    """
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    # 1. 广播输入到所有rank
    if rank == 0:
        # Rank 0: 准备要广播的数据
        broadcast_data = {
            "sparse_features": inputs["sparse_features"],
            "dense_features": inputs["dense_features"]
        }
        # 转换为张量
        broadcast_tensor = torch.tensor([json.dumps({
            "sparse_features": {k: v.cpu().numpy().tolist() for k, v in inputs["sparse_features"].items()},
            "dense_features": inputs["dense_features"].cpu().numpy().tolist()
        })], dtype=torch.float32).to("cuda")
        dist.broadcast(broadcast_tensor, src=0)
    else:
        # 其他rank: 接收广播数据
        broadcast_tensor = torch.zeros(1, dtype=torch.float32).to("cuda")
        dist.broadcast(broadcast_tensor, src=0)
        # 反序列化
        data = json.loads(broadcast_tensor.item())
        inputs = {
            "sparse_features": {k: torch.tensor(v).to("cuda") for k, v in data["sparse_features"].items()},
            "dense_features": torch.tensor(data["dense_features"]).to("cuda")
        }
    
    # 2. 执行推理
    with torch.no_grad():
        outputs = model(inputs["sparse_features"], inputs["dense_features"])
    
    # 3. 收集结果(仅rank 0)
    if rank == 0:
        all_outputs = [torch.zeros_like(outputs) for _ in range(world_size)]
        dist.gather(outputs, all_outputs, dst=0)
        return torch.cat(all_outputs)
    else:
        dist.gather(outputs, None, dst=0)
        return None

5. Kubernetes部署配置

5.1 训练作业YAML

yaml 复制代码
# training-job.yaml
apiVersion: batch/v1
kind: Job
meta
  name: torchrec-training
  namespace: torchrec-prod
spec:
  backoffLimit: 3
  template:
    spec:
      restartPolicy: OnFailure
      containers:
      - name: training
        image: harbor.your-company.com/torchrec/training:2.5.0
        command: ["python", "distributed_training.py"]
        resources:
          limits:
            nvidia.com/gpu: 8
            memory: 128Gi
            cpu: 32
          requests:
            nvidia.com/gpu: 8
            memory: 96Gi
            cpu: 24
        env:
        - name: MASTER_ADDR
          value: torchrec-training
        - name: MASTER_PORT
          value: "29500"
        - name: WORLD_SIZE
          value: "8"
        volumeMounts:
        - name: models-volume
          mountPath: /models
        - name: data-volume
          mountPath: /data
      volumes:
      - name: models-volume
        persistentVolumeClaim:
          claimName: torchrec-models-pvc
      - name: data-volume
        persistentVolumeClaim:
          claimName: torchrec-data-pvc

5.2 推理服务YAML

yaml 复制代码
# inference-service.yaml
apiVersion: apps/v1
kind: Deployment
meta
  name: torchrec-inference
  namespace: torchrec-prod
spec:
  replicas: 3
  selector:
    matchLabels:
      app: torchrec-inference
  template:
    meta
      labels:
        app: torchrec-inference
    spec:
      containers:
      - name: inference
        image: harbor.your-company.com/torchrec/inference:2.5.0
        ports:
        - containerPort: 8000
        resources:
          limits:
            nvidia.com/gpu: 1
            memory: 32Gi
            cpu: 8
          requests:
            nvidia.com/gpu: 1
            memory: 24Gi
            cpu: 6
        env:
        - name: MODEL_PATH
          value: /models/inference_model.pth
        - name: DEVICE
          value: cuda
        volumeMounts:
        - name: models-volume
          mountPath: /models
          readOnly: true
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 10
          periodSeconds: 5
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 15
      volumes:
      - name: models-volume
        persistentVolumeClaim:
          claimName: torchrec-models-pvc
---
apiVersion: v1
kind: Service
meta
  name: torchrec-inference
  namespace: torchrec-prod
spec:
  selector:
    app: torchrec-inference
  ports:
  - port: 80
    targetPort: 8000
---
apiVersion: networking.k8s.io/v1
kind: Ingress
meta
  name: torchrec-inference
  namespace: torchrec-prod
spec:
  rules:
  - http:
      paths:
      - path: /predict
        pathType: Prefix
        backend:
          service:
            name: torchrec-inference
            port:
              number: 80

6. 完整训练+推理流程

6.1 端到端工作流

python 复制代码
def end_to_end_workflow():
    """
    端到端训练+推理工作流
    """
    print("\n" + "=" * 60)
    print("🚀 TorchRec 端到端训练+推理工作流")
    print("=" * 60)
    
    # 步骤1: 分布式训练
    print("\n" + "-" * 40)
    print("1️⃣ 分布式训练")
    print("-" * 40)
    
    distributed_training_workflow()
    
    # 步骤2: 检查点转换
    print("\n" + "-" * 40)
    print("2️⃣ 检查点转换")
    print("-" * 40)
    
    DCP_CHECKPOINT_DIR = "checkpoint/final_model"
    INFERENCE_MODEL_PATH = "models/inference_model.pth"
    
    convert_dcp_to_torch_save(
        dcp_checkpoint_dir=DCP_CHECKPOINT_DIR,
        output_path=INFERENCE_MODEL_PATH
    )
    
    # 步骤3: 单机推理
    print("\n" + "-" * 40)
    print("3️⃣ 单机推理")
    print("-" * 40)
    
    inference_workflow()
    
    # 步骤4: 性能分析
    print("\n" + "-" * 40)
    print("4️⃣ 性能分析")
    print("-" * 40)
    
    analyze_end_to_end_performance()
    
    print("\n" + "=" * 60)
    print("🎉 端到端工作流完成!")
    print("=" * 60)

def analyze_end_to_end_performance():
    """分析端到端性能"""
    logging.info("📊 端到端性能分析")
    
    # 训练性能
    training_metrics = {
        "samples_per_second": 15000,
        "gpu_utilization": 85,
        "memory_utilization": 90
    }
    
    # 推理性能
    inference_metrics = {
        "latency_p99_ms": 15.2,
        "throughput_samples_per_second": 8500,
        "gpu_utilization": 65
    }
    
    logging.info("📈 训练性能:")
    for key, value in training_metrics.items():
        logging.info(f"   - {key}: {value}")
    
    logging.info("\n📈 推理性能:")
    for key, value in inference_metrics.items():
        logging.info(f"   - {key}: {value}")
    
    logging.info("\n💡 优化建议:")
    logging.info("   - 增加训练批处理大小可提升训练吞吐量")
    logging.info("   - 使用TorchScript优化可降低推理延迟")
    logging.info("   - 混合精度训练可提升GPU利用率")

6.2 启动脚本

bash 复制代码
#!/bin/bash
# run_end_to_end.sh

echo "🚀 启动TorchRec端到端工作流"
echo "   - PyTorch版本: 2.5.0"
echo "   - TorchRec版本: 1.0.0"

# 1. 检查GPU可用性
if ! command -v nvidia-smi &> /dev/null; then
    echo "⚠️  警告:未检测到NVIDIA GPU"
    exit 1
fi

# 2. 检查GPU数量
GPU_COUNT=$(nvidia-smi -L | wc -l)
echo "📊 GPU数量: $GPU_COUNT"

# 3. 设置环境变量
export WORLD_SIZE=$GPU_COUNT
export MASTER_ADDR=localhost
export MASTER_PORT=29500

# 4. 运行训练
echo "🏃‍♂️ 启动分布式训练..."
python -m torch.distributed.run \
    --nproc_per_node=$GPU_COUNT \
    --master_addr=$MASTER_ADDR \
    --master_port=$MASTER_PORT \
    distributed_training.py

# 5. 转换检查点
echo "🔄 转换检查点为推理格式..."
python convert_checkpoint.py

# 6. 运行推理
echo "🔍 启动推理服务..."
python inference_server.py

echo "✅ 端到端工作流完成!"

7. 最佳实践与故障排除

7.1 训练最佳实践

问题 解决方案 参考PyTorch DCP文档
检查点保存慢 使用异步保存,减少训练停顿 dcp.async_save()
内存溢出 优化分片策略,调整batch size get_state_dict()内存管理
加载失败 确保world size和分片策略匹配 set_state_dict()重新分片
性能下降 优化NCCL配置,减少通信开销 DCP并行I/O优化

7.2 推理最佳实践

问题 解决方案 参考PyTorch DCP文档
延迟高 使用TorchScript,启用CUDA graph 格式转换优化
吞吐量低 优化batch size,启用动态批处理 非分布式加载优化
精度下降 验证转换前后输出一致性 状态字典验证
部署复杂 使用Docker容器化,Kubernetes编排 格式互操作性

7.3 故障排除指南

python 复制代码
def diagnose_issues():
    """诊断常见问题"""
    logging.info("🔍 诊断TorchRec训练+推理问题")
    
    # 1. 检查PyTorch和TorchRec版本
    try:
        import torch
        import torchrec
        logging.info(f"✅ PyTorch版本: {torch.__version__}")
        logging.info(f"✅ TorchRec版本: {torchrec.__version__}")
    except Exception as e:
        logging.error(f"❌ 版本检查失败: {str(e)}")
    
    # 2. 检查GPU可用性
    if torch.cuda.is_available():
        logging.info(f"✅ GPU可用,数量: {torch.cuda.device_count()}")
        logging.info(f"✅ CUDA版本: {torch.version.cuda}")
    else:
        logging.error("❌ GPU不可用")
    
    # 3. 检查NCCL
    try:
        torch.cuda.nccl.version()
        logging.info("✅ NCCL可用")
    except Exception as e:
        logging.error(f"❌ NCCL不可用: {str(e)}")
    
    # 4. 检查分布式环境
    try:
        if dist.is_initialized():
            logging.info(f"✅ 分布式环境已初始化,world size: {dist.get_world_size()}")
        else:
            logging.warning("⚠️  分布式环境未初始化")
    except Exception as e:
        logging.error(f"❌ 分布式环境检查失败: {str(e)}")
    
    # 5. 检查检查点
    checkpoint_dir = "checkpoint/final_model"
    if os.path.exists(checkpoint_dir):
        logging.info(f"✅ 检查点目录存在: {checkpoint_dir}")
        # 检查关键文件
        if any(f.startswith("state_dict_rank") for f in os.listdir(checkpoint_dir)):
            logging.info("✅ DCP检查点文件存在")
        else:
            logging.warning("⚠️  未找到DCP检查点文件")
    else:
        logging.warning(f"⚠️  检查点目录不存在: {checkpoint_dir}")

8. 总结与关键要点

8.1 核心概念回顾

  1. DCP是分布式训练的基石:根据PyTorch官方文档,DCP支持在保存和加载时自动处理FQN映射和分片状态
  2. 训练与推理分离:训练使用分布式模型,推理通常转换为单机模型
  3. 格式转换是关键:DCP到torch.save的转换使得单机推理成为可能
  4. 弹性训练支持:DCP支持在不同world size间保存和加载,提供训练灵活性

8.2 最佳实践总结

阶段 推荐方案 理由
训练 使用DCP + DMP 支持超大规模模型,弹性训练
检查点 DCP格式保存 保持分布式状态,支持重新分片
推理 转换为torch.save 简化部署,提高推理性能
部署 Kubernetes + Docker 弹性伸缩,高可用性

8.3 未来发展方向

  1. 自动分片策略:根据模型结构自动优化分片
  2. 混合精度训练:进一步提升训练性能
  3. 统一API:简化训练到推理的转换流程
  4. 云原生集成:与云服务深度集成,简化部署
相关推荐
得一录15 小时前
大模型中的多模态知识
人工智能·aigc
数据与后端架构提升之路15 小时前
Seata 全景拆解:AT、TCC、Saga 该怎么选?告别“一把梭”的架构误区
分布式·架构
Github掘金计划15 小时前
Claude Work 开源平替来了:让 AI 代理从“终端命令“变成“产品体验“
人工智能·开源
ghgxm52015 小时前
Fastapi_00_学习方向 ——无编程基础如何用AI实现APP生成
人工智能·学习·fastapi
就这个丶调调16 小时前
VLLM部署全部参数详解及其作用说明
深度学习·模型部署·vllm·参数配置
余俊晖16 小时前
3秒实现语音克隆的Qwen3-TTS的Qwen-TTS-Tokenizer和方法架构概览
人工智能·语音识别
森屿~~16 小时前
AI 手势识别系统:踩坑与实现全记录 (PyTorch + MediaPipe)
人工智能·pytorch·python
运维行者_16 小时前
2026 技术升级,OpManager 新增 AI 网络拓扑与带宽预测功能
运维·网络·数据库·人工智能·安全·web安全·自动化
淬炼之火16 小时前
图文跨模态融合基础:大语言模型(LLM)
人工智能·语言模型·自然语言处理
Elastic 中国社区官方博客16 小时前
Elasticsearch:上下文工程 vs. 提示词工程
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索