分布式多卡训练,以及 Lightning 中启用 FSDP

总览

最近有多显卡训练的需求,于是研究了一下分布式训练。

大致来说分为 Data Parallelism 和 Model Parallelism 两种策略。前者相当于是单张卡的训练直接拷贝到多张卡并行运行,后者则要切分模型权重让多显卡协作。两种策略具有代表性的方案分别是 DDP、FDSP。

我选择多卡训练是因为单卡跑不动而不是速度不够快,所以 DDP 就不考虑了。FDSP 算是比较方便不需要大量修改代码的方案,想试试。

本文后面附带 Lightning 使用 FDSP 的方法。

不同的分布式训练方法

Data parallelism

一般是用 DDP(DistributedDataParallel),分布式数据并行。每张卡维护自己的一份模型拷贝,每步训练结果都会被合并与同步,是多卡训练最直接的方法。大致原理和步骤如下:

  1. 为每个 GPU 创建一个进程
  2. 每个 GPU 只会接收并处理数据集的一小部分
  3. 每个进程初始化各自的模型
  4. 每个进程各自进行前向和后向传播
  5. 同步各个进程的梯度,取平均
  6. 每个进程更新各自的优化器状态

必须整个模型(权重、优化器状态、激活、梯度等)能放到单个 GPU 才能使用 DDP。

单张卡放不下整个模型时,就需要分割模型到多个 GPU 上。这种方法称为模型并行(Model parallelism)。接下来介绍几个模型并行方案。

FDSP(Fully Sharded Data Parallelism)

完全共享数据并行,会拆分模型权重、梯度和优化器状态,能显著降低每个 GPU 的显存占用。但会引入频繁的 GPU 间通信。

Tensor parallelism

在张量级别对运算任务进行切分,每次运算结束就要进行同步。需要大量 GPU 间通信。

需要修改模型代码,所以不是很方便。

Pipeline parallelism

合理配置模型权重使其形成流水线,让多张卡依次处理一个个数据数据。相对来说 GPU 通信需求更低,不过这个方法要求模型本身适合流水线,也许要重写模型。

Lightning 的 FSDP

要在 Lightning Trainer 启用 FSDP,只需要向 Trainer 传入一个 FSDPStrategy 实例。

python 复制代码
from lightning.pytorch.strategies import FSDPStrategy

trainer = L.Trainer(accelerator="cuda", devices=2, strategy=FSDPStrategy())

auto_warp_policy

为了减少通信压力,要手动进行配置,防止参数量较少的层被拆分到不同 GPU。通过 auto_warp_policy 参数进行配置:

python 复制代码
policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
strategy = FSDPStrategy(auto_warp_policy=policy)

configure_model()

可以把模型初始化写在 Lightning Module 的 configure_model() 接口里,减少多余步骤(从 CPU 移动到 GPU),加快加载速度。

python 复制代码
class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        # don't instantiate layers here

    def configure_model(self):
        self.layers = nn.Sequential(...)

sharding_strategy

对于 FSDP,选择不同的拆分策略:

python 复制代码
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
  • FULL_SHARD,拆分权重、梯度和优化器状态
  • SHARD_GRAD_OP,拆分梯度和优化器状态
  • HYBRID_SHARD,多机器
  • NO_SHARD,不拆分任何东西,类似 DDP

activation_checkpointing_policy

激活检查点(Activation Checkpointing),也可以说是梯度检查点(Gradient Checkpointing),在正向传播时不存储所有层的激活 而在反向传播时重新计算激活,达成以时间换空间的效果。

通常将检查点的层设为和切分策略 auto_warp_policy 一样的值。

python 复制代码
strategy = FSDPStrategy(
    activation_checkpointing_policy={
        nn.TransformerEncoderLayer,
        nn.TransformerDecoderLayer,
    },
)

cpu_offload

可以向 FSDPStrategy 传入 cpu_offload=True 参数来节省相当多的显存,但训练速度会变得非常慢。

state_dict_type

多卡训练时,保存 checkpoint 为单个文件会非常的慢。向 FSDPStrategy 传入 state_dict_type="sharded" 可以将各个进程的状态分别存储,加快保存速度。

分布式存储的 checkpoints 即使在 GPU 数量发生变化的情况下也能加载,只要保证训练模式是 FSDP。

Lightning 提供了将分布式存储的 checkpoints 转换为单文件的方法,具体请看文档

其他建议

向优化器传入 foreach=False 节省一点显存尖峰。

贴着显存上限进行训练会导致频繁的显存回收操作,降低训练速度。可以向 FSDPStrategy 传入 limit_all_gathers=True 减缓这个问题。

除了用 auto_warp_policy 参数指定拆分策略,还可以用 from torch.distributed.fsdp.warp import warp 来手动封装层。像是这样:

python 复制代码
from torch.distributed.fsdp.wrap import wrap
    ...
    linear_layer = wrap(self.linear_layer)
    for i, layer in enumerate(self.block):
        self.block[i] = wrap(layer)

warp 封装在非 FSDP 训练时不会有任何影响。

碎碎念

按照 Lightning 的实验结果,训练速度排行如下。最快比起最慢快了 35%,显存占用增加了 140%。

DDP > > FSDP(SHARD_GRAD_OP) > FSDP(FULL_SHARD)

感觉 Lightning 的文章写得真好,清晰易懂深入痛点。在实例中讲解库的使用方式,用合理的节奏讲授相关知识点。不仅是 Lightning 使用说明,还是一篇优秀的教学文章。这也是 Lightning 库本身的理念吧,

参考来源

相关推荐
DuHz7 小时前
通过超宽带信号估计位置——论文精读
论文阅读·人工智能·机器学习·自动驾驶·汽车
喵手7 小时前
Python爬虫实战:针对Python官网,精准提取出每一个历史版本的版本号、发布日期以及对应的文档/详情页链接等信息,并最终清洗为标准化的CSV文件!
爬虫·python·爬虫实战·零基础python爬虫教学·python官方数据采集·采集历史版本版本号等信息·导出csv文件
Physicist in Geophy.7 小时前
一维波动方程(从变分法角度)
线性代数·算法·机器学习
databook7 小时前
像搭积木一样思考:数据科学中的“自下而上”之道
python·数据挖掘·数据分析
luoluoal7 小时前
基于python的医疗问句中的实体识别算法的研究(源码+文档)
python·mysql·django·毕业设计·源码
硅谷秋水7 小时前
REALM:用于机器人操作泛化能力的真实-仿真验证基准测试
人工智能·机器学习·计算机视觉·语言模型·机器人
啊阿狸不会拉杆7 小时前
《机器学习导论》第 9 章-决策树
人工智能·python·算法·决策树·机器学习·数据挖掘·剪枝
喵手7 小时前
Python爬虫实战:城市停车收费标准自动化采集系统 - 让停车费透明化的技术实践(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·城市停车收费标准·采集城市停车收费数据·采集停车数据csv文件导出
无水先生7 小时前
python函数的参数管理(01)*args和**kwargs
开发语言·python
曦月逸霜7 小时前
机器学习——个人笔记(持续更新中~)
人工智能·机器学习