分布式训练类的定义以及创建分布式模型

一 、分布式训练类的定义

python 复制代码
from ..modules import Module
from typing import Any, Optional
from .common_types import _devices_t, _device_t


class DistributedDataParallel(Module):
    process_group: Any = ...
    dim: int = ...
    module: Module = ...
    device_ids: _devices_t = ...
    output_device: _device_t = ...
    broadcast_buffers: bool = ...
    check_reduction: bool = ...
    broadcast_bucket_size: float = ...
    bucket_bytes_cap: float = ...

    # TODO type process_group once `distributed` module is stubbed
    def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
                 output_device: Optional[_device_t] = ..., dim: int = ...,
                 broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
                 find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...

Python 类的定义,该类名为 DistributedDataParallel,是 PyTorch 中用于分布式数据并行训练的模块

from ..modules import Module :导入 PyTorch 中的 Module 类,表示神经网络模块的基类

from typing import Any, Optional : 导入 Any 和 Optional 类型,用于类型注解

from .common_types import _devices_t, _device_t:导入 _devices_t 和 _device_t 类型

class DistributedDataParallel(Module): 定义了一个类 DistributedDataParallel,它继承自Module类

类属性:

复制代码
    process_group: Any = ... : 代表分布式训练的进程组
    dim: int = ...: 代表分布式的维度
    module: Module = ...:代表要进行并行处理的神经网络模块
    device_ids: _devices_t = ...:代表设备的 ID 列表
    output_device: _device_t = ...:代表输出设备
    broadcast_buffers: bool = ...:是否广播缓冲区
    check_reduction: bool = ...:是否检查减少操作
    broadcast_bucket_size: float = ...: 广播桶大小
    bucket_bytes_cap: float = ...:桶的字节容量上限
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
             output_device: Optional[_device_t] = ..., dim: int = ...,
             broadcast_buffers: bool = ..., process_group: Optional[Any] = ...,   bucket_cap_mb: float = ...,find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...:

def __init__(self, ...):类的初始化方法,用于初始化对象的属性。参数包括神经网络模块 module、设备 ID 列表 device_ids、输出设备 output_device 等。这些参数都有默认值,可以在初始化对象时提供或使用默认值

-> None: 表示初始化方法没有返回值

总体而言,这段代码定义了一个分布式数据并行训练的模块 DistributedDataParallel,该模块可以在多个设备上并行处理神经网络模块,实现分布式训练

二、创建分布式模型

这段代码创建了一个分布式数据并行(DDP)模型,并在必要时进行版本检查

根据 PyTorch 版本的不同,采取不同的配置参数来创建 DDP 模型

python 复制代码
def smart_DDP(model):  # 定义了一个名为 smart_DPP 的函数,该函数接受一个参数model,表示神经网络模型
    # Model DDP creation with checks  版本检查,其目的是确保不使用不受支持的PyTorch版本进行DDP训练
    assert not check_version(torch.__version__, '1.12.0', pinned=True), \
        'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
        'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
    #
    '''
    DDP 模型的创建:
    if check_version(torch.__version__, '1.11.0'): ...: 如果 PyTorch 版本为 1.11.0,则使用 DDP 类创建 DDP 模型,并设置 static_graph=True
    else: ...: 如果 PyTorch 版本不为 1.11.0,则使用 DDP 类创建 DDP 模型
    '''
    if check_version(torch.__version__, '1.11.0'):
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True): 返回创建的 DDP 模型对象
        device_ids 表示设备 ID 列表,这里设置为 [LOCAL_RANK],而 output_device 表示输出设备,也设置为 LOCAL_RANK
        如果 PyTorch 版本为 1.11.0,则 static_graph 被设置为 True
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
    else:
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

其中,static_graph=True 是在创建分布式数据并行模型时传递给 DDP 类的一个参数

这个参数的作用是告诉 PyTorch 是否使用静态图,如果将 static_graph 设置为 True,则表示希望使用静态图,这在某些情况下可以提高分布式训练的效率,尤其是在一些特定的 PyTorch 版本中可能需要使用静态图以避免问题

在 PyTorch 中,static_graph 参数是用于控制动态图(Dynamic Computational Graph)和静态图(Static Computational Graph)的一个设置。动态图和静态图是两种不同的计算图构建方式:

  1. 动态图(Dynamic Computational Graph)

    • 在动态图中,计算图是在运行时动态构建的,每次迭代都可以改变图的结构
    • PyTorch 的默认行为是使用动态图,这使得在模型训练过程中可以更灵活地调整模型结构
  2. 静态图(Static Computational Graph)

    • 在静态图中,计算图在模型定义阶段就被固定,不再改变。这意味着一旦定义了计算图,就无法在运行时修改
    • 静态图的优点之一是可以进行一些优化,例如静态图可以被预先分析以进行优化,从而提高计算效率
相关推荐
研究点啥好呢16 分钟前
3月15日GitHub热门项目推荐 | 当AI拥有记忆
人工智能·python·github·openclaw
小龙报22 分钟前
【AI】高效交互的艺术:AI提示工程与大模型对话指南
人工智能·深度学习·神经网络·自然语言处理·chatgpt·交互·语音识别
肖永威27 分钟前
Python 工程化实战:从目录结构到 VSCode 完美配置
vscode·python·python工程
HyperAI超神经1 小时前
基于2.5万临床数据,斯坦福大学发布首个原生3D腹部CT视觉语言模型,Merlin在752类任务中全面领先
人工智能·深度学习·神经网络·机器学习·3d·语言模型·cpu
smj2302_796826521 小时前
解决leetcode第3869题.统计区间内奇妙数的数目
python·算法·leetcode
AI视觉网奇2 小时前
pycharm ui 历史版本
python
sunxunyong2 小时前
spark History Server 重启失败
大数据·分布式·spark
只与明月听2 小时前
RAG深入学习之Emabedding
前端·python·面试
2401_883035462 小时前
数据分析与科学计算
jvm·数据库·python
我的xiaodoujiao2 小时前
API 接口自动化测试详细图文教程学习系列2--相关Python基础知识
python·学习·测试工具·pytest