六、分布式嵌入

六、分布式嵌入


文章目录


前言

  • 我们已经使用了TorchRec 的主模块:EmbeddedBagCollection 。我们在上一节研究了它是如何工作的,以及数据在TorchRec 中是如何表示的。然而,我们还没有探索TorchRec 的主要部分之一,即分布式嵌入

一、先要配置torch.distributed环境

  • EmbeddingBagCollectionSharder 依赖于 PyTorch 的分布式通信库(torch.distributed)来管理跨进程/GPU 的分片和通信。

首先初始化分布式环境

python 复制代码
import torch.distributed as dist

# 初始化进程组
dist.init_process_group(
    backend="nccl",          # GPU 推荐 NCCL 后端, CPU就是 gloo
    init_method="env://",    # 从环境变量读取节点信息
    rank=rank,               # 当前进程的全局唯一标识(从 0 开始)
    world_size=world_size,   # 总进程数(总 GPU 数)
)

pg = dist.GroupMember.WORLD

设置环境变量(多节点训练时必须)

bash 复制代码
import torch.distributed as dist

# 初始化进程组
# 在每个节点上设置以下环境变量
export MASTER_ADDR="主节点IP"   # 如 "192.168.1.1"
export MASTER_PORT="66666"     # 任意未占用端口
export WORLD_SIZE=4            # 总 GPU 数
export RANK=0                  # 当前节点的全局 rank

二、Distributed Embeddings

  • 先回顾一下我们上一节的EmbeddingBagCollection module

代码演示:

python 复制代码
print(ebc)
"""
EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (product_table): EmbeddingBag(4096, 64, mode='sum')
    (user_table): EmbeddingBag(4096, 64, mode='sum')
  )
)
"""

2.1 EmbeddingBagCollectionSharder

  • 策略制定者 ,决定如何分片。
  • 决定如何将 EmbeddingBagCollection 的嵌入表(Embedding Tables )分布到多个 GPU/节点。
    核心功能 :根据配置(如 ShardingType )生成分片计划(Sharding Plan

代码演示:

python 复制代码
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder

# 定义分片器:指定分片策略(如按表分片)
sharder = EmbeddingBagCollectionSharder(
    sharding_type=ShardingType.TABLE_WISE.value,  # 每个表分配到一个 GPU
    kernel_type=EmbeddingComputeKernel.FUSED.value,  # 使用 fused 优化
)
  • 关键参数
    • sharding_type :分片策略,如:
      • TABLE_WISE:整个表放在一个 GPU。
      • ROW_WISE:按行分片到多个 GPU。
      • COLUMN_WISE:按列分片(适用于超大表)。
    • kernel_type:计算内核类型(如 FUSED 优化显存)

2.2 ShardedEmbeddingBagCollection

  • 策略执行者 ,实际管理分片后的嵌入表
  • 根据 EmbeddingBagCollectionSharder 生成的分片计划,实际管理分布在多设备上的嵌入表。
  • 核心功能 :在分布式环境中执行前向传播、梯度聚合和参数更新

代码演示:

python 复制代码
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection

# 根据分片器生成分片后的模块
sharded_ebc = ShardedEmbeddingBagCollection(
    module=ebc,        # 原始 EmbeddingBagCollection
    sharder=sharder,   # 分片策略
    device=device,     # 目标设备(如 GPU:0)
)

三、Planner

  • 它可以帮助我们确定最佳的分片配置。
  • Planner能够根据嵌入表的数量和GPU的数量来确定最佳配置。事实证明,这很难手动完成,工程师必须考虑大量因素来确保最佳的分片计划。
  • TorchRec 在提供的这个Planner ,可以帮助我们:
    • 评估硬件的内存限制
    • 将基于存储器获取的计算估计为嵌入查找
    • 解决数据特定因素
    • 考虑其他硬件细节,如带宽,以生成最佳分片计划

演示代码:

python 复制代码
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology

# 初始化Planner
planner = EmbeddingShardingPlanner(
    topology=Topology(  # 硬件拓扑信息
        world_size=4,  # 总 GPU 数
        compute_device="cuda",
        local_world_size=2,  # 单机 GPU 数
        batch_size=1024,  
    ),
    constraints={  # 可选约束(如强制某些表使用特定策略)
        "user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),
    },
)


# 生成分片计划
plan = planner.collective_plan(ebc, [sharder], pg)

# 分片后的模型
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection

sharded_ebc = ShardedEmbeddingBagCollection(
    module=ebc,
    sharder=sharder,
    device=torch.device("cuda:0"),
    plan=plan,  # 应用自动生成的分片计划
)

总结

  • TorchRec中的分布式嵌入以及训练设置。
相关推荐
A懿轩A16 分钟前
2025年十六届蓝桥杯Python B组原题及代码解析
python·算法·蓝桥杯·idle·b组
程序媛徐师姐22 分钟前
Python Django基于协同过滤算法的招聘信息推荐系统【附源码、文档说明】
python·django·协同过滤算法·招聘信息推荐系统·招聘信息·python招聘信息推荐系统·python招聘信息
2401_8906658640 分钟前
免费送源码:Java+ssm+MySQL 基于PHP在线考试系统的设计与实现 计算机毕业设计原创定制
java·hadoop·spring boot·python·mysql·spring cloud·php
帝锦_li1 小时前
微服务1--服务架构
分布式·微服务·系统架构
xuemenghan1 小时前
Numba 从零基础到实战:解锁 Python 性能新境界
开发语言·python
明月看潮生1 小时前
青少年编程与数学 02-016 Python数据结构与算法 22课题、并行算法
开发语言·python·青少年编程·并行计算·编程与数学
明月看潮生1 小时前
青少年编程与数学 02-016 Python数据结构与算法 20课题、几何算法
python·算法·青少年编程·编程与数学
星辰瑞云2 小时前
Spark-SQL核心编程2
大数据·分布式·spark
ErizJ2 小时前
Golang|Kafka在秒杀场景中的应用
开发语言·分布式·后端·golang·kafka
limengshi1383922 小时前
使用Python+xml+shutil修改目标检测图片和对应xml标注文件
xml·python·目标检测