PyTorch中的autocast:混合精度训练的实现原理

PyTorch中的autocast:混合精度训练的实现原理

在深度学习中,尤其是在训练大型模型时,计算资源和显存消耗往往是非常关键的因素。为了优化这些问题,混合精度训练 (Mixed Precision Training)应运而生。autocast 是 PyTorch 提供的一个工具,用于自动处理混合精度训练中的数值类型选择,使得计算能在尽量减少精度损失的同时,提升性能。

1. 什么是autocast

autocast 是 PyTorch 中用于启用自动混合精度的上下文管理器。它可以使代码中的指定部分自动选择合适的浮点数精度(例如 float16bfloat16),以提高计算效率并节省显存,同时尽量保持模型的训练精度。

  • 目的:优化性能,减少显存占用。
  • 实现方式 :在指定的代码区域内,自动选择使用较低精度(如 float16bfloat16)进行计算。计算结束后,返回高精度(如 float32)以进行梯度更新。

2. autocast如何工作?

工作流程

autocast 基本上是一个上下文管理器(Context Manager),其工作原理如下:

  1. 进入autocast上下文

    • 当代码进入 autocast 上下文时,PyTorch 会自动将相关操作(如矩阵乘法、卷积等)切换到 低精度浮点数 (通常为 float16bfloat16),以提高计算速度和节省显存。
  2. 操作类型选择

    • autocast 会根据硬件和设备类型(例如 CUDACPU)自动选择合适的精度。对于 CUDA 设备,通常使用 float16 ;而对于 CPU 设备,使用 bfloat16
  3. 返回到默认精度

    • 计算完成后,autocast 会退出上下文并将所有变量恢复到 默认精度 (通常是 float32)。这对于梯度计算和权重更新至关重要,因为在低精度下进行梯度计算可能会导致数值不稳定或精度损失。
  4. 避免梯度和权重更新中的精度丢失

    • autocast 内部进行的前向传播计算使用低精度(float16bfloat16),但 梯度计算和权重更新 操作仍然在 float32 精度下进行,以保证数值稳定性。
具体代码解析

以下是 PyTorch 中 autocast 的基本使用示例:

python 复制代码
import torch
from torch import nn, optim
from torch.cuda.amp import autocast, GradScaler

# 创建一个模型和优化器
model = nn.Linear(10, 1).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# GradScaler 用于 Loss Scaling
scaler = GradScaler()

# 假设有一个训练循环
for epoch in range(10):
    optimizer.zero_grad()

    # 进入 autocast 上下文
    with autocast(device_type="cuda"):
        inputs = torch.randn(32, 10).cuda()  # 输入数据
        targets = torch.randn(32, 1).cuda()  # 目标数据
        outputs = model(inputs)
        loss = nn.MSELoss()(outputs, targets)

    # 使用 GradScaler 进行 Loss Scaling 和反向传播
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

3. autocast 代码解析

