深度学习专题:模型训练的张量并行(一)

深度学习专题:模型训练的张量并行(一)

张量并行的基本原理

(一)张量并行的定义

张量并行(Tensor Parallelism)是一种将单个张量分割到多个 GPU 上进行并行计算的技术,尤其在 Transformer 架构中广泛用于大模型训练和推理。

  • 将权重矩阵按行或列拆分到不同设备
  • 每个设备持有部分参数,计算部分结果,通过通信 如 all-reduce 聚合完整输出
  • 目标是减少单个设备的显存占用,同时利用多设备并行计算

(二)张量并行的切分方式

复制代码
输入 X: [b, s, h] (batch_size, sequence_length, hidden_size)

权重 W: [h, o] (hidden_size, output_size)

并行度 n: 设备数量
  1. 行并行(Row Parallelism)

    • 将权重矩阵按行拆分到不同设备,输入矩阵随之按列拆分

    • 每个设备计算拆分后的输入矩阵与拆分后的权重矩阵的乘积

    • 输出结果通过 all-reduce 聚合到所有设备

      (1) 切分权重矩阵W
      权重W形状: [h, o]
      按行切分: W被切成n块,每块形状 [h/n, o]

      (2) 切分输入张量X
      输入X形状: [b, s, h]
      对应切分: X的最后一维h也被切成n块,每块形状 [b, s, h/n]

      (3) 计算每个 GPU 的输出
      GPU0: Y0 = X0 @ W0

      X0: [b, s, h/n], W0: [h/n, o] → Y0: [b, s, o]

      GPU1: Y1 = X1 @ W1

      X1: [b, s, h/n], W1: [h/n, o] → Y1: [b, s, o]

      ...

      GPUn-1: Yn-1 = Xn-1 @ Wn-1

      Xn-1: [b, s, h/n], Wn-1: [h/n, o] → Yn-1: [b, s, o]

      (4) 聚合输出
      每个 GPU 计算完成后,将结果通过 all-reduce 聚合到所有设备
      数学上:Y = X @ W = (X0 @ W0) + (X1 @ W1) + ... + (Xn-1 @ Wn-1)
      所以要得到最终Y,需要把所有Yi相加:Y = ∑ Yi

  2. 列并行(Column Parallelism)

    • 将权重矩阵按列拆分到不同设备

    • 每个设备计算输入矩阵与拆分后的权重矩阵的乘积

    • 输出结果通过 all-gather 聚合到所有设备

      (1) 切分权重矩阵W
      权重W形状: [h, o]
      按列切分: W被切成n块,每块形状 [h, o/n]

      (2) 输入张量X保持不变
      输入X形状: [b, s, h]
      保持不变: X完整广播到所有GPU,形状 [b, s, h]

      (3) 计算每个 GPU 的输出
      GPU0: Y0 = X @ W0

      X: [b, s, h], W0: [h, o/n] → Y0: [b, s, o/n]

      GPU1: Y1 = X @ W1

      X: [b, s, h], W1: [h, o/n] → Y1: [b, s, o/n]

      ...

      GPUn-1: Yn-1 = X @ Wn-1

      X: [b, s, h], Wn-1: [h, o/n] → Yn-1: [b, s, o/n]

      (4) 聚合输出
      每个 GPU 计算完成后,将结果通过 all-gather 聚合到所有设备
      数学上:Y = X @ W = [X @ W0 | X @ W1 | ... | X @ Wn-1]
      所以要得到最终Y,需要把所有Yi拼接起来:Y = concat(Y0, Y1, ..., Yn-1)

相关推荐
智驱力人工智能4 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144874 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile4 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算
人工不智能5774 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
盟接之桥4 小时前
盟接之桥说制造:引流品 × 利润品,全球电商平台高效产品组合策略(供讨论)
大数据·linux·服务器·网络·人工智能·制造
kfyty7254 小时前
集成 spring-ai 2.x 实践中遇到的一些问题及解决方案
java·人工智能·spring-ai
h64648564h4 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切4 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
数据与后端架构提升之路4 小时前
论系统安全架构设计及其应用(基于AI大模型项目)
人工智能·安全·系统安全
忆~遂愿4 小时前
ops-cv 算子库深度解析:面向视觉任务的硬件优化与数据布局(NCHW/NHWC)策略
java·大数据·linux·人工智能