【Infra】Megatron学习

1. 3D 并行= TP(Tensor Parallel) + PP(Pipeline Parallel) + DP(Data Parallel)

这是大模型训练的"标配组合",三维并行就像把模型切成一个3D立方体来加速。

并行维度 切的是什么 主要解决的问题 每张卡内存下降倍数 通信量特点
TP 每一层权重矩阵(GEMM) 单层参数太大,一张卡放不下 ÷ TP_size AllReduce(频繁但小)
PP 模型的层(Layer) 模型太深,激活值占太多内存 ÷ PP_size P2P(气泡较大)
DP 样本(Batch) 提升训练吞吐 ≈ 1(ZeRO可再省) AllReduce(梯度大)

三者相乘得到总卡数 = TP × PP × DP

举例:

  • 模型:LLaMA-405B

  • 配置:TP=8, PP=8, DP=64 → 总共 8×8×64 = 4096 张卡

  • 每张卡只放 405B ÷ 8 ≈ 50B 参数,激活再被PP切8份,内存轻松放下

3D并行核心难点:

  • TP内部的AllReduce要和计算完全重叠(Megatron用列并行+行并行完美overlap)

  • PP的气泡(bubble)要压到最低(Megatron用1F1B调度,气泡只剩1/PP)

  • 三者混合后通信调度极其复杂,Megatron-Core把这部分做到极致

2. Sequence Parallel(序列并行)

问题:训练超长上下文(32k~128k~1M)时,单卡显存被激活值(attention的K/V cache + intermediate activation)打爆。

传统做法:只能减小batch size → 训练效率暴跌。

Sequence Parallel(SP)的解决思路: 把序列长度也切开!(和TP切权重矩阵类似)

具体做法(Megatron + FlashAttention-2实现):

  • 在一个TP组内(比如8张卡),把序列长度S再切成8份

  • 每张卡只存 S/8 长度的K/V和中间激活

  • Attention计算时,通过AllReduce把Q和所有分片的K做完整点积

  • 通信正好能和FlashAttention的GEMM完全重叠,几乎零开销

效果:

  • 上下文长度×8而不增加显存(理论上可无限扩展)

  • 目前DeepSeek-V3、LLaMA-3.1、Qwen2.5长上下文训练全部靠SP

一句话总结:Sequence Parallel = "在TP组内再做一次Tensor Parallel,但切的是序列维度"

3. Context Parallel(上下文并行)

Sequence Parallel已经很强了,但还有问题:

  • 只能在同一个TP组内切序列(最多切8-16路)

  • 想支持1M甚至10M长上下文还是不够

核心思想:

  • 彻底把"序列维度"从"模型并行"中独立出来

  • 允许跨节点、跨TP组、跨PP阶段切序列(可以切64路、128路甚至更多)

  • 使用Ring Attention + 可调度通信,把超长序列的attention通信变成线性扩展

优势对比:

技术 最大可切份数 是否跨节点 通信开销 代表模型
Sequence Parallel ≤ TP size (8-16) 不行 几乎0(overlap) LLaMA-3.1 128K
Context Parallel 64~512+ 可以 稍高但可接受 DeepSeek-V3 1M+, Orion等

总结对比

技术 解决的核心问题 显存下降倍数 通信开销 何时必须用
Tensor Parallel 权重太大 ÷ TP 70B+模型
Pipeline Parallel 层太多,激活爆显存 ÷ PP 中高 模型层数>60 + 多节点
Data Parallel + ZeRO 梯度/优化器状态爆显存 ZeRO-3几乎÷DP 所有大规模训练
Sequence Parallel 长序列激活爆显存 ÷ TP(序列维度) 极低 32K~128K上下文训练
Context Parallel 超长序列(128K+~1M+) 可÷64~512 1M+长上下文训练(2025主流)

为什么"三者相乘得到总卡数"?(TP × PP × DP = 总卡数)

这是 3D 并行的基本数学原理:

  • TP(Tensor Parallel):把每一层权重矩阵横着切,需要 TP 张卡组成一个完整的层。
  • PP(Pipeline Parallel):把模型的层纵着切成 PP 段,每段放在不同卡组。
  • DP(Data Parallel):每个完整的模型再复制 DP 份,处理不同 batch 的数据。

这三个维度是正交的、互相独立的,所以总卡数就是乘法:

举例(LLaMA-405B 真实训练配置):

  • TP = 8 → 8张卡才能拼出一个完整的 Transformer layer
  • PP = 8 → 需要 8 个这样的 TP 组首尾相连组成完整模型
  • DP = 64 → 上面这套东西再复制 64 份处理不同样本

为什么 Sequence Parallel、Context Parallel 和 3D 并行是分开算的?

因为它们解决的问题和切的维度完全不同,属于"第四维、第五维"并行:

并行技术 切的维度 属于几维并行 典型切分大小 备注
TP / PP / DP 权重、层、数 经典 3D 并行 8×8×几十 总卡数 = TP×PP×DP
Sequence Parallel 序列长度 S 第4维 ≤ TP 大小(通常 8~16) 必须在同一个 TP 组内切
Context Parallel 序列长度 S 第5维 64~512+ 可以跨节点、跨 TP 组、跨 PP 阶段

它们不是互斥的,而是叠加使用的:

  • 70B~400B 普通上下文(≤32K) → 3D + Sequence Parallel(SP)
  • 超长上下文(128K~1M+) → 3D + Sequence Parallel + Context Parallel(CP)
  • 总卡数还是只看 TP×PP×DP,SP 和 CP 不会额外增加卡数,只是在已有卡里再切序列维度
相关推荐
lizhihai_991 天前
股市学习心得—半导体12种核心材料
大数据·人工智能·学习
sakiko_1 天前
UIKit学习笔记3-布局、滚动视图、隐藏或显示视图
前端·笔记·学习·objective-c·swift·uikit
嵌入式-老费1 天前
瑞芯微soc的学习和应用(题外话之esp32开发)
学习
辰同学ovo1 天前
从全局登录状态管理学习 Redux
前端·javascript·学习·react.js
ting94520001 天前
告别无效学习:Scholé 如何用 AI 重构职场学习,让学习直接嵌入工作流
人工智能·学习·重构
xian_wwq1 天前
【学习笔记】Harness到底是什么
笔记·学习·ai·harness
wuxinyan1231 天前
大模型学习之路004:RAG 零基础入门教程(第一篇):基础理论与文档处理流水线
人工智能·学习·rag
冯诺依曼的锦鲤1 天前
从零实现高并发内存池:TCMalloc 核心架构拆解
c++·学习·算法·架构
网络工程小王1 天前
【LangChain Output Parser 输出解析器】输出篇
人工智能·学习·langchain
AI周红伟1 天前
周红伟:DeepSeek官方教您如何部署Hermes Agent 和接入 DeepSeek-V4-Pro
人工智能·深度学习·学习·机器学习·copilot·openclaw