安装完torch之后,在类似于下面的路径下可以找到源码:[~/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/autocast_mode.py, 源码放到了文末。

我们深入分析autocast的实现代码,理解它是如何工作的:

__enter__方法:

当进入 autocast 上下文时,__enter__ 方法会被调用。此方法完成以下操作:

  • 保存当前的设置

    • 保存设备的 autocast 状态、数据类型(dtype)以及缓存设置。
  • 启用 autocast

    • 通过 torch.set_autocast_enabled() 启用指定设备的 autocast,并设置数据类型(如 float16bfloat16)。
  • 开启缓存

    • 启用或禁用 autocast 的缓存功能,优化计算性能。
python 复制代码
def __enter__(self):
    self.prev_cache_enabled = torch.is_autocast_cache_enabled()
    self.prev = torch.is_autocast_enabled(self.device)
    self.prev_fastdtype = torch.get_autocast_dtype(self.device)
    torch.set_autocast_enabled(self.device, self._enabled)
    torch.set_autocast_dtype(self.device, self.fast_dtype)  # type: ignore[arg-type]
    torch.autocast_increment_nesting()
    torch.set_autocast_cache_enabled(self._cache_enabled)

如果读者对def __enter__(self)函数中的set_autocast_dtype这样的函数实现感兴趣,可以参考笔者的另一篇博客:PyTorch中的__init__.pyi文件:作用与C++实现关系解析

__exit__方法:

当退出 autocast 上下文时,__exit__ 方法会被调用:

  • 恢复原来的设置

    • 恢复原先的 autocast 状态、数据类型以及缓存设置。
  • 清理缓存

    • 在嵌套层级降至 0 时,调用 torch.clear_autocast_cache() 清理缓存,以释放内存。
python 复制代码
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
    if torch.autocast_decrement_nesting() == 0:
        torch.clear_autocast_cache()
    torch.set_autocast_enabled(self.device, self.prev)
    torch.set_autocast_dtype(self.device, self.prev_fastdtype)
    torch.set_autocast_cache_enabled(self.prev_cache_enabled)
    return False
__call__方法:

此外,autocast 还可以作为装饰器使用:

python 复制代码
def __call__(self, func):
    return autocast_decorator(self, func)

这个方法允许你将 autocast 直接应用到函数上,简化代码结构。

4. 使用 autocast 进行混合精度训练

在训练过程中,autocast 会自动选择适当的精度,以确保高效训练:

  • 前向传播(Forward Pass) :大部分操作使用低精度(float16bfloat16)进行加速计算。
  • 反向传播(Backward Pass) :梯度计算仍使用 float32 精度,避免由于精度不足引起的数值不稳定。
  • 梯度更新 :权重更新在 float32 精度下进行,确保模型稳定收敛。

5. autocastGradScaler的结合

为了进一步提高混合精度训练的稳定性,PyTorch 还提供了 GradScaler,用来进行 Loss Scaling。其目的是防止低精度计算中的梯度溢出或下溢:

  • GradScaler 将损失放大,然后进行反向传播,确保梯度数值保持在合理范围内。
  • 通过 scaler.scale(loss).backward()scaler.step(optimizer) 来执行带有缩放的反向传播和优化器步骤。

总结

  • autocast 是 PyTorch 提供的一种自动混合精度工具,可以在训练过程中自动选择适当的数据类型,从而加速计算和减少显存使用。
  • autocast 在前向传播时使用低精度(如 float16bfloat16),但梯度计算和权重更新始终保持 float32 精度,以保证数值稳定性。
  • 通过与 GradScaler 配合使用,能够确保混合精度训练在节省资源的同时,避免精度丢失或梯度下溢问题。

使用 autocast 使得大规模深度学习模型的训练更高效,同时保持较高的精度和稳定性,特别适用于高性能计算环境中的训练任务。

附录:pytorch源码

go 复制代码
class autocast:
    r"""
    Instances of :class:`autocast` serve as context managers or decorators that
    allow regions of your script to run in mixed precision.

    In these regions, ops run in an op-specific dtype chosen by autocast
    to improve performance while maintaining accuracy.
    See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.

    When entering an autocast-enabled region, Tensors may be any type.
    You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.

    :class:`autocast` should wrap only the forward pass(es) of your network, including the loss
    computation(s).  Backward passes under autocast are not recommended.
    Backward ops run in the same type that autocast used for corresponding forward ops.

    Example for CUDA Devices::

        # Creates model and optimizer in default precision
        model = Net().cuda()
        optimizer = optim.SGD(model.parameters(), ...)

        for input, target in data:
            optimizer.zero_grad()

            # Enables autocasting for the forward pass (model + loss)
            with torch.autocast(device_type="cuda"):
                output = model(input)
                loss = loss_fn(output, target)

            # Exits the context manager before backward()
            loss.backward()
            optimizer.step()

    See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
    in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).

    :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::

        class AutocastModel(nn.Module):
            ...
            @torch.autocast(device_type="cuda")
            def forward(self, input):
                ...

    Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
    After returning to an autocast-disabled region, using them with floating-point
    Tensors of different dtypes may cause type mismatch errors.  If so, cast the Tensor(s)
    produced in the autocast region back to ``float32`` (or other dtype if desired).
    If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
    and incurs no additional overhead.
    CUDA Example::

        # Creates some tensors in default dtype (here assumed to be float32)
        a_float32 = torch.rand((8, 8), device="cuda")
        b_float32 = torch.rand((8, 8), device="cuda")
        c_float32 = torch.rand((8, 8), device="cuda")
        d_float32 = torch.rand((8, 8), device="cuda")

        with torch.autocast(device_type="cuda"):
            # torch.mm is on autocast's list of ops that should run in float16.
            # Inputs are float32, but the op runs in float16 and produces float16 output.
            # No manual casts are required.
            e_float16 = torch.mm(a_float32, b_float32)
            # Also handles mixed input types
            f_float16 = torch.mm(d_float32, e_float16)

        # After exiting autocast, calls f_float16.float() to use with d_float32
        g_float32 = torch.mm(d_float32, f_float16.float())

    CPU Training Example::

        # Creates model and optimizer in default precision
        model = Net()
        optimizer = optim.SGD(model.parameters(), ...)

        for epoch in epochs:
            for input, target in data:
                optimizer.zero_grad()

                # Runs the forward pass with autocasting.
                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
                    output = model(input)
                    loss = loss_fn(output, target)

                loss.backward()
                optimizer.step()


    CPU Inference Example::

        # Creates model in default precision
        model = Net().eval()

        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
            for input in data:
                # Runs the forward pass with autocasting.
                output = model(input)

    CPU Inference Example with Jit Trace::

        class TestModel(nn.Module):
            def __init__(self, input_size, num_classes):
                super().__init__()
                self.fc1 = nn.Linear(input_size, num_classes)
            def forward(self, x):
                return self.fc1(x)

        input_size = 2
        num_classes = 2
        model = TestModel(input_size, num_classes).eval()

        # For now, we suggest to disable the Jit Autocast Pass,
        # As the issue: https://github.com/pytorch/pytorch/issues/75956
        torch._C._jit_set_autocast_mode(False)

        with torch.cpu.amp.autocast(cache_enabled=False):
            model = torch.jit.trace(model, torch.randn(1, input_size))
        model = torch.jit.freeze(model)
        # Models Run
        for _ in range(3):
            model(torch.randn(1, input_size))

    Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
    please file an issue.

    ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
    Locally disabling autocast can be useful, for example, if you want to force a subregion
    to run in a particular ``dtype``.  Disabling autocast gives you explicit control over
    the execution type.  In the subregion, inputs from the surrounding region
    should be cast to ``dtype`` before use::

        # Creates some tensors in default dtype (here assumed to be float32)
        a_float32 = torch.rand((8, 8), device="cuda")
        b_float32 = torch.rand((8, 8), device="cuda")
        c_float32 = torch.rand((8, 8), device="cuda")
        d_float32 = torch.rand((8, 8), device="cuda")

        with torch.autocast(device_type="cuda"):
            e_float16 = torch.mm(a_float32, b_float32)
            with torch.autocast(device_type="cuda", enabled=False):
                # Calls e_float16.float() to ensure float32 execution
                # (necessary because e_float16 was created in an autocasted region)
                f_float32 = torch.mm(c_float32, e_float16.float())

            # No manual casts are required when re-entering the autocast-enabled region.
            # torch.mm again runs in float16 and produces float16 output, regardless of input types.
            g_float16 = torch.mm(d_float32, f_float32)

    The autocast state is thread-local.  If you want it enabled in a new thread, the context manager or decorator
    must be invoked in that thread.  This affects :class:`torch.nn.DataParallel` and
    :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
    (see :ref:`Working with Multiple GPUs<amp-multigpu>`).

    Args:
        device_type(str, required):  Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
                                     The type is the same as the `type` attribute of a :class:`torch.device`.
                                     Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
        enabled(bool, optional):  Whether autocasting should be enabled in the region.
            Default: ``True``
        dtype(torch_dtype, optional):  Data type for ops run in autocast. It uses the default value
            (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by
            :func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``.
            Default: ``None``
        cache_enabled(bool, optional):  Whether the weight cache inside autocast should be enabled.
            Default: ``True``
    """

    def __init__(
        self,
        device_type: str,
        dtype: Optional[_dtype] = None,
        enabled: bool = True,
        cache_enabled: Optional[bool] = None,
    ):
        if not isinstance(device_type, str):
            raise ValueError(
                f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
            )
        if dtype is None:
            dtype = torch.get_autocast_dtype(device_type)
        if torch._jit_internal.is_scripting():
            self._enabled = enabled
            self.device = device_type
            self.fast_dtype = dtype
            assert dtype is not None
            return
        self.device = device_type
        if not is_autocast_available(self.device):
            raise RuntimeError(
                f"User specified an unsupported autocast device_type '{self.device}'"
            )
        self.custom_backend_name = torch._C._get_privateuse1_backend_name()
        self.fast_dtype = torch.get_autocast_dtype(self.device)
        if self.device == self.custom_backend_name:
            necessary_funcs = [
                "get_amp_supported_dtype",
            ]
            message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
            message += "registered a module or  the module miss some necessary funcs. The backend should register "
            message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
            message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"

            assert hasattr(torch, self.custom_backend_name), message
            self.custom_device_mod = getattr(torch, self.custom_backend_name)
            for func in necessary_funcs:
                assert hasattr(self.custom_device_mod, func), (
                    message + f"But the func `{func}` is missing. \n"
                )

        self._cache_enabled = torch.is_autocast_cache_enabled()
        if (
            enabled
            and torch.cuda.amp.common.amp_definitely_not_available()
            and self.device == "cuda"
        ):
            warnings.warn(
                "User provided device_type of 'cuda', but CUDA is not available. Disabling"
            )
            enabled = False
        if dtype is not None:
            self.fast_dtype = dtype
        if cache_enabled is not None:
            self._cache_enabled = cache_enabled

        if self.device == "cpu":
            supported_dtype = [torch.bfloat16, torch.float16]
            if self.fast_dtype not in supported_dtype and enabled:
                error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
                error_message += "CPU Autocast only supports dtype of "
                error_message += (
                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
                )
                warnings.warn(error_message)
                enabled = False
        elif self.device == "xpu":
            supported_dtype = [torch.bfloat16, torch.float16]
            if self.fast_dtype not in supported_dtype:
                error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
                error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
                warnings.warn(error_message)
                enabled = False
        elif self.device == "ipu":
            supported_dtypes = [torch.bfloat16, torch.float16]
            if self.fast_dtype not in supported_dtypes:
                error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
                error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
                warnings.warn(error_message)
                enabled = False
        elif self.device == "hpu":
            supported_dtype = [torch.bfloat16, torch.float16]
            if self.fast_dtype not in supported_dtype:
                error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
                error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
                warnings.warn(error_message)
                enabled = False
        elif self.device == self.custom_backend_name:
            supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
            if self.fast_dtype not in supported_dtype:
                error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
                error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
                error_message += (
                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
                )
                warnings.warn(error_message)
                enabled = False
        elif self.device == "cuda":
            if (
                enabled
                and self.fast_dtype == torch.bfloat16
                and not torch.cuda.is_bf16_supported()
            ):
                raise RuntimeError(
                    "Current CUDA Device does not support bfloat16. Please switch dtype to float16."
                )
        elif self.device == "xla":
            supported_dtype = [torch.float16, torch.bfloat16]
            if self.fast_dtype not in supported_dtype:
                error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
                error_message += (
                    "XLA Autocast only supports dtype of torch.bfloat16 currently."
                )
                warnings.warn(error_message)
                enabled = False
        self._enabled = enabled

    def __enter__(self):
        if torch._jit_internal.is_scripting():
            assert self.fast_dtype is not None
            return self

        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
        self.prev = torch.is_autocast_enabled(self.device)
        self.prev_fastdtype = torch.get_autocast_dtype(self.device)
        torch.set_autocast_enabled(self.device, self._enabled)
        torch.set_autocast_dtype(self.device, self.fast_dtype)  # type: ignore[arg-type]
        torch.autocast_increment_nesting()
        torch.set_autocast_cache_enabled(self._cache_enabled)

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
        if torch._jit_internal.is_scripting():
            return

        # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
        if torch.autocast_decrement_nesting() == 0:
            torch.clear_autocast_cache()
        torch.set_autocast_enabled(self.device, self.prev)
        torch.set_autocast_dtype(self.device, self.prev_fastdtype)
        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
        return False

    def __call__(self, func):
        if torch._jit_internal.is_scripting():
            return func
        return autocast_decorator(self, func)

后记

2024年12月31日21点10分于上海, 在GPT4o大模型辅助下完成。

相关推荐
IT古董几秒前
【漫话机器学习系列】039.点积(dot product)
人工智能·线性代数·机器学习
leisigoyle5 分钟前
第四届智能系统、通信与计算机网络国际学术会议(ISCCN 2025)
网络·人工智能·计算机网络
小张认为的测试9 分钟前
Selenium 浏览器驱动代理 - 无需下载本地浏览器驱动镜像!(Java 版本!)
java·python·selenium·测试工具·浏览器
8Qi811 分钟前
多目标优化算法——基于聚类的不规则Pareto前沿多目标优化自适应进化算法(CA-MOEA)
人工智能·算法·多目标优化·进化计算·群体智能·ca-moea
泡芙萝莉酱14 分钟前
中国税务年鉴PDF电子版Excel2022年-社科数据
大数据·人工智能·深度学习·数据挖掘·数据分析·毕业论文·统计年鉴
AI+程序员在路上23 分钟前
OpenCV轮廓相关操作API (C++)
c++·人工智能·opencv
好评笔记26 分钟前
多模态论文笔记——GLIDE(DALL·E 2模型核心部件)
论文阅读·人工智能·深度学习·aigc·transformer·glide·dall·e 2
Zilliz Planet27 分钟前
Milvus×全诊通:从导诊到智能超声,如何将人效比翻倍
人工智能·milvus
ZPC821030 分钟前
MoveItConfigsBuilder 配置机器人的完整示例
c++·人工智能·机器人
光锥智能31 分钟前
周亚辉投资笔记2025系列第1篇:机器人时代的社会结构模型与十年后中国首富预测
人工智能