PyTorch 自动混合精度AMP Grad Scaler 源码解析:_unscale_grads_ 与 unscale_ 函数

PyTorch AMP Grad Scaler 源码解析:unscale_grads 与 unscale_ 函数

引言

本文详细解析 PyTorch 自动混合精度(AMP)模块中 grad_scaler.py 文件的两个关键函数:_unscale_grads_unscale_。这些函数在梯度缩放与反缩放过程中起到了关键作用,特别适用于训练大规模深度学习模型时的数值稳定性优化。我们还将给出详细的示例与数值模拟,帮助理解其具体应用。


1. _unscale_grads_ 函数解析

go 复制代码
def _unscale_grads_(
        self,
        optimizer: torch.optim.Optimizer,
        inv_scale: torch.Tensor,
        found_inf: torch.Tensor,
        allow_fp16: bool,
    ) -> Dict[torch.device, torch.Tensor]:
        per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
        per_device_found_inf = _MultiDeviceReplicator(found_inf)

        # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
        # There could be hundreds of grads, so we'd like to iterate through them just once.
        # However, we don't know their devices or dtypes in advance.

        # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
        # Google says mypy struggles with defaultdicts type annotations.
        per_device_and_dtype_grads: Dict[
            torch.device, Dict[torch.dtype, List[torch.Tensor]]
        ] = defaultdict(lambda: defaultdict(list))
        with torch.no_grad():
            for group in optimizer.param_groups:
                for param in group["params"]:
                    assert isinstance(param, torch.Tensor)
                    if param.grad is None:
                        continue
                    if (not allow_fp16) and param.grad.dtype == torch.float16:
                        raise ValueError("Attempting to unscale FP16 gradients.")
                    if param.grad.is_sparse:
                        # is_coalesced() == False means the sparse grad has values with duplicate indices.
                        # coalesce() deduplicates indices and adds all values that have the same index.
                        # For scaled fp16 values, there's a good chance coalescing will cause overflow,
                        # so we should check the coalesced _values().
                        if param.grad.dtype is torch.float16:
                            param.grad = param.grad.coalesce()
                        to_unscale = param.grad._values()
                    else:
                        to_unscale = param.grad

                    # TODO: is there a way to split by device and dtype without appending in the inner loop?
                    per_device_and_dtype_grads[to_unscale.device][
                        to_unscale.dtype
                    ].append(to_unscale)

            for device, per_dtype_grads in per_device_and_dtype_grads.items():
                for grads in per_dtype_grads.values():
                    torch._amp_foreach_non_finite_check_and_unscale_(
                        grads,
                        per_device_found_inf.get(device),
                        per_device_inv_scale.get(device),
                    )

        return per_device_found_inf._per_device_tensors

1.1 函数定义

python 复制代码
def _unscale_grads_(
        self,
        optimizer: torch.optim.Optimizer,
        inv_scale: torch.Tensor,
        found_inf: torch.Tensor,
        allow_fp16: bool,
    ) -> Dict[torch.device, torch.Tensor]:

该函数主要用于将梯度从缩放状态恢复到原始大小,同时检查是否存在数值溢出情况。

1.2 参数说明

  • optimizer:优化器对象,包含训练过程中使用的所有参数。
  • inv_scale:缩放因子的倒数,用于恢复梯度。
  • found_inf:用于记录是否存在无穷大或 NaN 值。
  • allow_fp16:是否允许 FP16 精度的梯度反缩放,默认设置为 False。

1.3 核心实现步骤

  1. 按设备与数据类型分类梯度:

    • 将优化器中的参数按设备和数据类型进行分组,便于批量处理。
    • 使用 defaultdict 对分组存储。
  2. 检查梯度并分类:

    • 遍历每个参数,如果存在稀疏梯度,使用 coalesce() 消除重复索引。关于这个方法, 可以参考笔者的另一篇博客:PyTorch 中 coalesce() 函数详解与应用示例
    • 将梯度分组存储到 per_device_and_dtype_grads 中。
  3. 调用 PyTorch 内部函数反缩放梯度:

  4. 返回各设备上的溢出检查结果:

    • 输出包含各设备是否发现溢出的布尔值张量。

1.4 关键代码片段

python 复制代码
with torch.no_grad():
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is None:
                continue
            if (not allow_fp16) and param.grad.dtype == torch.float16:
                raise ValueError("Attempting to unscale FP16 gradients.")
            
            to_unscale = param.grad._values() if param.grad.is_sparse else param.grad
            per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)

for device, per_dtype_grads in per_device_and_dtype_grads.items():
    for grads in per_dtype_grads.values():
        torch._amp_foreach_non_finite_check_and_unscale_(
            grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)
        )

2. unscale_ 函数解析

