TorchAcc:基于 TorchXLA 的分布式训练框架

演讲人:林伟,阿里云研究员,阿里云人工智能平台 PAI 技术负责人

本文旨在探讨阿里云 TorchAcc,这是一个基于 PyTorch/XLA 的大模型分布式训练框架。

过去十年 AI 领域的显著进步,关键在于训练技术的革新和模型规模的快速攀升。尽管大模型展现了堪比人类的理解力,但其训练却对算力提出了极高的要求。唯有配备充足的计算资源,方能在海量数据上有效训练大模型,确保其在有限时间内实现优质收敛。

图片来源于 GTC 2024大会China AI Day 线上专场的演讲《TorchAcc:基于TorchXLA的分布式训练框架》

根据上图左侧图表显示,过去五年,大模型规模的增长态势尤为突出,平均每两年大小翻 15 倍;而对于 Transformer 为代表的语言模型以及多模态模型而言,其规模膨胀速度更加惊人,每隔两年以 750 倍剧增。对比之下,右侧图表揭示了一个明显的矛盾点:不论是单个 GPU 的计算能力抑或是 GPU 显存容量的发展速度,都无法跟上模型规模如此急剧的扩张步伐。这一现实状况直接催生了对分布式训练的迫切需求。分布式训练不再局限于以往单纯的数据并行模式,而是在此基础上,更加重视并采取模型并行策略,以弥补单个计算单元算力与存储提升速度相对于模型规模增长的滞后性。

在分布式训练实践中,开发人员普遍认同,构建模型并行的分布式训练系统相比数据并行更为复杂。数据并行从分布式角度来看,其逻辑相对直接和简洁,因为每个计算节点执行的任务本质上是对等且一致的。在这种情况下,只需在训练过程末尾插入 AllReduce 步骤,将各个工作节点(worker)独立计算出的梯度差异累加整合,然后求平均值,并将最终梯度结果广播至所有参与工作的节点,用以同步更新全局模型参数。

这类简单的分布式训练范式,确实呈现出类似单机计算的特点,主要涉及全局梯度同步的 AllReduce。然而步入大模型时代,由于模型规模过大,已无法容纳于单个 GPU 之内,我们就必须采用模型并行策略,其开发难度也就陡然上升了。

原因是,模型并行需要根据模型的规模和结构来决定如何恰当地"分割"模型,即将其分割为多个可以平衡计算负载的模块。在不同的分割策略下,模型在各个节点上算子的算法实现方式会发生变化,同时,不同分割方法还会引起节点间通信原语的差异,需要精心选择最优分割方案以及配套的通信原语。

在模型分割完成后,接下来的任务就是选用适合的通信原语,并精细地调度各个算子及其相关的通信操作,力求最大化计算与网络通信的重叠(overlap),以充分发挥底层计算资源的效率。正是由于存在多种可能的分割选项与调度决策,寻求最优模型并行策略的复杂性明显高于数据并行,对开发者的技巧和经验提出了更高的要求。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

本文将围绕四个核心方面展开。首个议题是如何在 TorchAcc 中实现多样化的并行策略,涵盖了常规的数据并行,以及当下备受关注的 FSDP(Fully Sharded Data Parallel,又称 ZeRO (Zero Redundancy Optimizer)) 。此外,还包括了模型并行的各种形态,诸如算子并行,即 Tensor Parallelism,以及流水线并行(Pipeline Parallelism)等。

TorchAcc 的一大亮点在于其能够自动探寻并有机整合各类并行策略,并为用户提供高度自动化的分布式策略配置方案;与此同时,为了满足高级开发者的定制化需求,TorchAcc 还提供了半自动化的控制接口,允许用户介入并调整自动探索并行策略的过程,从而在兼顾灵活性的同时,最大程度地提升训练效率和资源利用率。

通过上述方式,TorchAcc 有效地助力算法开发者将精力集中于模型自身的结构设计、训练方法的优化,以及追求模型收敛性能的提升上,而非花费精力在分布式训练的具体实现细节。TorchAcc 将智能化地协助开发者探寻并实现最佳的分布式训练方案,从而显著提升计算资源利用效率和算法迭代效率。

