分布式训练类的定义以及创建分布式模型

一 、分布式训练类的定义

python 复制代码
from ..modules import Module
from typing import Any, Optional
from .common_types import _devices_t, _device_t


class DistributedDataParallel(Module):
    process_group: Any = ...
    dim: int = ...
    module: Module = ...
    device_ids: _devices_t = ...
    output_device: _device_t = ...
    broadcast_buffers: bool = ...
    check_reduction: bool = ...
    broadcast_bucket_size: float = ...
    bucket_bytes_cap: float = ...

    # TODO type process_group once `distributed` module is stubbed
    def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
                 output_device: Optional[_device_t] = ..., dim: int = ...,
                 broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
                 find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...

Python 类的定义,该类名为 DistributedDataParallel,是 PyTorch 中用于分布式数据并行训练的模块

from ..modules import Module :导入 PyTorch 中的 Module 类,表示神经网络模块的基类

from typing import Any, Optional : 导入 Any 和 Optional 类型,用于类型注解

from .common_types import _devices_t, _device_t:导入 _devices_t 和 _device_t 类型

class DistributedDataParallel(Module): 定义了一个类 DistributedDataParallel,它继承自Module类

类属性:

复制代码
    process_group: Any = ... : 代表分布式训练的进程组
    dim: int = ...: 代表分布式的维度
    module: Module = ...:代表要进行并行处理的神经网络模块
    device_ids: _devices_t = ...:代表设备的 ID 列表
    output_device: _device_t = ...:代表输出设备
    broadcast_buffers: bool = ...:是否广播缓冲区
    check_reduction: bool = ...:是否检查减少操作
    broadcast_bucket_size: float = ...: 广播桶大小
    bucket_bytes_cap: float = ...:桶的字节容量上限
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
             output_device: Optional[_device_t] = ..., dim: int = ...,
             broadcast_buffers: bool = ..., process_group: Optional[Any] = ...,   bucket_cap_mb: float = ...,find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...:

def __init__(self, ...):类的初始化方法,用于初始化对象的属性。参数包括神经网络模块 module、设备 ID 列表 device_ids、输出设备 output_device 等。这些参数都有默认值,可以在初始化对象时提供或使用默认值

-> None: 表示初始化方法没有返回值

总体而言,这段代码定义了一个分布式数据并行训练的模块 DistributedDataParallel,该模块可以在多个设备上并行处理神经网络模块,实现分布式训练

二、创建分布式模型

这段代码创建了一个分布式数据并行(DDP)模型,并在必要时进行版本检查

根据 PyTorch 版本的不同,采取不同的配置参数来创建 DDP 模型

python 复制代码
def smart_DDP(model):  # 定义了一个名为 smart_DPP 的函数,该函数接受一个参数model,表示神经网络模型
    # Model DDP creation with checks  版本检查,其目的是确保不使用不受支持的PyTorch版本进行DDP训练
    assert not check_version(torch.__version__, '1.12.0', pinned=True), \
        'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
        'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
    #
    '''
    DDP 模型的创建:
    if check_version(torch.__version__, '1.11.0'): ...: 如果 PyTorch 版本为 1.11.0,则使用 DDP 类创建 DDP 模型,并设置 static_graph=True
    else: ...: 如果 PyTorch 版本不为 1.11.0,则使用 DDP 类创建 DDP 模型
    '''
    if check_version(torch.__version__, '1.11.0'):
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True): 返回创建的 DDP 模型对象
        device_ids 表示设备 ID 列表,这里设置为 [LOCAL_RANK],而 output_device 表示输出设备,也设置为 LOCAL_RANK
        如果 PyTorch 版本为 1.11.0,则 static_graph 被设置为 True
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
    else:
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

其中,static_graph=True 是在创建分布式数据并行模型时传递给 DDP 类的一个参数

这个参数的作用是告诉 PyTorch 是否使用静态图,如果将 static_graph 设置为 True,则表示希望使用静态图,这在某些情况下可以提高分布式训练的效率,尤其是在一些特定的 PyTorch 版本中可能需要使用静态图以避免问题

在 PyTorch 中,static_graph 参数是用于控制动态图(Dynamic Computational Graph)和静态图(Static Computational Graph)的一个设置。动态图和静态图是两种不同的计算图构建方式:

  1. 动态图(Dynamic Computational Graph)

    • 在动态图中,计算图是在运行时动态构建的,每次迭代都可以改变图的结构
    • PyTorch 的默认行为是使用动态图,这使得在模型训练过程中可以更灵活地调整模型结构
  2. 静态图(Static Computational Graph)

    • 在静态图中,计算图在模型定义阶段就被固定,不再改变。这意味着一旦定义了计算图,就无法在运行时修改
    • 静态图的优点之一是可以进行一些优化,例如静态图可以被预先分析以进行优化,从而提高计算效率
相关推荐
databook6 小时前
Manim实现闪光轨迹特效
后端·python·动效
Juchecar7 小时前
解惑:NumPy 中 ndarray.ndim 到底是什么?
python
用户8356290780517 小时前
Python 删除 Excel 工作表中的空白行列
后端·python
Json_7 小时前
使用python-fastApi框架开发一个学校宿舍管理系统-前后端分离项目
后端·python·fastapi
数据智能老司机14 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机15 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机15 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机15 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i15 小时前
drf初步梳理
python·django
每日AI新事件15 小时前
python的异步函数
python