分布式多卡训练,以及 Lightning 中启用 FSDP

总览

最近有多显卡训练的需求,于是研究了一下分布式训练。

大致来说分为 Data Parallelism 和 Model Parallelism 两种策略。前者相当于是单张卡的训练直接拷贝到多张卡并行运行,后者则要切分模型权重让多显卡协作。两种策略具有代表性的方案分别是 DDP、FDSP。

我选择多卡训练是因为单卡跑不动而不是速度不够快,所以 DDP 就不考虑了。FDSP 算是比较方便不需要大量修改代码的方案,想试试。

本文后面附带 Lightning 使用 FDSP 的方法。

不同的分布式训练方法

Data parallelism

一般是用 DDP(DistributedDataParallel),分布式数据并行。每张卡维护自己的一份模型拷贝,每步训练结果都会被合并与同步,是多卡训练最直接的方法。大致原理和步骤如下:

  1. 为每个 GPU 创建一个进程
  2. 每个 GPU 只会接收并处理数据集的一小部分
  3. 每个进程初始化各自的模型
  4. 每个进程各自进行前向和后向传播
  5. 同步各个进程的梯度,取平均
  6. 每个进程更新各自的优化器状态

必须整个模型(权重、优化器状态、激活、梯度等)能放到单个 GPU 才能使用 DDP。

单张卡放不下整个模型时,就需要分割模型到多个 GPU 上。这种方法称为模型并行(Model parallelism)。接下来介绍几个模型并行方案。

FDSP(Fully Sharded Data Parallelism)

完全共享数据并行,会拆分模型权重、梯度和优化器状态,能显著降低每个 GPU 的显存占用。但会引入频繁的 GPU 间通信。

Tensor parallelism

在张量级别对运算任务进行切分,每次运算结束就要进行同步。需要大量 GPU 间通信。

需要修改模型代码,所以不是很方便。

Pipeline parallelism

合理配置模型权重使其形成流水线,让多张卡依次处理一个个数据数据。相对来说 GPU 通信需求更低,不过这个方法要求模型本身适合流水线,也许要重写模型。

Lightning 的 FSDP

要在 Lightning Trainer 启用 FSDP,只需要向 Trainer 传入一个 FSDPStrategy 实例。

python 复制代码
from lightning.pytorch.strategies import FSDPStrategy

trainer = L.Trainer(accelerator="cuda", devices=2, strategy=FSDPStrategy())

auto_warp_policy

为了减少通信压力,要手动进行配置,防止参数量较少的层被拆分到不同 GPU。通过 auto_warp_policy 参数进行配置:

python 复制代码
policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
strategy = FSDPStrategy(auto_warp_policy=policy)

configure_model()

可以把模型初始化写在 Lightning Module 的 configure_model() 接口里,减少多余步骤(从 CPU 移动到 GPU),加快加载速度。

python 复制代码
class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        # don't instantiate layers here

    def configure_model(self):
        self.layers = nn.Sequential(...)

sharding_strategy

对于 FSDP,选择不同的拆分策略:

python 复制代码
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
  • FULL_SHARD,拆分权重、梯度和优化器状态
  • SHARD_GRAD_OP,拆分梯度和优化器状态
  • HYBRID_SHARD,多机器
  • NO_SHARD,不拆分任何东西,类似 DDP

activation_checkpointing_policy

激活检查点(Activation Checkpointing),也可以说是梯度检查点(Gradient Checkpointing),在正向传播时不存储所有层的激活 而在反向传播时重新计算激活,达成以时间换空间的效果。

通常将检查点的层设为和切分策略 auto_warp_policy 一样的值。

python 复制代码
strategy = FSDPStrategy(
    activation_checkpointing_policy={
        nn.TransformerEncoderLayer,
        nn.TransformerDecoderLayer,
    },
)

cpu_offload

可以向 FSDPStrategy 传入 cpu_offload=True 参数来节省相当多的显存,但训练速度会变得非常慢。

state_dict_type

多卡训练时,保存 checkpoint 为单个文件会非常的慢。向 FSDPStrategy 传入 state_dict_type="sharded" 可以将各个进程的状态分别存储,加快保存速度。

分布式存储的 checkpoints 即使在 GPU 数量发生变化的情况下也能加载,只要保证训练模式是 FSDP。

Lightning 提供了将分布式存储的 checkpoints 转换为单文件的方法,具体请看文档

其他建议

向优化器传入 foreach=False 节省一点显存尖峰。

贴着显存上限进行训练会导致频繁的显存回收操作,降低训练速度。可以向 FSDPStrategy 传入 limit_all_gathers=True 减缓这个问题。

除了用 auto_warp_policy 参数指定拆分策略,还可以用 from torch.distributed.fsdp.warp import warp 来手动封装层。像是这样:

python 复制代码
from torch.distributed.fsdp.wrap import wrap
    ...
    linear_layer = wrap(self.linear_layer)
    for i, layer in enumerate(self.block):
        self.block[i] = wrap(layer)

warp 封装在非 FSDP 训练时不会有任何影响。

碎碎念

按照 Lightning 的实验结果,训练速度排行如下。最快比起最慢快了 35%,显存占用增加了 140%。

DDP > > FSDP(SHARD_GRAD_OP) > FSDP(FULL_SHARD)

感觉 Lightning 的文章写得真好,清晰易懂深入痛点。在实例中讲解库的使用方式,用合理的节奏讲授相关知识点。不仅是 Lightning 使用说明,还是一篇优秀的教学文章。这也是 Lightning 库本身的理念吧,

参考来源

相关推荐
2401_838472515 分钟前
使用Python处理计算机图形学(PIL/Pillow)
jvm·数据库·python
盼小辉丶7 分钟前
PyTorch实战(27)——自动混合精度训练
pytorch·深度学习·混合精度训练
深蓝电商API14 分钟前
aiohttp爬取带登录态的异步请求
爬虫·python
rainbow688917 分钟前
Python学生管理系统:JSON持久化实战
java·前端·python
咕噜咕噜啦啦28 分钟前
ROS入门
linux·vscode·python
2301_7903009628 分钟前
用Matplotlib绘制专业图表:从基础到高级
jvm·数据库·python
XLYcmy36 分钟前
一个用于统计文本文件行数的Python实用工具脚本
开发语言·数据结构·windows·python·开发工具·数据处理·源代码
DFT计算杂谈1 小时前
VASP+PHONOPY+pypolymlpj计算不同温度下声子谱,附批处理脚本
java·前端·数据库·人工智能·python