利用PyTorch训练时的一些关于分布式训练的总结

1.PyTorch模型的并行化

PyTorch模型的并行化方法分为模型并行(Model Parallel)和数据并行(DataParallel) 。PyTorch主要支持的是数据并行化的概念,这个概念在PyTorch中分为两种类型,即数据并行化(Data Parallel, DP)和分布式数据并行化(Distributed Data Parallel, DDP)

2.两种数据并行化方式的说明及使用

(1)DP 使用的是torch.nn.DataParallel类

python 复制代码
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0) 

该类传入一个PyTorch的模块,且这个模块必须是先保存在主GPU上,其工作原理是将模型从主GPU设备上复制到device_ids指定的设备上;dim参数规定了迷你批次的分割方向。 使用示例如下:

python 复制代码
import torch.nn as nn
model = ...
model = model.cuda()
model = nn.DataParallel(model, device_ids=[0,1], dim=0)
output = model(input)

(2)DDP 使用的是torch.distributed分布式计算包,具体可分为对所有计算进程进行初始化、定义分布式训练的数据采样器、构建分布式数据并行模型。torch.distributed 提供了更好的接口和并行方式,搭配多进程接口 torch.multiprocessing可以提供更加高效的并行训练。 使用示例如下:

python 复制代码
"""""
@Author     :   jiguotong
@Contact    :   1776220977@qq.com
@site       :   
-----------------------------------------------
@Time       :   2024/8/1
@Description:   本代码用来测试torch的DDP使用方法;本代码使用PyTorch版本为2.0.1,不同版本调用方式不同
""" ""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import Dataset
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


class TestModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm1d(16)
        self.conv2 = nn.Conv1d(16, 32, 3, 1, 1)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32, 3, 3, 1, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.conv3(x))
        return x


class TestDataset(Dataset):

    def __init__(self, n):
        self.n = n
        pass

    def __len__(self):
        pass
        return self.n

    def __getitem__(self, index):
        data = torch.randn((3, 10000))
        target = torch.randn((3, 10000))
        return data, target


def main_worker(rank, world_size):
    model = TestModel().to(rank)
    train_dataset = TestDataset(100)

    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    # 用于分布式训练的数据采样器
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=4,
                                               sampler=train_sampler)
    # 构建分布式数据并行模型
    model = DDP(model, device_ids=[rank])

    optimizer = optim.SGD(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1000):
        # 避免数据一致
        train_sampler.set_epoch(epoch)
        for batch_idx, (input, target) in enumerate(train_loader):
            input = input.to(rank)
            target = target.to(rank)

            output = model(input)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if rank == 0:
            print("current epoch: {} , loss: {} ".format(epoch, loss.item()))
    print("Done!")


if __name__ == '__main__':

    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "7777"

    world_size = 2

    # 使用torch.multiprocessing开启多个进程
    mp.spawn(main_worker, args=(world_size, ), nprocs=world_size, join=True)

    # pass
相关推荐
Jetev几秒前
宝塔面板如何实现网站重定向_配置301永久跳转与域名更换
jvm·数据库·python
AI机器学习算法几秒前
说走就走的AI之旅第01课:浅谈机器学习
数据结构·人工智能·python·深度学习·机器学习·大模型·线性回归
༒࿈南林࿈༒几秒前
yi欣考研刷题题库js逆向
python·js逆向
idolao2 分钟前
CentOS 7 安装 libtool-1.5.22.tar.gz 详细步骤(源码编译、配置、验证)
开发语言·python
2401_833033622 分钟前
c++如何解析二进制协议中的可选字段与默认值读取逻辑实现【实战】
jvm·数据库·python
源码之家3 分钟前
计算机毕业设计:Python基于数据挖掘的医院疾病分析与预测系统 Flask框架 数据分析 可视化 ARIMA算法 大数据 大模型(建议收藏)✅
python·信息可视化·数据挖掘·数据分析·flask·lstm·课程设计
Francek Chen4 分钟前
【大数据存储与管理】云数据库:02 云数据库产品
大数据·数据库·分布式·云计算·云数据库
m0_591364739 分钟前
CSS 背景图滑动切换:纯 CSS 实现右进左出轮播效果
jvm·数据库·python
2401_8242226910 分钟前
Python测试代码如何实现自解释_使用pytest描述性命名规范
jvm·数据库·python
接着奏乐接着舞12 分钟前
springboot 常用注解
spring boot·后端·python