PyTorch AMP Grad Scaler 源码解析:unscale_grads 与 unscale_ 函数
引言
本文详细解析 PyTorch 自动混合精度(AMP)模块中 grad_scaler.py
文件的两个关键函数:_unscale_grads_
和 unscale_
。这些函数在梯度缩放与反缩放过程中起到了关键作用,特别适用于训练大规模深度学习模型时的数值稳定性优化。我们还将给出详细的示例与数值模拟,帮助理解其具体应用。
1. _unscale_grads_
函数解析
go
def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads: Dict[
torch.device, Dict[torch.dtype, List[torch.Tensor]]
] = defaultdict(lambda: defaultdict(list))
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
assert isinstance(param, torch.Tensor)
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
return per_device_found_inf._per_device_tensors
1.1 函数定义
python
def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool,
) -> Dict[torch.device, torch.Tensor]:
该函数主要用于将梯度从缩放状态恢复到原始大小,同时检查是否存在数值溢出情况。
1.2 参数说明
optimizer
:优化器对象,包含训练过程中使用的所有参数。inv_scale
:缩放因子的倒数,用于恢复梯度。found_inf
:用于记录是否存在无穷大或 NaN 值。allow_fp16
:是否允许 FP16 精度的梯度反缩放,默认设置为 False。
1.3 核心实现步骤
-
按设备与数据类型分类梯度:
- 将优化器中的参数按设备和数据类型进行分组,便于批量处理。
- 使用
defaultdict
对分组存储。
-
检查梯度并分类:
- 遍历每个参数,如果存在稀疏梯度,使用
coalesce()
消除重复索引。关于这个方法, 可以参考笔者的另一篇博客:PyTorch 中 coalesce() 函数详解与应用示例 - 将梯度分组存储到
per_device_and_dtype_grads
中。
- 遍历每个参数,如果存在稀疏梯度,使用
-
调用 PyTorch 内部函数反缩放梯度:
- 使用
torch._amp_foreach_non_finite_check_and_unscale_()
批量反缩放梯度并检查是否存在 NaN 或无穷大值。 这个具体解析请参考笔者的另一篇博客:PyTorch源码_amp_foreach_non_finite_check_and_unscale_cpu_kernel 函数解析:自动混合精度AMP的一部分
- 使用
-
返回各设备上的溢出检查结果:
- 输出包含各设备是否发现溢出的布尔值张量。
1.4 关键代码片段
python
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
to_unscale = param.grad._values() if param.grad.is_sparse else param.grad
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(
grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)
)
2. unscale_
函数解析
go
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
2.1 函数定义
python
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
该函数是 PyTorch AMP 提供的外部接口,供用户调用以解除梯度缩放。
2.2 参数说明
optimizer
:包含所有待训练参数的优化器对象。
2.3 核心实现步骤
- 状态检查:
- 检查是否已经调用过
unscale_
或step
。
- 检查是否已经调用过
- 计算反缩放因子:
- 使用 FP64 精度计算缩放因子的倒数,以避免精度误差。reciprocal这是取倒数的函数,具体可以参考笔者的另一篇博客:PyTorch 中 reciprocal(取倒数)函数的深入解析:分析底层实现CPP代码
- 调用内部函数
_unscale_grads_
:- 执行反缩放过程,包含稀疏梯度与 NaN 检查。
- 更新状态记录:
- 将优化器状态更新为 "UNSCALED"。
2.4 关键代码片段
python
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
3. 使用示例与数值模拟
3.1 示例代码
python
import torch
from torch.cuda.amp import GradScaler, autocast
# 创建模型和优化器
model = torch.nn.Linear(10, 1).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scaler = GradScaler()
# 模拟训练循环
for epoch in range(2):
for step in range(5):
data = torch.randn(16, 10).cuda()
target = torch.randn(16, 1).cuda()
optimizer.zero_grad()
# 使用混合精度训练
with autocast():
output = model(data)
loss = torch.nn.functional.mse_loss(output, target)
# 缩放梯度
scaler.scale(loss).backward()
# 手动解除梯度缩放
scaler.unscale_(optimizer)
# 使用梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 更新权重与缩放器
scaler.step(optimizer)
scaler.update()
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")
3.2 数值模拟分析
- 梯度缩放影响:
缩放因子 = 65536 时,梯度放大至 10^4 量级,有助于 FP16 避免下溢问题。 - 反缩放结果验证:
对比反缩放前后的梯度值,可观察到恢复精度并避免溢出错误。 - 梯度裁剪测试:
执行torch.nn.utils.clip_grad_norm_()
,确认反缩放后的梯度值能够被安全裁剪。
4. 注意事项与总结
- 注意 API 使用顺序:
调用unscale_
应在反向传播完成后、优化器更新前进行。 - 防止重复调用:
多次调用可能导致状态不一致,应确保每轮训练仅调用一次。 - 稀疏梯度支持:
自动处理稀疏梯度的特殊情况,避免溢出。
这两个函数是 AMP 核心模块,提供了稳定高效的混合精度训练支持。通过示例与数值分析,开发者可以更好地理解 AMP 工作原理并优化深度学习模型训练过程。
后记
2025年1月2日18点49分于上海,在GPT4o大模型辅助下完成。