torch分布式通信基础

torch分布式通信基础

  • [1. 点到点通信](#1. 点到点通信)
  • [2. 集群通信](#2. 集群通信)

官网文档:WRITING DISTRIBUTED APPLICATIONS WITH PYTORCH

1. 点到点通信

python 复制代码
# 同步,peer-2-peer数据传递
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def test_send_recv_sync(rank, size):
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        dist.send(tensor=tensor, dst=1) # 需要指定dst,发送的目标
    else:
        dist.recv(tensor=tensor, src=0) # 需要指定src,从哪儿接收
    print('Rank ', rank, ' has data ', tensor[0])

# 异步
def test_send_recv_async(rank, size):
    tensor = torch.zeros(1)
    req = None
    if rank == 0:
        tensor += 1
        req = dist.isend(tensor=tensor, dst=1)
    else:
        req = dist.irecv(tensor=tensor, src=0)
    req.wait()
    print('Rank ', rank, ' has data ', tensor[0])

def init_process(rank, size, backend='gloo'):
    """ 这里初始化分布式环境,设定Master机器以及端口号 """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29598'
    dist.init_process_group(backend, rank=rank, world_size=size)
    #test_send_recv_sync(rank, size)
    test_send_recv_async(rank, size)

if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

2. 集群通信



python 复制代码
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def test_broadcast(rank, size):
    tensor = torch.zeros(1)
    if rank == 0:
      tensor += 2
    else:
      tensor += 1
    dist.broadcast(tensor=tensor,src=0) # src指定broad_cast的源
    print("******test_broadcast******")
    print('Rank ', rank, ' has data ', tensor) # 结果都是 2

def test_scatter(rank, size):
    tensor = torch.zeros(1)
    if rank == 0:
      tensor_list = [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0]), torch.tensor([4.0])]
      dist.scatter(tensor, scatter_list = tensor_list, src = 0)
    else:
      dist.scatter(tensor, scatter_list = [], src = 0)
    print("******test_scatter******")
    print('Rank ', rank, ' has data ', tensor) # 结果是[[1], [2], [3], [4]]

def test_reduce(rank, size):
    tensor = torch.ones(1)
    dist.reduce(tensor=tensor, dst=0) # dst指定哪个进程进行reduce, 默认操作是加法
    print("******test_reduce******")
    print('Rank ', rank, ' has data ', tensor)

def test_all_reduce(rank, size):
    tensor = torch.ones(1)
    dist.all_reduce(tensor=tensor,op=dist.ReduceOp.SUM)
    print("******test_all_reduce******")
    print('Rank ', rank, ' has data ', tensor)  # 结果都是 4

def test_gather(rank, size):
    tensor = torch.ones(1)
    if rank == 0:
      output = [torch.zeros(1) for _ in range(size)]
      dist.gather(tensor, gather_list=output, dst=0)
    else:
      dist.gather(tensor, gather_list=[], dst=0)
    if rank == 0:
      print("******test_gather******")
      print('Rank ', rank, ' has data ', output)  # 结果是 [[1,1,1,1]]

def test_all_gather(rank, size):
    output = [torch.zeros(1) for _ in range(size)]
    tensor = torch.ones(1)
    dist.all_gather(output, tensor)
    print("******test_all_gather******")
    print('Rank ', rank, ' has data ', output)  # 结果都是 [1,1,1,1]

def init_process(rank, size, backend='gloo'):
    """ 这里初始化分布式环境,设定Master机器以及端口号 """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29596'
    dist.init_process_group(backend, rank=rank, world_size=size)
    test_reduce(rank, size)
    test_all_reduce(rank, size)
    test_gather(rank, size)
    test_all_gather(rank, size)
    test_broadcast(rank, size)
    test_scatter(rank, size)

if __name__ == "__main__":
    size = 4
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

需要注意的一点是:

这里面的调用都是同步的,可以理解为,每个进程都调用到通信api时,真正的有效数据传输才开始,然后通信完成之后,代码继续往下跑。实际上有些通信进程并不获取数据,这些进程可能并不会被阻塞。

文档最后,提供了一个简单的类似 DDP 的实现,里面核心的部分就是:

这也进一步阐释了DDP的核心逻辑:
反向计算完成之后,汇总梯度信息(求均值),然后再更新参数

相关推荐
Data跳动1 小时前
Spark内存都消耗在哪里了?
大数据·分布式·spark
Java程序之猿3 小时前
微服务分布式(一、项目初始化)
分布式·微服务·架构
来一杯龙舌兰3 小时前
【RabbitMQ】RabbitMQ保证消息不丢失的N种策略的思想总结
分布式·rabbitmq·ruby·持久化·ack·消息确认
节点。csn5 小时前
Hadoop yarn安装
大数据·hadoop·分布式
NiNg_1_2346 小时前
基于Hadoop的数据清洗
大数据·hadoop·分布式
隔着天花板看星星8 小时前
Spark-Streaming集成Kafka
大数据·分布式·中间件·spark·kafka
技术路上的苦行僧12 小时前
分布式专题(8)之MongoDB存储原理&多文档事务详解
数据库·分布式·mongodb
龙哥·三年风水12 小时前
workman服务端开发模式-应用开发-后端api推送修改二
分布式·gateway·php
小小工匠13 小时前
分布式协同 - 分布式事务_2PC & 3PC解决方案
分布式·分布式事务·2pc·3pc
闯闯的日常分享15 小时前
分布式锁的原理分析
分布式