【推荐系统】深度学习训练框架(十六):模型并行——推荐系统的TorchRec和大语言模型的FSDP(Fully Sharded Data Parallel)

📦 第一部分:TorchRec 实战教程

TorchRec是PyTorch的领域库,专为大规模推荐系统设计。其核心是解决超大规模嵌入表在多GPU/多节点上的高效训练问题。

1. 安装与环境配置

首先安装TorchRec及其依赖。推荐使用CUDA环境以获得最佳性能。

bash 复制代码
# 1. 安装对应CUDA版本的PyTorch (以CUDA 12.1为例)
pip install torch --index-url https://download.pytorch.org/whl/cu121

# 2. 安装FBGEMM GPU版本和TorchRec
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
pip install torchrec --index-url https://download.pytorch.org/whl/cu121

# 3. 如果是纯CPU环境(性能较低)
# pip uninstall fbgemm-gpu -y
# pip install fbgemm-gpu-cpu
# pip install torchrec

2. 核心概念:分片与并行

理解以下两个关键模块是使用TorchRec的基础:

  • 分片器(Sharder) :定义如何将巨大的嵌入表切割并分布到不同设备上。TorchRec支持多种分片策略,如按行(row_wise)、按表(table_wise)等。
  • 分布式模型并行(DistributedModelParallel, DMP) :这是TorchRec最核心的高级API。它类似于PyTorch的DistributedDataParallel,但专为封装已分片的稀疏模型部分(嵌入表)和稠密模型部分(如MLP)而设计。

3. 实战:构建一个分布式推荐模型

下面通过一个简化的代码示例,展示如何使用TorchRec的关键组件。

python 复制代码
import torch
import torch.nn as nn
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.distributed.model_parallel import (
    get_default_sharders,
)

# 1. 定义模型(以最简单的稠密-稀疏交互为例)
class SimpleRecModel(nn.Module):
    def __init__(self, embedding_bag_collection):
        super().__init__()
        self.ebc = embedding_bag_collection
        # 假设稀疏特征维度总和为512
        self.dense = nn.Linear(512, 1)

    def forward(self, sparse_features):
        embeddings = self.ebc(sparse_features)  # 获得稀疏特征嵌入向量
        concatenated = torch.cat([emb for emb in embeddings.values()], dim=1)
        return self.dense(concatenated)

# 2. 初始化分布式环境(必须在代码最开头)
import torch.distributed as dist
dist.init_process_group(backend="nccl")  # GPU用NCCL,CPU用gloo
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")

# 3. 配置嵌入表
embedding_configs = [
    EmbeddingBagConfig(
        name="user",
        embedding_dim=128,
        num_embeddings=100_0000,  # 一百万用户
        feature_names=["user_feature"],
    ),
    EmbeddingBagConfig(
        name="item",
        embedding_dim=128,
        num_embeddings=50_0000,   # 五十万物品
        feature_names=["item_feature"],
    ),
]

# 4. 在CPU上实例化模型(重要!DMP会处理设备移动)
ebc = EmbeddingBagCollection(
    tables=embedding_configs,
    device=torch.device("cpu")
)
model = SimpleRecModel(ebc)

# 5. 使用分布式模型并行(DMP)包装模型
# get_default_sharders() 提供了适用于常见嵌入模块的分片器
model = DistributedModelParallel(
    module=model,
    device=device,
    sharders=get_default_sharders(),
    # planner=EmbeddingShardingPlanner()  # 可选的自动规划器,用于生成优化的分片计划
)

# 6. 定义优化器(TorchRec的优化器支持稀疏更新,高效处理嵌入梯度)
from torchrec.optim import apply_optimizer_in_backward
from torch.optim import SGD
# 为嵌入参数设置稀疏优化器
apply_optimizer_in_backward(
    SGD,
    model.module.ebc.parameters(),
    {"lr": 0.1}
)
# 为稠密参数设置标准优化器
dense_optimizer = SGD(model.module.dense.parameters(), lr=0.01)

# 此后,在训练循环中,前向传播、反向传播和优化器步骤与非分布式模型基本一致。
# DMP会自动处理跨设备的梯度同步和稀疏参数的更新。

⚙️ 第二部分:FSDP 快速指南

FSDP是PyTorch原生的分布式训练策略,核心思想是将模型的参数、梯度和优化器状态全部分片存储,在需要时再通过通信收集,从而极大节省单卡显存。

1. 基本使用模式

