LLM 分布式推理:切分、通信与优化
-
- [1. MLP 层 (Feed Forward) 的切分](#1. MLP 层 (Feed Forward) 的切分)
-
- [1.1 第一层: W u p W_{up} Wup (升维) → \rightarrow → **列切分 (Column Parallel)**](#1.1 第一层: W u p W_{up} Wup (升维) → \rightarrow → 列切分 (Column Parallel))
- [1.2 第二层: W d o w n W_{down} Wdown (降维) → \rightarrow → **行切分 (Row Parallel)**](#1.2 第二层: W d o w n W_{down} Wdown (降维) → \rightarrow → 行切分 (Row Parallel))
- [2. Attention 层 (QKV) 的切分](#2. Attention 层 (QKV) 的切分)
-
- [2.1 QKV 投影层 → \rightarrow → **列切分 (Column Parallel)**](#2.1 QKV 投影层 → \rightarrow → 列切分 (Column Parallel))
- [2.2 输出投影层 (Output Linear) → \rightarrow → **行切分 (Row Parallel)**](#2.2 输出投影层 (Output Linear) → \rightarrow → 行切分 (Row Parallel))
- [🔗 3. 通信发生点 (Critical Path)](#🔗 3. 通信发生点 (Critical Path))
- [🛠️ 4. 针对"片内分布式"的优化策略](#🛠️ 4. 针对“片内分布式”的优化策略)
-
- [4.1 序列并行 (Sequence Parallelism) - 解决 GQA 负载不均](#4.1 序列并行 (Sequence Parallelism) - 解决 GQA 负载不均)
- [4.2 通信与计算重叠 (Overlap / Hiding)](#4.2 通信与计算重叠 (Overlap / Hiding))
- [4.3 拓扑感知的 All-Reduce (Topology-aware Reduction)](#4.3 拓扑感知的 All-Reduce (Topology-aware Reduction))
- [4.4 针对 Prefill 的分块 (Chunked Prefill)](#4.4 针对 Prefill 的分块 (Chunked Prefill))
- [📝 总结](#📝 总结)
💡 核心逻辑 (The Core Logic)
目前主流的 LLM 推理(如 vLLM)采用 张量并行 (Tensor Parallelism, TP)。
- 黄金法则: "列切分 (Column Parallel) + 行切分 (Row Parallel)" 的组合。
- 目的: 确保在两层计算之间不需要通信,只在模块的末尾进行一次通信,从而最大化吞吐量。
1. MLP 层 (Feed Forward) 的切分
MLP 层通常由两个线性层组成:先升维(Up Proj),激活,再降维(Down Proj)。
公式: Y = Down ( Act ( Up ( X ) ) ) Y = \text{Down}(\text{Act}(\text{Up}(X))) Y=Down(Act(Up(X)))
1.1 第一层: W u p W_{up} Wup (升维) → \rightarrow → 列切分 (Column Parallel)
- 切分方式: 将权重矩阵 W u p W_{up} Wup 竖着切 。
- Device A 拿左半边列,Device B 拿右半边列。
- 输入: 输入 X X X 被广播 (Broadcast) 到所有设备(或者 X X X 本身就是全量的)。
- 计算: Y p a r t i a l = X × W u p _ c o l Y_{partial} = X \times W_{up\_col} Ypartial=X×Wup_col。
- 状态: 每张卡得到一部分输出通道的结果(Split Output)。
- 通信: 无。计算完直接做激活函数(Activation),因为激活是 Element-wise 的,各算各的即可。
1.2 第二层: W d o w n W_{down} Wdown (降维) → \rightarrow → 行切分 (Row Parallel)
- 切分方式: 将权重矩阵 W d o w n W_{down} Wdown 横着切 。
- Device A 拿上半边行,Device B 拿下半边行。
- 输入: 正好接收上一层产生的"切分好的输出"。数据不需要搬运,直接喂进去。
- 计算: Y p a r t i a l _ s u m = Y p r e v × W d o w n _ r o w Y_{partial\sum} = Y{prev} \times W_{down\_row} Ypartial_sum=Yprev×Wdown_row。
- 状态: 每张卡得到一个部分和 (Partial Sum),维度是完整的,但数值是不完整的。
- 通信: All-Reduce (Sum) 。
- 必须把所有卡的结果加起来,才能得到最终的 Y Y Y。
2. Attention 层 (QKV) 的切分
这是 Transformer 的核心,切分粒度通常是 Head (注意力头)。
2.1 QKV 投影层 → \rightarrow → 列切分 (Column Parallel)
- 切分方式: 按 Head 分组。
- 假设有 16 个 Head,2 张卡。Card 0 拿 Head 0-7,Card 1 拿 Head 8-15。
- W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 全部竖着切。
- KV Cache 存储: 切分存储 (Split Storage) 。
- 这是显存优化的关键。Card 0 显存里只存 Head 0-7 的 KV Cache,Card 1 只存 Head 8-15。
- 优势: 单卡显存压力随卡数线性降低。
- Attention 计算: 独立计算 (Local Compute) 。
- O = Softmax ( Q K T ) V O = \text{Softmax}(QK^T)V O=Softmax(QKT)V。因为 Head 之间没有依赖,所以各卡算各的,不需要通信。
2.2 输出投影层 (Output Linear) → \rightarrow → 行切分 (Row Parallel)
- 输入: 各个 Head 算出来的结果拼接在一起 (Concat)。
- 切分方式: W o u t W_{out} Wout 横着切。
- Card 0 处理对应 Head 0-7 的那部分权重行。
- Card 1 处理对应 Head 8-15 的那部分权重行。
- 计算: 得到部分和 (Partial Sum)。
- 通信: All-Reduce (Sum) 。
- 将结果相加,得到最终的 Attention Output。
🔗 3. 通信发生点 (Critical Path)
在一个标准的 Transformer Block 中,必须 发生通信的地方只有 2 处:
- Attention 结束时: Output Linear 之后 → \rightarrow → All-Reduce。
- MLP 结束时: Down Proj 之后 → \rightarrow → All-Reduce。
- 瓶颈分析:
- 这两个 All-Reduce 是同步点 (Synchronization Point)。
- 所有卡必须等最慢的那张卡算完,才能进行通信。
- 对于"片内分布式"架构,这里的瓶颈在于 NoC (片上网络) 的带宽 和 Reduce 逻辑的延迟。
🛠️ 4. 针对"片内分布式"的优化策略
结合 vLLM 和硬件特性,有以下优化点:
4.1 序列并行 (Sequence Parallelism) - 解决 GQA 负载不均
- 场景: 芯片核心数极多 (如 100核),但 Llama-3 只有 8 个 KV Head (GQA)。
- 问题: 按 Head 切分,只有 8 个核在干活,92 个核围观。
- 优化: 在 Sequence (Token) 维度进行切分。
- 将同一个 Head 的计算任务,按 Token 长度拆分给多个核。
- 虽然增加了部分通信(需要汇总不同 Token 的 Attention 结果),但填满了算力,大幅降低了 Latency。
4.2 通信与计算重叠 (Overlap / Hiding)
- 原理: 不要等算完了再发数据,边算边发。
- 操作:
- 在计算 Output Linear (行切分) 的最后几行时,前面的计算结果已经可以开始在 NoC 上跑 All-Reduce 了。
- 利用近存计算架构低延迟的特性,尽可能掩盖通信开销。
4.3 拓扑感知的 All-Reduce (Topology-aware Reduction)
- 场景: 片内 Block 的连接方式可能是 Mesh (网格) 或 Ring (环)。
- 优化:
- 不要用简单的"广播+求和"。
- 设计符合物理拓扑的树状归约 (Tree Reduction)。比如相邻的 4 个 Block 先在本地求和,再把结果往上传递。减少长距离的数据搬运。
4.4 针对 Prefill 的分块 (Chunked Prefill)
- 场景: 输入 Prompt 特别长。
- 优化: 不要一次性把所有 Prompt 塞进去算,导致长时间占用计算资源(卡顿)。
- 做法: 把 Prompt 切成 Chunk,和 Decode 阶段的任务混合调度 (Piggybacking),保证计算单元的流水线始终是满的。
📝 总结
| 模块 | 权重切分 | 输入数据 | 输出数据 | 是否通信 |
|---|---|---|---|---|
| QKV Proj | 列 (Col) | 广播 (全量) | 切分 (Split) | ❌ 否 |
| Attn 计算 | - | 切分 | 切分 | ❌ 否 |
| Output Proj | 行 (Row) | 切分 | 部分和 (Partial) | ✅ All-Reduce |
| MLP Up | 列 (Col) | 广播 (全量) | 切分 | ❌ 否 |
| Activation | - | 切分 | 切分 | ❌ 否 |
| MLP Down | 行 (Row) | 切分 | 部分和 (Partial) | ✅ All-Reduce |