一 、分布式训练类的定义
python
from ..modules import Module
from typing import Any, Optional
from .common_types import _devices_t, _device_t
class DistributedDataParallel(Module):
process_group: Any = ...
dim: int = ...
module: Module = ...
device_ids: _devices_t = ...
output_device: _device_t = ...
broadcast_buffers: bool = ...
check_reduction: bool = ...
broadcast_bucket_size: float = ...
bucket_bytes_cap: float = ...
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
output_device: Optional[_device_t] = ..., dim: int = ...,
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...
Python 类的定义,该类名为 DistributedDataParallel
,是 PyTorch 中用于分布式数据并行训练的模块
from ..modules import Module :导入 PyTorch 中的 Module 类,表示神经网络模块的基类
from typing import Any, Optional : 导入 Any 和 Optional 类型,用于类型注解
from .common_types import _devices_t, _device_t:导入 _devices_t 和 _device_t 类型
class DistributedDataParallel(Module): 定义了一个类 DistributedDataParallel,它继承自Module类
类属性:
process_group: Any = ... : 代表分布式训练的进程组
dim: int = ...: 代表分布式的维度
module: Module = ...:代表要进行并行处理的神经网络模块
device_ids: _devices_t = ...:代表设备的 ID 列表
output_device: _device_t = ...:代表输出设备
broadcast_buffers: bool = ...:是否广播缓冲区
check_reduction: bool = ...:是否检查减少操作
broadcast_bucket_size: float = ...: 广播桶大小
bucket_bytes_cap: float = ...:桶的字节容量上限
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
output_device: Optional[_device_t] = ..., dim: int = ...,
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...:
def __init__(self, ...)
:类的初始化方法,用于初始化对象的属性。参数包括神经网络模块 module
、设备 ID 列表 device_ids
、输出设备 output_device
等。这些参数都有默认值,可以在初始化对象时提供或使用默认值
-> None
: 表示初始化方法没有返回值
总体而言,这段代码定义了一个分布式数据并行训练的模块 DistributedDataParallel
,该模块可以在多个设备上并行处理神经网络模块,实现分布式训练
二、创建分布式模型
这段代码创建了一个分布式数据并行(DDP)模型,并在必要时进行版本检查
根据 PyTorch 版本的不同,采取不同的配置参数来创建 DDP 模型
python
def smart_DDP(model): # 定义了一个名为 smart_DPP 的函数,该函数接受一个参数model,表示神经网络模型
# Model DDP creation with checks 版本检查,其目的是确保不使用不受支持的PyTorch版本进行DDP训练
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
#
'''
DDP 模型的创建:
if check_version(torch.__version__, '1.11.0'): ...: 如果 PyTorch 版本为 1.11.0,则使用 DDP 类创建 DDP 模型,并设置 static_graph=True
else: ...: 如果 PyTorch 版本不为 1.11.0,则使用 DDP 类创建 DDP 模型
'''
if check_version(torch.__version__, '1.11.0'):
'''
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True): 返回创建的 DDP 模型对象
device_ids 表示设备 ID 列表,这里设置为 [LOCAL_RANK],而 output_device 表示输出设备,也设置为 LOCAL_RANK
如果 PyTorch 版本为 1.11.0,则 static_graph 被设置为 True
'''
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
其中,static_graph=True
是在创建分布式数据并行模型时传递给 DDP
类的一个参数
这个参数的作用是告诉 PyTorch 是否使用静态图,如果将 static_graph
设置为 True
,则表示希望使用静态图,这在某些情况下可以提高分布式训练的效率,尤其是在一些特定的 PyTorch 版本中可能需要使用静态图以避免问题
在 PyTorch 中,static_graph
参数是用于控制动态图(Dynamic Computational Graph)和静态图(Static Computational Graph)的一个设置。动态图和静态图是两种不同的计算图构建方式:
-
动态图(Dynamic Computational Graph):
- 在动态图中,计算图是在运行时动态构建的,每次迭代都可以改变图的结构
- PyTorch 的默认行为是使用动态图,这使得在模型训练过程中可以更灵活地调整模型结构
-
静态图(Static Computational Graph):
- 在静态图中,计算图在模型定义阶段就被固定,不再改变。这意味着一旦定义了计算图,就无法在运行时修改
- 静态图的优点之一是可以进行一些优化,例如静态图可以被预先分析以进行优化,从而提高计算效率