以下是使用FSDP包装一个Transformer模型的典型代码:

python 复制代码
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import default_auto_wrap_policy

# 1. 初始化分布式环境 (同上,略)
# ...

# 2. 定义一个大的模型 (例如Transformer)
model = nn.Transformer(
    d_model=2048,
    nhead=16,
    num_encoder_layers=12,
    num_decoder_layers=12
)

# 3. 定义自动包装策略(按子模块分片)
my_auto_wrap_policy = default_auto_wrap_policy(
    transformer_layer_cls={nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
)

# 4. 用FSDP包装模型
fsdp_model = FSDP(
    model,
    auto_wrap_policy=my_auto_wrap_policy,
    device_id=torch.cuda.current_device(),
)

# 5. 定义优化器(FSDP会自动处理优化器状态分片)
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-4)

# 6. 训练循环与非分布式模型一致
# FSDP会在前向传播时透明地收集所需参数,并在反向传播后同步梯度和更新分片。

🤔 第三部分:TorchRec 与 FSDP 核心对比

这两种技术都是为"大模型"设计,但目标完全不同。下表清晰地展示了两者的区别:

对比维度 TorchRec FSDP (Fully Sharded Data Parallel)
核心目标 专门用于大规模推荐系统 ,解决稀疏嵌入表的并行训练。 通用的大规模稠密模型(如LLM、CV大模型)训练,解决参数显存瓶颈。
主要并行范式 混合并行 :嵌入表常采用模型并行/张量并行切分,稠密部分使用数据并行。 增强的数据并行 :在数据并行的基础上,对参数、梯度、优化器状态进行分片
优化核心 嵌入表的分片策略 (行、列、表),以及稀疏梯度的高效聚合与更新 通信与计算的重叠 ,以及分片策略(全分片、混合分片)以平衡显存和通信开销。
关键优势 1. 原生支持超大规模嵌入 (十亿/万亿级)。 2. 为推荐系统提供专用原语 (如EmbeddingBagCollection)。 3. 优化器支持稀疏更新,计算高效。 1. 通用性强 ,几乎适用于任何PyTorch模型。 2. 显存节省显著 ,是训练千亿参数大模型的标配技术 。 3. 与PyTorch生态无缝集成
典型应用场景 电商推荐、广告点击率(CTR)预估、社交网络推荐等具有海量稀疏特征的场景。 大语言模型(LLM)预训练与微调、大规模视觉模型训练、稠密科学计算模型。
关系 互补 。一个复杂模型可同时使用两者:其稀疏嵌入部分用TorchRec分片 ,而稠密神经网络部分用FSDP分片

💡 第四部分:如何选择与后续建议

如何选择:

  • 如果你的模型核心是处理用户ID、商品ID等海量离散特征 ,嵌入表参数占模型绝大部分,请直接选择 TorchRec
  • 如果你的模型是Transformer、ResNet等稠密结构 ,参数巨大但并非稀疏特征,应选择 FSDP
  • 对于混合模型 (大嵌入表+大稠密网络),可以研究组合使用两者
相关推荐
Xy-unu10 小时前
Analog optical computer for AI inference and combinatorial optimization
论文阅读·人工智能
小马过河R11 小时前
混元世界模型1.5架构原理初探
人工智能·语言模型·架构·nlp
三万棵雪松11 小时前
【AI小智后端部分(一)】
人工智能·python·ai小智
编程小Y11 小时前
Adobe Animate 2024:2D 矢量动画与交互创作利器下载安装教程
人工智能
laplace012311 小时前
Part 3:模型调用、记忆管理与工具调用流程(LangChain 1.0)笔记(Markdown)
开发语言·人工智能·笔记·python·langchain·prompt
mys551811 小时前
杨建允:AI搜索优化对汽车服务行业获客的影响
人工智能·aigc·geo·ai搜索优化·ai引擎优化
2501_9361460411 小时前
鱼类识别与分类:基于freeanchor_x101-32x4d_fpn_1x_coco的三种鱼类自动检测
人工智能·分类·数据挖掘
鲨莎分不晴11 小时前
拯救暗淡图像:深度解析直方图均衡化(原理、公式与计算)
人工智能·算法·机器学习
好奇龙猫11 小时前
【人工智能学习-AI-MIT公开课-10. 学习介绍、最近邻】
人工智能·学习
智算菩萨11 小时前
2026马年新岁:拥抱智能时代,共谱科技华章
人工智能·科技