LLM 分布式推理:切分、通信与优化

LLM 分布式推理:切分、通信与优化

    • [1. MLP 层 (Feed Forward) 的切分](#1. MLP 层 (Feed Forward) 的切分)
      • [1.1 第一层: W u p W_{up} Wup (升维) → \rightarrow → **列切分 (Column Parallel)**](#1.1 第一层: W u p W_{up} Wup (升维) → \rightarrow → 列切分 (Column Parallel))
      • [1.2 第二层: W d o w n W_{down} Wdown (降维) → \rightarrow → **行切分 (Row Parallel)**](#1.2 第二层: W d o w n W_{down} Wdown (降维) → \rightarrow → 行切分 (Row Parallel))
    • [2. Attention 层 (QKV) 的切分](#2. Attention 层 (QKV) 的切分)
      • [2.1 QKV 投影层 → \rightarrow → **列切分 (Column Parallel)**](#2.1 QKV 投影层 → \rightarrow → 列切分 (Column Parallel))
      • [2.2 输出投影层 (Output Linear) → \rightarrow → **行切分 (Row Parallel)**](#2.2 输出投影层 (Output Linear) → \rightarrow → 行切分 (Row Parallel))
    • [🔗 3. 通信发生点 (Critical Path)](#🔗 3. 通信发生点 (Critical Path))
    • [🛠️ 4. 针对"片内分布式"的优化策略](#🛠️ 4. 针对“片内分布式”的优化策略)
      • [4.1 序列并行 (Sequence Parallelism) - 解决 GQA 负载不均](#4.1 序列并行 (Sequence Parallelism) - 解决 GQA 负载不均)
      • [4.2 通信与计算重叠 (Overlap / Hiding)](#4.2 通信与计算重叠 (Overlap / Hiding))
      • [4.3 拓扑感知的 All-Reduce (Topology-aware Reduction)](#4.3 拓扑感知的 All-Reduce (Topology-aware Reduction))
      • [4.4 针对 Prefill 的分块 (Chunked Prefill)](#4.4 针对 Prefill 的分块 (Chunked Prefill))
    • [📝 总结](#📝 总结)

💡 核心逻辑 (The Core Logic)

目前主流的 LLM 推理(如 vLLM)采用 张量并行 (Tensor Parallelism, TP)

  • 黄金法则: "列切分 (Column Parallel) + 行切分 (Row Parallel)" 的组合。
  • 目的: 确保在两层计算之间不需要通信,只在模块的末尾进行一次通信,从而最大化吞吐量。

1. MLP 层 (Feed Forward) 的切分

MLP 层通常由两个线性层组成:先升维(Up Proj),激活,再降维(Down Proj)。

公式: Y = Down ( Act ( Up ( X ) ) ) Y = \text{Down}(\text{Act}(\text{Up}(X))) Y=Down(Act(Up(X)))

1.1 第一层: W u p W_{up} Wup (升维) → \rightarrow → 列切分 (Column Parallel)

  • 切分方式: 将权重矩阵 W u p W_{up} Wup 竖着切
    • Device A 拿左半边列,Device B 拿右半边列。
  • 输入: 输入 X X X 被广播 (Broadcast) 到所有设备(或者 X X X 本身就是全量的)。
  • 计算: Y p a r t i a l = X × W u p _ c o l Y_{partial} = X \times W_{up\_col} Ypartial=X×Wup_col。
  • 状态: 每张卡得到一部分输出通道的结果(Split Output)。
  • 通信: 。计算完直接做激活函数(Activation),因为激活是 Element-wise 的,各算各的即可。

1.2 第二层: W d o w n W_{down} Wdown (降维) → \rightarrow → 行切分 (Row Parallel)

  • 切分方式: 将权重矩阵 W d o w n W_{down} Wdown 横着切
    • Device A 拿上半边行,Device B 拿下半边行。
  • 输入: 正好接收上一层产生的"切分好的输出"。数据不需要搬运,直接喂进去
  • 计算: Y p a r t i a l _ s u m = Y p r e v × W d o w n _ r o w Y_{partial\sum} = Y{prev} \times W_{down\_row} Ypartial_sum=Yprev×Wdown_row。
  • 状态: 每张卡得到一个部分和 (Partial Sum),维度是完整的,但数值是不完整的。
  • 通信: All-Reduce (Sum)
    • 必须把所有卡的结果加起来,才能得到最终的 Y Y Y。

2. Attention 层 (QKV) 的切分

这是 Transformer 的核心,切分粒度通常是 Head (注意力头)

2.1 QKV 投影层 → \rightarrow → 列切分 (Column Parallel)

  • 切分方式:Head 分组。
    • 假设有 16 个 Head,2 张卡。Card 0 拿 Head 0-7,Card 1 拿 Head 8-15。
    • W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 全部竖着切。
  • KV Cache 存储: 切分存储 (Split Storage)
    • 这是显存优化的关键。Card 0 显存里只存 Head 0-7 的 KV Cache,Card 1 只存 Head 8-15。
    • 优势: 单卡显存压力随卡数线性降低。
  • Attention 计算: 独立计算 (Local Compute)
    • O = Softmax ( Q K T ) V O = \text{Softmax}(QK^T)V O=Softmax(QKT)V。因为 Head 之间没有依赖,所以各卡算各的,不需要通信

2.2 输出投影层 (Output Linear) → \rightarrow → 行切分 (Row Parallel)

  • 输入: 各个 Head 算出来的结果拼接在一起 (Concat)。
  • 切分方式: W o u t W_{out} Wout 横着切。
    • Card 0 处理对应 Head 0-7 的那部分权重行。
    • Card 1 处理对应 Head 8-15 的那部分权重行。
  • 计算: 得到部分和 (Partial Sum)。
  • 通信: All-Reduce (Sum)
    • 将结果相加,得到最终的 Attention Output。

🔗 3. 通信发生点 (Critical Path)

在一个标准的 Transformer Block 中,必须 发生通信的地方只有 2 处

  1. Attention 结束时: Output Linear 之后 → \rightarrow → All-Reduce
  2. MLP 结束时: Down Proj 之后 → \rightarrow → All-Reduce
  • 瓶颈分析:
    • 这两个 All-Reduce 是同步点 (Synchronization Point)
    • 所有卡必须等最慢的那张卡算完,才能进行通信。
    • 对于"片内分布式"架构,这里的瓶颈在于 NoC (片上网络) 的带宽Reduce 逻辑的延迟

🛠️ 4. 针对"片内分布式"的优化策略

结合 vLLM 和硬件特性,有以下优化点:

4.1 序列并行 (Sequence Parallelism) - 解决 GQA 负载不均

  • 场景: 芯片核心数极多 (如 100核),但 Llama-3 只有 8 个 KV Head (GQA)。
  • 问题: 按 Head 切分,只有 8 个核在干活,92 个核围观。
  • 优化:Sequence (Token) 维度进行切分。
    • 将同一个 Head 的计算任务,按 Token 长度拆分给多个核。
    • 虽然增加了部分通信(需要汇总不同 Token 的 Attention 结果),但填满了算力,大幅降低了 Latency。

4.2 通信与计算重叠 (Overlap / Hiding)

  • 原理: 不要等算完了再发数据,边算边发
  • 操作:
    • 在计算 Output Linear (行切分) 的最后几行时,前面的计算结果已经可以开始在 NoC 上跑 All-Reduce 了。
    • 利用近存计算架构低延迟的特性,尽可能掩盖通信开销。

4.3 拓扑感知的 All-Reduce (Topology-aware Reduction)

  • 场景: 片内 Block 的连接方式可能是 Mesh (网格) 或 Ring (环)。
  • 优化:
    • 不要用简单的"广播+求和"。
    • 设计符合物理拓扑的树状归约 (Tree Reduction)。比如相邻的 4 个 Block 先在本地求和,再把结果往上传递。减少长距离的数据搬运。

4.4 针对 Prefill 的分块 (Chunked Prefill)

  • 场景: 输入 Prompt 特别长。
  • 优化: 不要一次性把所有 Prompt 塞进去算,导致长时间占用计算资源(卡顿)。
  • 做法: 把 Prompt 切成 Chunk,和 Decode 阶段的任务混合调度 (Piggybacking),保证计算单元的流水线始终是满的。

📝 总结

模块 权重切分 输入数据 输出数据 是否通信
QKV Proj 列 (Col) 广播 (全量) 切分 (Split) ❌ 否
Attn 计算 - 切分 切分 ❌ 否
Output Proj 行 (Row) 切分 部分和 (Partial) All-Reduce
MLP Up 列 (Col) 广播 (全量) 切分 ❌ 否
Activation - 切分 切分 ❌ 否
MLP Down 行 (Row) 切分 部分和 (Partial) All-Reduce
相关推荐
HZjiangzi2 小时前
文物古董如何实现高保真三维数字化?思看科技3DeVOK MT彩色扫描+智能贴图方案权威解析
人工智能·科技·制造·三维扫描仪
救救孩子把2 小时前
58-机器学习与大模型开发数学教程-5-5 牛顿法与拟牛顿法(BFGS、L-BFGS)
人工智能·机器学习
junziruruo2 小时前
三叉预测头Trident prediction head(RGBT目标跟踪以MTNET为例)
人工智能·计算机视觉·目标跟踪
光羽隹衡2 小时前
计算机视觉--Opencv(图像形态学)
人工智能·opencv·计算机视觉
懈尘2 小时前
基于Spring Boot与LangChain4j的AI驱动新闻系统设计与工程实现
java·大数据·人工智能·spring boot·后端·langchain
倔强的石头1062 小时前
假设空间与版本空间 —— 机器学习是 “猜规律” 的过程
人工智能·机器学习
flying_13142 小时前
图神经网络分享系列-GGNN(GATED GRAPH SEQUENCE NEURAL NETWORKS)(三)
人工智能·深度学习·神经网络·图神经网络·ggnn·门控机制·图特征学习
cooldream20092 小时前
Agent Skill:新一代 AI 设计模式的原理、实践与 MCP 协同应用解析
人工智能·mcp·agent skill
言無咎2 小时前
传统财务RPA陷入性能瓶颈?AI财务机器人用LLM重构智能财税
人工智能·机器人·rpa