其次,模型并行技术的必要性是因为大模型尺寸超出单个 GPU 显存容量的限制。显存容量对于模型训练至关重要,如何打破显存瓶颈,对于提升分布式训练的整体效率来说至关重要。因此,TorchAcc 提供了一种显存智能分配器,通过对显存资源的精细化调度与地址分配策略,最大限度地提高模型并行训练时的效率,确保模型能充分利用现有的显存地址空间。

再者,随着模型结构日益复杂,且规模不断增大,用户对计算资源的需求也在持续攀升,因此,进一步优化模型在训练过程中的计算密集度及减少访存开销也非常关键。

最后,考虑到当前数据中心基础设施的发展趋势,大模型训练对网络条件的要求日渐严苛。现代数据中心服务器间的互联带宽已达到 TB 级别,以满足大规模模型并行训练对高速数据交换的需求。然而,模型并行所带来的复杂通信模式与高频次的数据交互亦会对整体训练效率构成挑战。因此,如何有效利用网络带宽,减少通信过程在迭代计算中占据的时间比例,也就成了训练效率提升的另一重要因素。

在具体实现上,TorchAcc 通过一系列技术手段,成功地将用户在前端,无论是基于 PyTorch 还是 TensorFlow 构建的模型训练过程转化为统一的中间表示层(Model IR)的 graph。其中,对于 TensorFlow 而言,因其自身就是一种计算图模型,转化过程相对直接,而对于 PyTorch,我们采用了符号式追踪(symbolic tracing)以及 LazyTensor 等技术捕获计算图,进而转化为 IR Graph。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

基于中间表示层(IR Graph)的构建,TorchAcc 实施了一系列多元化的优化策略,涵盖计算优化、存储优化、通信优化以及分布式策略优化,IR Graph 以各类组合并反复执行这些优化的 Pass 后,最终得到一个最优的执行 Plan。然后交由底层 Backend 执行,以实现模型训练性能的最大化提升。

通过这一整套方案,TorchAcc 在多个模型的分布式训练场景中表现出了显著的性能优势。部分模型的训练过程得以实现高达 3 倍的性能提速,充分证明了 TorchAcc 在解决分布式训练难题上的高效性和实用性。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

这张图片主要展示了 TorchAcc 的框架总体架构。TorchAcc 以 Pytorch/XLA 为基础,并 TorchAcc 依托于 OpenXLA,构建了一套大模型训练加速框架。TorchAcc 在处理使用不同前端构建的模型时,会灵活采用适宜的图捕获技术,如 Symbolic Trace 和 LazyTensor,进而生成两种不同层级的图表示:FX Graph 和 HLO Graph。其中,FX Graph 位于较高抽象层次,而 HLO Graph 则更为底层。

基于捕获到的模型计算图,TorchAcc 即可进一步展开了四类优化工作,即前文提及的计算优化、存储优化、通信优化以及分布式策略优化。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在分布式策略优化层面,TorchAcc 支持业界广泛使用的各种并行策略,并能够灵活地结合这些策略对给定模型进行有效的并行化处理。具体而言,对于数据并行 DP(Data Parallelism)、流水并行 PP(Pipeline Parallelism)以及 FSDP(Fully Sharded Data Parallel, 也称为 ZeRO)这三种分布式策略,其实现和优化都是在 FX Graph 这一较高抽象层次上完成的。

选择在 FX Graph 层面对并行策略进行操作的原因在于,这一层级所包含的关于计算图结构和操作的信息已足够丰富,足以支撑开发人员设计出适应不同并行策略的优化方案。相较于在更低层的 HLO Graph 上直接进行优化,由于 FX Graph 具有更高的抽象性和概括性,在这一层面上进行优化的成本通常较低,更容易实施高效且针对性强的分布式策略调整。

以流水并行作为例子,系统能够自动检测 FX Graph 层级上的不同阶段,并确定合适的分割点,从而有效地将模型分割为多个连续执行的阶段,实现流水线并行化。在此过程中,我们可以利用 FX Graph 提供的详细计算结构信息来进行智能分割。

至于 Tensor Parallelism (张量并行)和 Sequence Parallelism (序列并行)这两种更为复杂的并行策略,它们要求更为细致精确的信息以便进行决策。为了实现这一点,系统需要对前向传播和反向传播的整个计算图的执行计划来进行分析。这时的工作主要在 HLO 这一低级别表示层面上进行。

