分布式训练类的定义以及创建分布式模型

一 、分布式训练类的定义

python 复制代码
from ..modules import Module
from typing import Any, Optional
from .common_types import _devices_t, _device_t


class DistributedDataParallel(Module):
    process_group: Any = ...
    dim: int = ...
    module: Module = ...
    device_ids: _devices_t = ...
    output_device: _device_t = ...
    broadcast_buffers: bool = ...
    check_reduction: bool = ...
    broadcast_bucket_size: float = ...
    bucket_bytes_cap: float = ...

    # TODO type process_group once `distributed` module is stubbed
    def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
                 output_device: Optional[_device_t] = ..., dim: int = ...,
                 broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
                 find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...

Python 类的定义,该类名为 DistributedDataParallel,是 PyTorch 中用于分布式数据并行训练的模块

from ..modules import Module :导入 PyTorch 中的 Module 类,表示神经网络模块的基类

from typing import Any, Optional : 导入 Any 和 Optional 类型,用于类型注解

from .common_types import _devices_t, _device_t:导入 _devices_t 和 _device_t 类型

class DistributedDataParallel(Module): 定义了一个类 DistributedDataParallel,它继承自Module类

类属性:

复制代码
    process_group: Any = ... : 代表分布式训练的进程组
    dim: int = ...: 代表分布式的维度
    module: Module = ...:代表要进行并行处理的神经网络模块
    device_ids: _devices_t = ...:代表设备的 ID 列表
    output_device: _device_t = ...:代表输出设备
    broadcast_buffers: bool = ...:是否广播缓冲区
    check_reduction: bool = ...:是否检查减少操作
    broadcast_bucket_size: float = ...: 广播桶大小
    bucket_bytes_cap: float = ...:桶的字节容量上限
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
             output_device: Optional[_device_t] = ..., dim: int = ...,
             broadcast_buffers: bool = ..., process_group: Optional[Any] = ...,   bucket_cap_mb: float = ...,find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...:

def __init__(self, ...):类的初始化方法,用于初始化对象的属性。参数包括神经网络模块 module、设备 ID 列表 device_ids、输出设备 output_device 等。这些参数都有默认值,可以在初始化对象时提供或使用默认值

-> None: 表示初始化方法没有返回值

总体而言,这段代码定义了一个分布式数据并行训练的模块 DistributedDataParallel,该模块可以在多个设备上并行处理神经网络模块,实现分布式训练

二、创建分布式模型

这段代码创建了一个分布式数据并行(DDP)模型,并在必要时进行版本检查

根据 PyTorch 版本的不同,采取不同的配置参数来创建 DDP 模型

python 复制代码
def smart_DDP(model):  # 定义了一个名为 smart_DPP 的函数,该函数接受一个参数model,表示神经网络模型
    # Model DDP creation with checks  版本检查,其目的是确保不使用不受支持的PyTorch版本进行DDP训练
    assert not check_version(torch.__version__, '1.12.0', pinned=True), \
        'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
        'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
    #
    '''
    DDP 模型的创建:
    if check_version(torch.__version__, '1.11.0'): ...: 如果 PyTorch 版本为 1.11.0,则使用 DDP 类创建 DDP 模型,并设置 static_graph=True
    else: ...: 如果 PyTorch 版本不为 1.11.0,则使用 DDP 类创建 DDP 模型
    '''
    if check_version(torch.__version__, '1.11.0'):
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True): 返回创建的 DDP 模型对象
        device_ids 表示设备 ID 列表,这里设置为 [LOCAL_RANK],而 output_device 表示输出设备,也设置为 LOCAL_RANK
        如果 PyTorch 版本为 1.11.0,则 static_graph 被设置为 True
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
    else:
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

其中,static_graph=True 是在创建分布式数据并行模型时传递给 DDP 类的一个参数

这个参数的作用是告诉 PyTorch 是否使用静态图,如果将 static_graph 设置为 True,则表示希望使用静态图,这在某些情况下可以提高分布式训练的效率,尤其是在一些特定的 PyTorch 版本中可能需要使用静态图以避免问题

在 PyTorch 中,static_graph 参数是用于控制动态图(Dynamic Computational Graph)和静态图(Static Computational Graph)的一个设置。动态图和静态图是两种不同的计算图构建方式:

  1. 动态图(Dynamic Computational Graph)

    • 在动态图中,计算图是在运行时动态构建的,每次迭代都可以改变图的结构
    • PyTorch 的默认行为是使用动态图,这使得在模型训练过程中可以更灵活地调整模型结构
  2. 静态图(Static Computational Graph)

    • 在静态图中,计算图在模型定义阶段就被固定,不再改变。这意味着一旦定义了计算图,就无法在运行时修改
    • 静态图的优点之一是可以进行一些优化,例如静态图可以被预先分析以进行优化,从而提高计算效率
相关推荐
AC赳赳老秦36 分钟前
OpenClaw+Power Apps 实战:自动生成 Power Apps 应用、连接 Excel 数据源
大数据·开发语言·python·serverless·excel·deepseek·openclaw
JiaHao汤1 小时前
分布式事务方案全景:从理论到 Seata 落地
java·分布式·spring·spring cloud
茉莉玫瑰花茶2 小时前
综合案例 - AI 智能租房助手 [ 5 ]
服务器·数据库·人工智能·python·ai
文艺倾年2 小时前
【强化学习】强化学习基本概念,20W字总结(一)
人工智能·python·语言模型·自然语言处理·面试·职场和发展·大模型
宸丶一2 小时前
Day 13:持久化记忆 - 让 Agent 拥有长期记忆
jvm·python·ai
南部余额3 小时前
RabbitMQ 进阶:延迟队列完全指南
java·分布式·spring·rabbitmq
码云骑士3 小时前
13-列表append的底层真相(上)-listobject源码中的预分配策略
开发语言·python
浦信仿真大讲堂3 小时前
达索系统SIMULIA Abaqus 2026接触和约束的增强新功能介绍
人工智能·python·算法·仿真软件·达索软件
xufengzhu3 小时前
第三方 Python 库 Loguru 的进阶实战
python·loguru
极光代码工作室4 小时前
基于深度学习的手写数字识别系统
人工智能·python·深度学习·神经网络·机器学习