利用PyTorch训练时的一些关于分布式训练的总结

1.PyTorch模型的并行化

PyTorch模型的并行化方法分为模型并行(Model Parallel)和数据并行(DataParallel) 。PyTorch主要支持的是数据并行化的概念,这个概念在PyTorch中分为两种类型,即数据并行化(Data Parallel, DP)和分布式数据并行化(Distributed Data Parallel, DDP)

2.两种数据并行化方式的说明及使用

(1)DP 使用的是torch.nn.DataParallel类

python 复制代码
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0) 

该类传入一个PyTorch的模块,且这个模块必须是先保存在主GPU上,其工作原理是将模型从主GPU设备上复制到device_ids指定的设备上;dim参数规定了迷你批次的分割方向。 使用示例如下:

python 复制代码
import torch.nn as nn
model = ...
model = model.cuda()
model = nn.DataParallel(model, device_ids=[0,1], dim=0)
output = model(input)

(2)DDP 使用的是torch.distributed分布式计算包,具体可分为对所有计算进程进行初始化、定义分布式训练的数据采样器、构建分布式数据并行模型。torch.distributed 提供了更好的接口和并行方式,搭配多进程接口 torch.multiprocessing可以提供更加高效的并行训练。 使用示例如下:

python 复制代码
"""""
@Author     :   jiguotong
@Contact    :   1776220977@qq.com
@site       :   
-----------------------------------------------
@Time       :   2024/8/1
@Description:   本代码用来测试torch的DDP使用方法;本代码使用PyTorch版本为2.0.1,不同版本调用方式不同
""" ""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import Dataset
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


class TestModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm1d(16)
        self.conv2 = nn.Conv1d(16, 32, 3, 1, 1)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32, 3, 3, 1, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.conv3(x))
        return x


class TestDataset(Dataset):

    def __init__(self, n):
        self.n = n
        pass

    def __len__(self):
        pass
        return self.n

    def __getitem__(self, index):
        data = torch.randn((3, 10000))
        target = torch.randn((3, 10000))
        return data, target


def main_worker(rank, world_size):
    model = TestModel().to(rank)
    train_dataset = TestDataset(100)

    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    # 用于分布式训练的数据采样器
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=4,
                                               sampler=train_sampler)
    # 构建分布式数据并行模型
    model = DDP(model, device_ids=[rank])

    optimizer = optim.SGD(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1000):
        # 避免数据一致
        train_sampler.set_epoch(epoch)
        for batch_idx, (input, target) in enumerate(train_loader):
            input = input.to(rank)
            target = target.to(rank)

            output = model(input)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if rank == 0:
            print("current epoch: {} , loss: {} ".format(epoch, loss.item()))
    print("Done!")


if __name__ == '__main__':

    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "7777"

    world_size = 2

    # 使用torch.multiprocessing开启多个进程
    mp.spawn(main_worker, args=(world_size, ), nprocs=world_size, join=True)

    # pass
相关推荐
知乎的哥廷根数学学派3 小时前
面向可信机械故障诊断的自适应置信度惩罚深度校准算法(Pytorch)
人工智能·pytorch·python·深度学习·算法·机器学习·矩阵
且去填词3 小时前
DeepSeek :基于 Schema 推理与自愈机制的智能 ETL
数据仓库·人工智能·python·语言模型·etl·schema·deepseek
txinyu的博客3 小时前
解析业务层的key冲突问题
开发语言·c++·分布式
人工干智能3 小时前
OpenAI Assistants API 中 client.beta.threads.messages.create方法,兼谈一星*和两星**解包
python·llm
databook4 小时前
当条形图遇上极坐标:径向与圆形条形图的视觉革命
python·数据分析·数据可视化
阿部多瑞 ABU4 小时前
`chenmo` —— 可编程元叙事引擎 V2.3+
linux·人工智能·python·ai写作
acanab4 小时前
VScode python插件
ide·vscode·python
知乎的哥廷根数学学派5 小时前
基于生成对抗U-Net混合架构的隧道衬砌缺陷地质雷达数据智能反演与成像方法(以模拟信号为例,Pytorch)
开发语言·人工智能·pytorch·python·深度学习·机器学习
WangYaolove13145 小时前
Python基于大数据的电影市场预测分析(源码+文档)
python·django·毕业设计·源码
知乎的哥廷根数学学派6 小时前
基于自适应多尺度小波核编码与注意力增强的脉冲神经网络机械故障诊断(Pytorch)
人工智能·pytorch·python·深度学习·神经网络·机器学习