通过利用 PyTorch/XLA 提供的 mark sharding 接口,系统能够在模型参数上添加相应的拆分标记,然后将这些拆分信息传递给 OpenXLA 的 SPMD 优化 Pass,进而触发计算图的拆分、优化、推导和重写过程,最终实现自动的 Tensor Parallelism 和 Sequence Parallelism 功能。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在算子优化层面,TorchAcc 引入 FlashAttention 技术来提升 Attention 模块的执行效率。首先,通过 XLA 的 custom call 功能,将 FlashAttention 的实现无缝地融入到了 OpenXLA 编译器和运行时框架中。这意味着 FlashAttention 可以直接在 XLA 内核层级被执行,从而充分利用硬件加速能力。

在整合过程中,要处理好在 PyTorch 与 XLA 之间 Tensor 数据的传递问题,确保在两个系统间转换时的数据一致性与性能优化,同时,还要妥善处理 FlashAttention内部参数传递等细节问题,保证在并行计算和优化的过程中,这些关键参数能够正确且高效地应用到计算中,进一步提升模型在执行注意力机制部分的运算速度和资源利用率。

为了用户能便捷地使用 FlashAttention 优化功能,我们提供了两种接口,用户也可以直接通过 Python 接口调用预先写好的 FlashAttention 算子,第三种方法是用户可以使用我们在 OpenXLA 上写好的 Pattern Match Pass,该 Pass 能够自动识别计算图中的 Attention Block,并将这部分计算结构提取出来,替换为FlashAttention 的 custom call。这样设计的优势在于,既能充分利用 XLA 原本就十分出色的 Kernel fusion 等算子优化功能,又能结合 FlashAttention 带来的先进计算优化技术。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在 Llama 2-7B 模型的性能测试中,我们能够明显观察到上述计算优化带来的效果。通过利用 XLA 自身的优化技术,尤其是 kernel fusion,我们将大量的访存密集型算子做了有效合并,从而大幅减少其数量,在叠加 FlashAttention 后,优化性能进一步提升。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在通信优化层面,我们主要完成了三项核心任务以提升分布式训练效率:首先,我们合并了一些零散的 collective 通讯算子,通过减少算子数量来降低通讯开销和调度复杂度,其次,我们将合并的 collective 通讯算子移至独立的 CUDA Stream 上执行,这样一来,就能够异步实现计算与通讯的重叠执行。最后,我们充分利用了 OpenXLA 的 Latency Hiding Scheduler 功能,对通讯算子的调度进行了精细优化,使其尽早启动和执行,从而增强通讯与计算之间的重叠效果。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

通过在 Llama2 -7B 模型上进行的端到端多机性能测试,我们发现,应用了通讯优化策略后,在 128 张 GPU 卡上进行分布式训练,优化后的加速比从原来的 88 提升到了 116,通过 timeline 图我们也可以直观地看到,优化后的通讯算子更加有序,并且能够更好地和计算重叠执行。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

本文最后一个章节绍 TorchAcc 的显存优化功能,该功能通过优化计算图中算子的执行顺序以及 Tensor 在显存中的地址分配,来降低显存开销。

如图举例说明,假设有一个包含四个算子 V0、V1、V2、V3 的计算图,如果不控制算子执行顺序,如左图所示按照 V0-V1-V2-V3 的顺序执行,若每个 Tensor 按照默认方式进行显存地址申请,则可能出现如 B 图左半部分所示的情况,即显存容量不足以容纳所有 Tensor,导致 out of memory 错误。

然而,如果我们能够预判并精细管理内存分配,即在分配地址时预知后续执行的算子序列,即可如 B 图右半部分所示进行更优的显存布局,使得整体计算可在有限显存内顺利完成。更进一步,通过精确控制执行顺序,比如按照 V0-V2-V1-V3 的方式执行,可以进一步压缩显存需求至原始需求的 70% 左右。

