分布式多卡训练,以及 Lightning 中启用 FSDP

总览

最近有多显卡训练的需求,于是研究了一下分布式训练。

大致来说分为 Data Parallelism 和 Model Parallelism 两种策略。前者相当于是单张卡的训练直接拷贝到多张卡并行运行,后者则要切分模型权重让多显卡协作。两种策略具有代表性的方案分别是 DDP、FDSP。

我选择多卡训练是因为单卡跑不动而不是速度不够快,所以 DDP 就不考虑了。FDSP 算是比较方便不需要大量修改代码的方案,想试试。

本文后面附带 Lightning 使用 FDSP 的方法。

不同的分布式训练方法

Data parallelism

一般是用 DDP(DistributedDataParallel),分布式数据并行。每张卡维护自己的一份模型拷贝,每步训练结果都会被合并与同步,是多卡训练最直接的方法。大致原理和步骤如下:

  1. 为每个 GPU 创建一个进程
  2. 每个 GPU 只会接收并处理数据集的一小部分
  3. 每个进程初始化各自的模型
  4. 每个进程各自进行前向和后向传播
  5. 同步各个进程的梯度,取平均
  6. 每个进程更新各自的优化器状态

必须整个模型(权重、优化器状态、激活、梯度等)能放到单个 GPU 才能使用 DDP。

单张卡放不下整个模型时,就需要分割模型到多个 GPU 上。这种方法称为模型并行(Model parallelism)。接下来介绍几个模型并行方案。

FDSP(Fully Sharded Data Parallelism)

完全共享数据并行,会拆分模型权重、梯度和优化器状态,能显著降低每个 GPU 的显存占用。但会引入频繁的 GPU 间通信。

Tensor parallelism

在张量级别对运算任务进行切分,每次运算结束就要进行同步。需要大量 GPU 间通信。

需要修改模型代码,所以不是很方便。

Pipeline parallelism

合理配置模型权重使其形成流水线,让多张卡依次处理一个个数据数据。相对来说 GPU 通信需求更低,不过这个方法要求模型本身适合流水线,也许要重写模型。

Lightning 的 FSDP

要在 Lightning Trainer 启用 FSDP,只需要向 Trainer 传入一个 FSDPStrategy 实例。

python 复制代码
from lightning.pytorch.strategies import FSDPStrategy

trainer = L.Trainer(accelerator="cuda", devices=2, strategy=FSDPStrategy())

auto_warp_policy

为了减少通信压力,要手动进行配置,防止参数量较少的层被拆分到不同 GPU。通过 auto_warp_policy 参数进行配置:

python 复制代码
policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
strategy = FSDPStrategy(auto_warp_policy=policy)

configure_model()

可以把模型初始化写在 Lightning Module 的 configure_model() 接口里,减少多余步骤(从 CPU 移动到 GPU),加快加载速度。

python 复制代码
class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        # don't instantiate layers here

    def configure_model(self):
        self.layers = nn.Sequential(...)

sharding_strategy

对于 FSDP,选择不同的拆分策略:

python 复制代码
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
  • FULL_SHARD,拆分权重、梯度和优化器状态
  • SHARD_GRAD_OP,拆分梯度和优化器状态
  • HYBRID_SHARD,多机器
  • NO_SHARD,不拆分任何东西,类似 DDP

activation_checkpointing_policy

激活检查点(Activation Checkpointing),也可以说是梯度检查点(Gradient Checkpointing),在正向传播时不存储所有层的激活 而在反向传播时重新计算激活,达成以时间换空间的效果。

通常将检查点的层设为和切分策略 auto_warp_policy 一样的值。

python 复制代码
strategy = FSDPStrategy(
    activation_checkpointing_policy={
        nn.TransformerEncoderLayer,
        nn.TransformerDecoderLayer,
    },
)

cpu_offload

可以向 FSDPStrategy 传入 cpu_offload=True 参数来节省相当多的显存,但训练速度会变得非常慢。

state_dict_type

多卡训练时,保存 checkpoint 为单个文件会非常的慢。向 FSDPStrategy 传入 state_dict_type="sharded" 可以将各个进程的状态分别存储,加快保存速度。

分布式存储的 checkpoints 即使在 GPU 数量发生变化的情况下也能加载,只要保证训练模式是 FSDP。

Lightning 提供了将分布式存储的 checkpoints 转换为单文件的方法,具体请看文档

其他建议

向优化器传入 foreach=False 节省一点显存尖峰。

贴着显存上限进行训练会导致频繁的显存回收操作,降低训练速度。可以向 FSDPStrategy 传入 limit_all_gathers=True 减缓这个问题。

除了用 auto_warp_policy 参数指定拆分策略,还可以用 from torch.distributed.fsdp.warp import warp 来手动封装层。像是这样:

python 复制代码
from torch.distributed.fsdp.wrap import wrap
    ...
    linear_layer = wrap(self.linear_layer)
    for i, layer in enumerate(self.block):
        self.block[i] = wrap(layer)

warp 封装在非 FSDP 训练时不会有任何影响。

碎碎念

按照 Lightning 的实验结果,训练速度排行如下。最快比起最慢快了 35%,显存占用增加了 140%。

DDP > > FSDP(SHARD_GRAD_OP) > FSDP(FULL_SHARD)

感觉 Lightning 的文章写得真好,清晰易懂深入痛点。在实例中讲解库的使用方式,用合理的节奏讲授相关知识点。不仅是 Lightning 使用说明,还是一篇优秀的教学文章。这也是 Lightning 库本身的理念吧,

参考来源

相关推荐
青钰未央3 小时前
19、Python字符串高阶实战:转义字符深度解析、高效拼接与输入处理技巧
python·改行学it
Blue桃之夭夭5 小时前
Python进阶【四】:XML和JSON文件处理
xml·python·json
悲喜自渡7215 小时前
ST-GCN
pytorch
开发者工具分享5 小时前
Lua 的速度为什么比 Python 快
开发语言·python·lua
蔗理苦5 小时前
2025-05-28 Python&深度学习8——优化器
开发语言·pytorch·python·深度学习·优化器
杰瑞学AI6 小时前
在PyTorch中,对于一个张量,如何快速为多个元素赋值相同的值
人工智能·pytorch·python
写代码的小阿帆6 小时前
Attention Is All You Need论文阅读笔记
论文阅读·深度学习·机器学习·transformer
hongjianMa7 小时前
【论文阅读】User Diverse Preference Modeling by Multimodal Attentive Metric Learning
论文阅读·python·推荐系统·多模态推荐
乖乖der7 小时前
python同步mysql数据
开发语言·python·mysql
渐消散8 小时前
人工智障玩游戏
python