【Infra】Megatron学习

1. 3D 并行= TP(Tensor Parallel) + PP(Pipeline Parallel) + DP(Data Parallel)

这是大模型训练的"标配组合",三维并行就像把模型切成一个3D立方体来加速。

并行维度 切的是什么 主要解决的问题 每张卡内存下降倍数 通信量特点
TP 每一层权重矩阵(GEMM) 单层参数太大,一张卡放不下 ÷ TP_size AllReduce(频繁但小)
PP 模型的层(Layer) 模型太深,激活值占太多内存 ÷ PP_size P2P(气泡较大)
DP 样本(Batch) 提升训练吞吐 ≈ 1(ZeRO可再省) AllReduce(梯度大)

三者相乘得到总卡数 = TP × PP × DP

举例:

  • 模型:LLaMA-405B

  • 配置:TP=8, PP=8, DP=64 → 总共 8×8×64 = 4096 张卡

  • 每张卡只放 405B ÷ 8 ≈ 50B 参数,激活再被PP切8份,内存轻松放下

3D并行核心难点:

  • TP内部的AllReduce要和计算完全重叠(Megatron用列并行+行并行完美overlap)

  • PP的气泡(bubble)要压到最低(Megatron用1F1B调度,气泡只剩1/PP)

  • 三者混合后通信调度极其复杂,Megatron-Core把这部分做到极致

2. Sequence Parallel(序列并行)

问题:训练超长上下文(32k~128k~1M)时,单卡显存被激活值(attention的K/V cache + intermediate activation)打爆。

传统做法:只能减小batch size → 训练效率暴跌。

Sequence Parallel(SP)的解决思路: 把序列长度也切开!(和TP切权重矩阵类似)

具体做法(Megatron + FlashAttention-2实现):

  • 在一个TP组内(比如8张卡),把序列长度S再切成8份

  • 每张卡只存 S/8 长度的K/V和中间激活

  • Attention计算时,通过AllReduce把Q和所有分片的K做完整点积

  • 通信正好能和FlashAttention的GEMM完全重叠,几乎零开销

效果:

  • 上下文长度×8而不增加显存(理论上可无限扩展)

  • 目前DeepSeek-V3、LLaMA-3.1、Qwen2.5长上下文训练全部靠SP

一句话总结:Sequence Parallel = "在TP组内再做一次Tensor Parallel,但切的是序列维度"

3. Context Parallel(上下文并行)

Sequence Parallel已经很强了,但还有问题:

  • 只能在同一个TP组内切序列(最多切8-16路)

  • 想支持1M甚至10M长上下文还是不够

核心思想:

  • 彻底把"序列维度"从"模型并行"中独立出来

  • 允许跨节点、跨TP组、跨PP阶段切序列(可以切64路、128路甚至更多)

  • 使用Ring Attention + 可调度通信,把超长序列的attention通信变成线性扩展

优势对比:

技术 最大可切份数 是否跨节点 通信开销 代表模型
Sequence Parallel ≤ TP size (8-16) 不行 几乎0(overlap) LLaMA-3.1 128K
Context Parallel 64~512+ 可以 稍高但可接受 DeepSeek-V3 1M+, Orion等

总结对比

技术 解决的核心问题 显存下降倍数 通信开销 何时必须用
Tensor Parallel 权重太大 ÷ TP 70B+模型
Pipeline Parallel 层太多,激活爆显存 ÷ PP 中高 模型层数>60 + 多节点
Data Parallel + ZeRO 梯度/优化器状态爆显存 ZeRO-3几乎÷DP 所有大规模训练
Sequence Parallel 长序列激活爆显存 ÷ TP(序列维度) 极低 32K~128K上下文训练
Context Parallel 超长序列(128K+~1M+) 可÷64~512 1M+长上下文训练(2025主流)

为什么"三者相乘得到总卡数"?(TP × PP × DP = 总卡数)

这是 3D 并行的基本数学原理:

  • TP(Tensor Parallel):把每一层权重矩阵横着切,需要 TP 张卡组成一个完整的层。
  • PP(Pipeline Parallel):把模型的层纵着切成 PP 段,每段放在不同卡组。
  • DP(Data Parallel):每个完整的模型再复制 DP 份,处理不同 batch 的数据。

这三个维度是正交的、互相独立的,所以总卡数就是乘法:

举例(LLaMA-405B 真实训练配置):

  • TP = 8 → 8张卡才能拼出一个完整的 Transformer layer
  • PP = 8 → 需要 8 个这样的 TP 组首尾相连组成完整模型
  • DP = 64 → 上面这套东西再复制 64 份处理不同样本

为什么 Sequence Parallel、Context Parallel 和 3D 并行是分开算的?

因为它们解决的问题和切的维度完全不同,属于"第四维、第五维"并行:

并行技术 切的维度 属于几维并行 典型切分大小 备注
TP / PP / DP 权重、层、数 经典 3D 并行 8×8×几十 总卡数 = TP×PP×DP
Sequence Parallel 序列长度 S 第4维 ≤ TP 大小(通常 8~16) 必须在同一个 TP 组内切
Context Parallel 序列长度 S 第5维 64~512+ 可以跨节点、跨 TP 组、跨 PP 阶段

它们不是互斥的,而是叠加使用的:

  • 70B~400B 普通上下文(≤32K) → 3D + Sequence Parallel(SP)
  • 超长上下文(128K~1M+) → 3D + Sequence Parallel + Context Parallel(CP)
  • 总卡数还是只看 TP×PP×DP,SP 和 CP 不会额外增加卡数,只是在已有卡里再切序列维度
相关推荐
arvin_xiaoting9 分钟前
从 0 到 1:搭建自学习 AI Agent 系统的完整工程指南
人工智能·学习·系统设计·ai agent·lancedb·自学习·openclaw
飞Link20 分钟前
深度解析 TS2Vec:时序表示学习中的层次化建模(Hierarchical Contrastive Learning)
开发语言·python·学习·数据挖掘
格鸰爱童话1 小时前
向AI学习项目技能(二)
java·人工智能·python·学习
知识分享小能手1 小时前
PostgreSQL 入门学习教程,从入门到精通,PostgreSQL 16 服务器配置与数据库监控终极指南 —语法、案例与实战(18)
数据库·学习·postgresql
困死,根本不会1 小时前
蓝桥杯python备赛笔记之(八)动态规划(DP)
笔记·python·学习·算法·蓝桥杯·动态规划
懷淰メ1 小时前
python3GUI--socket+PyQt5开发局域网微信(含功能、详细介绍、分享)
python·学习·gui·大学生·pyqt5·微信界面
ByNotD0g2 小时前
Doris 学习笔记
android·笔记·学习
困死,根本不会2 小时前
Qt Designer 基础操作学习笔记
开发语言·笔记·qt·学习·microsoft
WJSKad12352 小时前
Focus瓶颈轻量化改进YOLOv26通道压缩与残差学习协同突破
学习·yolo
愚者游世2 小时前
<algorithm> 中 remove、remove_if、remove_copy、remove_copy_if 详解
c++·学习·程序人生·职场和发展·visual studio