这一理念是基于 XLA 中间表示层已有的 scheduler 和 buffer 管理机制,我们在此基础上提出了更先进的显存优化方法。目前业界存在多种优化显存分配的方法,如启发式算法、约束求解等,但这些方法往往难以兼顾时效性和高效性,在实际生产环境的集群中应用时可能存在局限性。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在训练场景中实现有效且高效的显存优化是一项极具挑战的任务,原因主要包括以下几个方面:

  1. NP-Hard 问题本质:由于模型的规模、算子的种类繁多,以及算子间显存分配的复杂性,显存优化问题成为一个典型的 NP-hard 问题,即找到全局最优解在计算上通常是不可行的。

  2. 算子执行灵活性:训练过程中,前向传播、反向传播和权重更新等操作具有很高的灵活性,特别是在权重更新方面,梯度产生后随时可以被用于权重更新,但不同的执行时机会影响显存的申请和释放,增加了优化难度。

  3. 显存复用复杂性:在训练过程中,前向和反向传播可以通过复用显存减少重新计算,但 Tensor 生命周期的多样性和尺寸的变化使得显存复用变得极为复杂,这对启发式算法等传统优化手段构成了严峻挑战。

为了解决上述难题,我们采取了一种分治策略:

  1. Memory-aware Weight Update Scheduler:引入了显存感知的权重更新调度器,它会根据梯度产生的时机、使用的优化器类型以及当前显存资源状况,选择合适的权重更新时间点,避免即时更新加重显存压力,特别是对于复杂的优化器如 Adam,需考虑动量和其他变量的存储。

  2. Graph 分割与局部优化:将大计算图根据关键节点 (memory insensitive operator) 分割成多个内存无关性的子图,子图间执行顺序固定,而子图内部的执行顺序则可以多样化。通过这种方式,可以将复杂的全局线性规划问题分解成多个局部问题,在子图范围内采用高效的优化方法,如线性规划求解最优执行顺序。

通过上述分治策略,最终我们能够聚合这些子图的求解结果,这也就是我们提出的 ROAM (Reorder Operators and Arrange Tensors Address to Reduce Memory Usage) 这一内存优化探索方式。

上述方法可以成功实现对显存优化问题的高效处理。实验结果显示,与原生 PyTorch、启发式算法以及 Facebook 近期基于整数线性规划的优化方法等 baseline 相比,ROAM 分别节省了约 16%、13% 和 27% 的显存开销,且在优化时长和可扩展性方面表现出色,证实了这种方法的有效性。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

从另一个维度衡量效果,我们考察了算法求解的时间开销。实验证明,在常见的深度学习场景中,我们的优化算法能够在短短几分钟内得出优化结果。从右图所示对比中可以看出,相较于 Facebook 最近提出的 MODeL(一种基于线性规划的优化方法),我们的方法在求解时间上实现了显著的缩减。原因在于,MODeL 在处理大规模图时并未对其进行有效分割,而我们的方法通过引入 memory-aware weight update scheduler 和子图划分策略,有效地降低了优化问题的空间复杂度,从而提高了求解效率。

综上所述,TorchAcc 在显存优化、计算优化、通信优化以及并行策略优化等方面均取得显著成效,全方位提升了分布式训练的效率与性能。


以上内容来源于 GTC 2024 大会 China AI Day 线上中文演讲专场。扫描图片二维码或登录大会官网,观看演讲视频,并可下载讲义。

相关推荐
无须logic ᭄3 分钟前
CrypTen项目实践
python·机器学习·密码学·同态加密
小韩学长yyds3 小时前
从入门到精通:RabbitMQ的深度探索与实战应用
分布式·rabbitmq
Coovally AI模型快速验证7 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
orion-orion9 小时前
贝叶斯机器学习:高斯分布及其共轭先验
机器学习·统计学习
问道飞鱼9 小时前
【分布式知识】Spring Cloud Gateway实现跨集群应用访问
分布式·eureka·gateway
Shinobi_Jack10 小时前
c#使用Confluent.Kafka实现生产者发送消息至kafka(远程连接kafka发送消息超时的解决 Local:Message timed out)
分布式·kafka
余炜yw11 小时前
深入探讨激活函数在神经网络中的应用
人工智能·深度学习·机器学习
S-X-S11 小时前
RabbitMQ的消息可靠性保证
分布式·rabbitmq
赛丽曼12 小时前
机器学习-分类算法评估标准
人工智能·机器学习·分类
yuanbenshidiaos13 小时前
【大数据】机器学习----------计算机学习理论
大数据·学习·机器学习