【Infra】Megatron学习

1. 3D 并行= TP(Tensor Parallel) + PP(Pipeline Parallel) + DP(Data Parallel)

这是大模型训练的"标配组合",三维并行就像把模型切成一个3D立方体来加速。

并行维度 切的是什么 主要解决的问题 每张卡内存下降倍数 通信量特点
TP 每一层权重矩阵(GEMM) 单层参数太大,一张卡放不下 ÷ TP_size AllReduce(频繁但小)
PP 模型的层(Layer) 模型太深,激活值占太多内存 ÷ PP_size P2P(气泡较大)
DP 样本(Batch) 提升训练吞吐 ≈ 1(ZeRO可再省) AllReduce(梯度大)

三者相乘得到总卡数 = TP × PP × DP

举例:

  • 模型:LLaMA-405B

  • 配置:TP=8, PP=8, DP=64 → 总共 8×8×64 = 4096 张卡

  • 每张卡只放 405B ÷ 8 ≈ 50B 参数,激活再被PP切8份,内存轻松放下

3D并行核心难点:

  • TP内部的AllReduce要和计算完全重叠(Megatron用列并行+行并行完美overlap)

  • PP的气泡(bubble)要压到最低(Megatron用1F1B调度,气泡只剩1/PP)

  • 三者混合后通信调度极其复杂,Megatron-Core把这部分做到极致

2. Sequence Parallel(序列并行)

问题:训练超长上下文(32k~128k~1M)时,单卡显存被激活值(attention的K/V cache + intermediate activation)打爆。

传统做法:只能减小batch size → 训练效率暴跌。

Sequence Parallel(SP)的解决思路: 把序列长度也切开!(和TP切权重矩阵类似)

具体做法(Megatron + FlashAttention-2实现):

  • 在一个TP组内(比如8张卡),把序列长度S再切成8份

  • 每张卡只存 S/8 长度的K/V和中间激活

  • Attention计算时,通过AllReduce把Q和所有分片的K做完整点积

  • 通信正好能和FlashAttention的GEMM完全重叠,几乎零开销

效果:

  • 上下文长度×8而不增加显存(理论上可无限扩展)

  • 目前DeepSeek-V3、LLaMA-3.1、Qwen2.5长上下文训练全部靠SP

一句话总结:Sequence Parallel = "在TP组内再做一次Tensor Parallel,但切的是序列维度"

3. Context Parallel(上下文并行)

Sequence Parallel已经很强了,但还有问题:

  • 只能在同一个TP组内切序列(最多切8-16路)

  • 想支持1M甚至10M长上下文还是不够

核心思想:

  • 彻底把"序列维度"从"模型并行"中独立出来

  • 允许跨节点、跨TP组、跨PP阶段切序列(可以切64路、128路甚至更多)

  • 使用Ring Attention + 可调度通信,把超长序列的attention通信变成线性扩展

优势对比:

技术 最大可切份数 是否跨节点 通信开销 代表模型
Sequence Parallel ≤ TP size (8-16) 不行 几乎0(overlap) LLaMA-3.1 128K
Context Parallel 64~512+ 可以 稍高但可接受 DeepSeek-V3 1M+, Orion等

总结对比

技术 解决的核心问题 显存下降倍数 通信开销 何时必须用
Tensor Parallel 权重太大 ÷ TP 70B+模型
Pipeline Parallel 层太多,激活爆显存 ÷ PP 中高 模型层数>60 + 多节点
Data Parallel + ZeRO 梯度/优化器状态爆显存 ZeRO-3几乎÷DP 所有大规模训练
Sequence Parallel 长序列激活爆显存 ÷ TP(序列维度) 极低 32K~128K上下文训练
Context Parallel 超长序列(128K+~1M+) 可÷64~512 1M+长上下文训练(2025主流)

为什么"三者相乘得到总卡数"?(TP × PP × DP = 总卡数)

这是 3D 并行的基本数学原理:

  • TP(Tensor Parallel):把每一层权重矩阵横着切,需要 TP 张卡组成一个完整的层。
  • PP(Pipeline Parallel):把模型的层纵着切成 PP 段,每段放在不同卡组。
  • DP(Data Parallel):每个完整的模型再复制 DP 份,处理不同 batch 的数据。

这三个维度是正交的、互相独立的,所以总卡数就是乘法:

举例(LLaMA-405B 真实训练配置):

  • TP = 8 → 8张卡才能拼出一个完整的 Transformer layer
  • PP = 8 → 需要 8 个这样的 TP 组首尾相连组成完整模型
  • DP = 64 → 上面这套东西再复制 64 份处理不同样本

为什么 Sequence Parallel、Context Parallel 和 3D 并行是分开算的?

因为它们解决的问题和切的维度完全不同,属于"第四维、第五维"并行:

并行技术 切的维度 属于几维并行 典型切分大小 备注
TP / PP / DP 权重、层、数 经典 3D 并行 8×8×几十 总卡数 = TP×PP×DP
Sequence Parallel 序列长度 S 第4维 ≤ TP 大小(通常 8~16) 必须在同一个 TP 组内切
Context Parallel 序列长度 S 第5维 64~512+ 可以跨节点、跨 TP 组、跨 PP 阶段

它们不是互斥的,而是叠加使用的:

  • 70B~400B 普通上下文(≤32K) → 3D + Sequence Parallel(SP)
  • 超长上下文(128K~1M+) → 3D + Sequence Parallel + Context Parallel(CP)
  • 总卡数还是只看 TP×PP×DP,SP 和 CP 不会额外增加卡数,只是在已有卡里再切序列维度
相关推荐
markuszhang2 小时前
G1 垃圾回收器学习
java·学习
星月IWJ2 小时前
领域驱动设计学习
java·学习·设计模式
菜鸟‍2 小时前
【论文学习】SAMed-2: 选择性记忆增强的医学任意分割模型
人工智能·学习·算法
weixin_409383122 小时前
简单四方向a*寻路学习记录2 先做个数组地图 在cocos编辑器模式上运行出格子 计算角色世界坐标跟数组地图的联系
学习·编辑器·cocos
一过菜只因3 小时前
MySql学习(2)
数据库·学习·mysql
灰灰勇闯IT3 小时前
虚拟机性能优化实战:从基础调优到深度压榨性能
开发语言·学习·性能优化·虚拟机
xxp43213 小时前
Linux 根文件系统构建
linux·学习
vi121233 小时前
农业图像预处理技术学习综述:原理、实现与应用
人工智能·学习
世界宇宙超级无敌究极特级顶级第一非常谱尼3 小时前
RF Power Amplifiers for Wireless Communications 第二章学习笔记
笔记·学习·pa·功率放大器·mmic