go 复制代码
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
        """
        Divides ("unscales") the optimizer's gradient tensors by the scale factor.

        :meth:`unscale_` is optional, serving cases where you need to
        :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
        between the backward pass(es) and :meth:`step`.
        If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.

        Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::

            ...
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            scaler.step(optimizer)
            scaler.update()

        Args:
            optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.

        .. note::
            :meth:`unscale_` does not incur a CPU-GPU sync.

        .. warning::
            :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
            and only after all gradients for that optimizer's assigned parameters have been accumulated.
            Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.

        .. warning::
            :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
        """
        if not self._enabled:
            return

        self._check_scale_growth_tracker("unscale_")

        optimizer_state = self._per_optimizer_states[id(optimizer)]

        if optimizer_state["stage"] is OptState.UNSCALED:
            raise RuntimeError(
                "unscale_() has already been called on this optimizer since the last update()."
            )
        elif optimizer_state["stage"] is OptState.STEPPED:
            raise RuntimeError("unscale_() is being called after step().")

        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
        assert self._scale is not None
        inv_scale = self._scale.double().reciprocal().float()
        found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)

        optimizer_state["found_inf_per_device"] = self._unscale_grads_(
            optimizer, inv_scale, found_inf, False
        )
        optimizer_state["stage"] = OptState.UNSCALED

2.1 函数定义

python 复制代码
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:

该函数是 PyTorch AMP 提供的外部接口,供用户调用以解除梯度缩放。

2.2 参数说明

  • optimizer:包含所有待训练参数的优化器对象。

2.3 核心实现步骤

  1. 状态检查:
    • 检查是否已经调用过 unscale_step
  2. 计算反缩放因子:
  3. 调用内部函数 _unscale_grads_
    • 执行反缩放过程,包含稀疏梯度与 NaN 检查。
  4. 更新状态记录:
    • 将优化器状态更新为 "UNSCALED"。

2.4 关键代码片段

python 复制代码
if optimizer_state["stage"] is OptState.UNSCALED:
    raise RuntimeError(
        "unscale_() has already been called on this optimizer since the last update()."
    )

inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)

optimizer_state["found_inf_per_device"] = self._unscale_grads_(
    optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED

3. 使用示例与数值模拟

3.1 示例代码

python 复制代码
import torch
from torch.cuda.amp import GradScaler, autocast

# 创建模型和优化器
model = torch.nn.Linear(10, 1).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scaler = GradScaler()

# 模拟训练循环
for epoch in range(2):
    for step in range(5):
        data = torch.randn(16, 10).cuda()
        target = torch.randn(16, 1).cuda()
        optimizer.zero_grad()

        # 使用混合精度训练
        with autocast():
            output = model(data)
            loss = torch.nn.functional.mse_loss(output, target)

        # 缩放梯度
        scaler.scale(loss).backward()

        # 手动解除梯度缩放
        scaler.unscale_(optimizer)

        # 使用梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # 更新权重与缩放器
        scaler.step(optimizer)
        scaler.update()

        print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")

3.2 数值模拟分析

  1. 梯度缩放影响:
    缩放因子 = 65536 时,梯度放大至 10^4 量级,有助于 FP16 避免下溢问题。
  2. 反缩放结果验证:
    对比反缩放前后的梯度值,可观察到恢复精度并避免溢出错误。
  3. 梯度裁剪测试:
    执行 torch.nn.utils.clip_grad_norm_(),确认反缩放后的梯度值能够被安全裁剪。

4. 注意事项与总结

  1. 注意 API 使用顺序:
    调用 unscale_ 应在反向传播完成后、优化器更新前进行。
  2. 防止重复调用:
    多次调用可能导致状态不一致,应确保每轮训练仅调用一次。
  3. 稀疏梯度支持:
    自动处理稀疏梯度的特殊情况,避免溢出。

这两个函数是 AMP 核心模块,提供了稳定高效的混合精度训练支持。通过示例与数值分析,开发者可以更好地理解 AMP 工作原理并优化深度学习模型训练过程。


后记

2025年1月2日18点49分于上海,在GPT4o大模型辅助下完成。

相关推荐
Fishel-1 小时前
预测facebook签到位置
人工智能·python·算法·机器学习·近邻算法·facebook
是阿静呀1 小时前
新手学习yolov8目标检测小记2--对比实验中经典模型库MMDetection使用方法(使用自己的数据集训练,并转换为yolo格式评价指标)
python·学习·yolo·目标检测
道友老李1 小时前
【PyTorch】实现卷积神经网络:使用CNN进行手写数字识别
人工智能·pytorch·cnn
视觉语言导航1 小时前
技术实践︱利用Docker快速体验Matterport3DSimulator!让视觉语言导航(VLN)任务入门再无门槛!
人工智能·docker·具身智能
luoganttcc2 小时前
香橙派安装 opencv 4.9.0
人工智能·opencv·webpack
技术程序猿华锋2 小时前
Cursor AI 编程代码助手:设置自定义 AI 与 OpenAI API Key 获取教程
人工智能
老大白菜2 小时前
Python 实现 冒泡排序算法示例
数据结构·python·算法
菠萝派爱跨境2 小时前
电商Google广告:2025年提升转化率的5种策略
大数据·人工智能
桂月二二2 小时前
解锁2025编程新高度:深入探索编程技术的最新趋势
前端·人工智能·flutter·neo4j·wasm
西电研梦3 小时前
西安电子科技大学初/复试笔试、面试、机试成绩占比
人工智能·考研·面试·职场和发展·研究生·西电·西安电子科技大学