利用PyTorch训练时的一些关于分布式训练的总结

1.PyTorch模型的并行化

PyTorch模型的并行化方法分为模型并行(Model Parallel)和数据并行(DataParallel) 。PyTorch主要支持的是数据并行化的概念,这个概念在PyTorch中分为两种类型,即数据并行化(Data Parallel, DP)和分布式数据并行化(Distributed Data Parallel, DDP)

2.两种数据并行化方式的说明及使用

(1)DP 使用的是torch.nn.DataParallel类

python 复制代码
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0) 

该类传入一个PyTorch的模块,且这个模块必须是先保存在主GPU上,其工作原理是将模型从主GPU设备上复制到device_ids指定的设备上;dim参数规定了迷你批次的分割方向。 使用示例如下:

python 复制代码
import torch.nn as nn
model = ...
model = model.cuda()
model = nn.DataParallel(model, device_ids=[0,1], dim=0)
output = model(input)

(2)DDP 使用的是torch.distributed分布式计算包,具体可分为对所有计算进程进行初始化、定义分布式训练的数据采样器、构建分布式数据并行模型。torch.distributed 提供了更好的接口和并行方式,搭配多进程接口 torch.multiprocessing可以提供更加高效的并行训练。 使用示例如下:

python 复制代码
"""""
@Author     :   jiguotong
@Contact    :   [email protected]
@site       :   
-----------------------------------------------
@Time       :   2024/8/1
@Description:   本代码用来测试torch的DDP使用方法;本代码使用PyTorch版本为2.0.1,不同版本调用方式不同
""" ""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import Dataset
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


class TestModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm1d(16)
        self.conv2 = nn.Conv1d(16, 32, 3, 1, 1)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32, 3, 3, 1, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.conv3(x))
        return x


class TestDataset(Dataset):

    def __init__(self, n):
        self.n = n
        pass

    def __len__(self):
        pass
        return self.n

    def __getitem__(self, index):
        data = torch.randn((3, 10000))
        target = torch.randn((3, 10000))
        return data, target


def main_worker(rank, world_size):
    model = TestModel().to(rank)
    train_dataset = TestDataset(100)

    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    # 用于分布式训练的数据采样器
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=4,
                                               sampler=train_sampler)
    # 构建分布式数据并行模型
    model = DDP(model, device_ids=[rank])

    optimizer = optim.SGD(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1000):
        # 避免数据一致
        train_sampler.set_epoch(epoch)
        for batch_idx, (input, target) in enumerate(train_loader):
            input = input.to(rank)
            target = target.to(rank)

            output = model(input)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if rank == 0:
            print("current epoch: {} , loss: {} ".format(epoch, loss.item()))
    print("Done!")


if __name__ == '__main__':

    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "7777"

    world_size = 2

    # 使用torch.multiprocessing开启多个进程
    mp.spawn(main_worker, args=(world_size, ), nprocs=world_size, join=True)

    # pass
相关推荐
zhanghongyi_cpp44 分钟前
4. “3+3”高考选考科目问题
python
知识中的海王3 小时前
js逆向入门图灵爬虫练习平台第六题
python
碳基学AI4 小时前
北京大学DeepSeek内部研讨系列:AI在新媒体运营中的应用与挑战|122页PPT下载方法
大数据·人工智能·python·算法·ai·新媒体运营·产品运营
forestsea4 小时前
Python进阶编程总结
开发语言·python·notepad++
袖清暮雨5 小时前
Python刷题笔记
笔记·python·算法
乌旭6 小时前
AI芯片混战:GPU vs TPU vs NPU的算力与能效博弈
人工智能·pytorch·python·深度学习·机器学习·ai·ai编程
群联云防护小杜6 小时前
基于AI的Web应用防火墙(AppWall)实战:漏洞拦截与威胁情报集成
前端·分布式·安全·ddos
MinggeQingchun6 小时前
Python - 爬虫-网页抓取数据-库requests
爬虫·python·requests
拓端研究室TRL7 小时前
Python贝叶斯回归、强化学习分析医疗健康数据拟合截断删失数据与参数估计3实例
开发语言·人工智能·python·数据挖掘·回归
云之兕7 小时前
分布式ID生成器设计详解
分布式