深度学习中的并行策略概述:2 Data Parallelism

深度学习中的并行策略概述:2 Data Parallelism

数据并行(Data Parallelism)的核心在于将模型的数据处理过程并行化。具体来说,面对大规模数据批次时,将其拆分为较小的子批次,并在多个计算设备上同时进行处理。每个设备负责处理一个子批次,实现并行计算。处理完成后,将各个设备上的计算结果汇总,以便对模型进行统一更新。由于其在深度学习中的普遍应用,数据并行成为了一种广泛支持的并行计算策略,并在主流框架中得到了良好的实现。

以下代码展示了如何在PyTorch中使用nn.DataParallel和DistributedDataParallel实现数据并行,以加速模型的训练过程。

使用nn.DataParallel实现数据并行

python 复制代码
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(0, 2, (n_sample,)).float()
dataset = SimpleDataset(X, Y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型
device_ids = [0, 1, 2]  # 指定使用的GPU编号
model = SimpleModel(n_dim).to(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)

# 定义优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()

# 训练模型
for epoch in range(10):
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs, targets = inputs.to('cuda'), targets.to('cuda')
        outputs = model(inputs)
        loss = criterion(outputs, targets.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

使用DistributedDataParallel实现数据并行

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

# 初始化进程组
def init_process(rank, world_size, backend='nccl'):
    dist.init_process_group(backend, rank=rank, world_size=world_size)

# 训练函数
def train(rank, world_size):
    init_process(rank, world_size)
    torch.cuda.set_device(rank)
    model = SimpleModel(10).to(rank)
    model = DDP(model, device_ids=[rank])

    dataset = SimpleDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)).float())
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    data_loader = DataLoader(dataset, batch_size=10, sampler=sampler)

    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.BCELoss()

    for epoch in range(10):
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(rank), targets.to(rank)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets.unsqueeze(1))
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    world_size = 4
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
相关推荐
mit6.824几秒前
[1Prompt1Story] 注意力机制增强 IPCA | 去噪神经网络 UNet | U型架构分步去噪
人工智能·深度学习·神经网络
挽淚17 分钟前
(小白向)什么是Prompt,RAG,Agent,Function Calling和MCP ?
人工智能·程序员
Jina AI24 分钟前
回归C++: 在GGUF上构建高效的向量模型
人工智能·算法·机器学习·数据挖掘·回归
Coovally AI模型快速验证29 分钟前
YOLO、DarkNet和深度学习如何让自动驾驶看得清?
深度学习·算法·yolo·cnn·自动驾驶·transformer·无人机
科大饭桶1 小时前
昇腾AI自学Day2-- 深度学习基础工具与数学
人工智能·pytorch·python·深度学习·numpy
什么都想学的阿超1 小时前
【大语言模型 02】多头注意力深度剖析:为什么需要多个头
人工智能·语言模型·自然语言处理
努力还债的学术吗喽2 小时前
2021 IEEE【论文精读】用GAN让音频隐写术骗过AI检测器 - 对抗深度学习的音频信息隐藏
人工智能·深度学习·生成对抗网络·密码学·音频·gan·隐写
明道云创始人任向晖2 小时前
20个进入实用阶段的AI应用场景(零售电商业篇)
人工智能·零售
数据智研2 小时前
【数据分享】大清河(大庆河)流域上游土地利用
人工智能
聚客AI2 小时前
🔷告别天价算力!2025性价比最高的LLM私有化训练路径
人工智能·llm·掘金·日新计划