PyTorch AMP 混合精度中的 scale
函数解析
混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler
类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale
函数是 GradScaler
的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。
本文将详细解析 scale
函数的作用、代码逻辑,以及 apply_scale
子函数的递归作用。
函数代码回顾
以下是 scale
函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py
torch 2.4.0+cu121版本
python
def scale(
self,
outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[
_MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
if isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
return iterable
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
1. 函数作用
scale
函数的主要作用是将输出张量(outputs
)按当前的缩放因子(self._scale
)进行缩放。它支持以下两种输入:
- 单个张量:直接将缩放因子乘以张量。
- 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。
当 AMP 功能未启用时(即 self._enabled
为 False
),scale
函数会直接返回原始的 outputs
,不执行任何缩放操作。
使用场景
- 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
- 支持多设备 :通过
_MultiDeviceReplicator
支持张量分布在多个设备(如多 GPU)的场景。
2. 核心代码解析
(1) 短路处理单个张量
当输入 outputs
是单个张量(torch.Tensor
)时,函数直接对其进行缩放:
python
if isinstance(outputs, torch.Tensor):
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
- 如果缩放因子
self._scale
尚未初始化,则调用_lazy_init_scale_growth_tracker
方法在指定设备上初始化缩放因子。 - 使用
outputs * self._scale
对张量进行缩放。这里使用了to(device=outputs.device)
确保缩放因子与张量在同一设备上。
这是单个张量输入的快速路径处理。
(2) 多张量递归处理逻辑
当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale
进行递归缩放:
python
stash: List[_MultiDeviceReplicator] = [] # 用于存储缩放因子对象
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
if isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
return iterable
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
apply_scale
子函数的作用
-
张量处理:
- 如果
val
是单个张量,检查stash
是否为空。 - 如果为空,初始化缩放因子对象
_MultiDeviceReplicator
,并存储在stash
中。 - 使用
stash[0].get(val.device)
获取对应设备上的缩放因子,并对张量进行缩放。
- 如果
-
递归处理可迭代对象:
- 如果
val
是一个可迭代对象,调用map(apply_scale, val)
,对其中的每个元素递归地调用apply_scale
。 - 如果输入是
list
或tuple
,则保持其原始类型。
- 如果
-
类型检查:
- 如果
val
既不是张量也不是可迭代对象,抛出错误。
- 如果
3. apply_scale
是递归函数吗?
是的,apply_scale
是一个递归函数。
递归逻辑
- 当输入为嵌套结构(如张量的列表或列表中的列表)时,
apply_scale
会递归调用自身,将缩放因子应用到最底层的张量。 - 递归的终止条件是
val
为单个张量(torch.Tensor
)。
示例:
假设输入为嵌套张量列表:
python
outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)
递归处理过程如下:
-
对
outputs
调用apply_scale
:- 第一个元素是张量
torch.tensor([1.0, 2.0])
,直接缩放。 - 第二个元素是列表,递归调用
apply_scale
。
- 第一个元素是张量
-
进入嵌套列表
[torch.tensor([3.0]), torch.tensor([4.0, 5.0])]
:- 第一个元素是张量
torch.tensor([3.0])
,缩放。 - 第二个元素是张量
torch.tensor([4.0, 5.0])
,缩放。
- 第一个元素是张量
4. _MultiDeviceReplicator
的作用
_MultiDeviceReplicator
是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。
- 当张量分布在多个设备(如 GPU)时,
_MultiDeviceReplicator
可以高效地为每个设备提供所需的缩放因子,避免重复初始化。
总结
scale
函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale
,该函数能够处理嵌套的张量结构,同时支持多设备场景。
关键点总结:
- 快速路径:单张量输入的情况下,直接进行缩放。
- 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
- 设备管理 :通过
_MultiDeviceReplicator
支持多设备场景。
通过 scale
函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。
后记
2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。