【Pytorch】理解自动混合精度训练

【Pytorch】理解自动混合精度训练

更大的深度学习模型需要更多的计算能力和内存资源。一些新技术的提出,可以更快地训练深度神经网络。我们可以使用 FP16(半精度浮点数格式)来代替 FP32(全精度浮点数格式),研究人员发现串联使用它们是更好的选择。有的 GPU(例如 Paperspace 提供的 Ampere GPU)甚至可以利用较低级别的精度,例如 INT8。

混合精度允许半精度训练,同时仍保留大部分单精度网络精度。术语"混合精度技术"是指该方法同时使用单精度和半精度表示。

在使用 PyTorch 进行自动混合精度 (Amp) 训练的概述中,我们演示了该技术的工作原理,逐步介绍使用 Amp 的过程,并通过代码讨论 Amp 技术的应用。

混合精度概述

在深度学习的世界里,使用 FP16 进行计算不仅能显著提升性能,还能节省内存。然而,这种方法也带来了两个主要问题:精度溢出和舍入误差。这两个问题是深度学习中 FP16 计算的关键挑战。

精度溢出(Precision Overflow)

在 FP16 格式下,由于位宽较小,可表示的数值范围远小于 FP32 或 FP64。这容易导致数值过大或过小而无法在 FP16 的表示范围内精确表示。在深度学习中,这可能引起梯度消失或梯度爆炸,因为一些小的梯度值可能变成零(下溢),而一些大的梯度值可能变得无限大(上溢)。这种溢出问题会严重影响模型训练的稳定性和最终性能。

舍入误差(Rounding Error)

FP16 由于其16位的表示限制,相比于 FP32 或 FP64,舍入误差更加明显。在深度学习中,每次计算的舍入误差会累积,尤其是在多层和复杂运算中。这可能导致模型输出与使用更高精度计算时存在显著差异。对于那些对精确度要求极高的应用(比如金融或医疗领域),这种误差可能造成不可接受的后果。

为了缓解这些问题,混合精度训练方法在关键部分(如权重更新)使用 FP32 来保持精度,而在其他操作(如前向传播)中使用 FP16 来提高效率。混合精度训练中,我特别注意到了权重备份(Weight Backup)、损失放大(Loss Scaling)、精度累加(Precision Accumulated)这三种技术的重要性。

权重备份(Weight Backup):

在混合精度训练中,为了确保数值稳定性,模型的权重通常会在 FP16 和 FP32 两种格式下同时维护。权重备份是指保留 FP32 格式的权重副本,这样即使在大部分使用 FP16 格式的计算过程中出现数值不稳定现象,我们仍然能依靠 FP32 权重副本保持稳定和精确。这对于更新模型参数时的准确性至关重要。

损失放大(Loss Scaling):

在混合精度训练中,由于 FP16 的表示范围限制,梯度值可能太小而无法在 FP16 中准确表示,导致有效梯度变为零。损失放大是通过在计算梯度前将损失函数的值乘以一个较大的常数(放大因子),从而放大梯度值,使其在 FP16 范围内可表示且非零。在反向传播后,再将放大的梯度除以相同的放大因子,恢复原始比例,这样可以有效减少梯度下溢问题。

精度累加(Precision Accumulation):

精度累加是指在权重更新过程中,即使梯度计算是在 FP16 下完成的,但权重更新则在 FP32 精度下进行。这有助于减少舍入误差和累积误差,尤其在训练过程中涉及大量累积操作时。由于 FP32 提供更高的数值精度和更大的表示范围,可以更准确地累积小梯度值,避免更新权重时的数值不稳定性。

综上所述,通过将这些技术相结合,混合精度训练能够有效地利用 FP16 带来的性能优势,同时最大限度地减少精度损失和计算不稳定性。

实验对比

为了进一步验证这些技术的有效性,我设计了两个实验对比使用混合精度和传统的 FP32 两种方式进行训练。以下是我在这两个实验中使用的代码片段:

FP16与FP32混合训练代码:

python 复制代码
import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time

from torch.cuda.amp import GradScaler, autocast


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20000),

            nn.Linear(20000, 20000),
            # nn.Dropout(0.1),

            nn.Linear(20000, 200),
            # nn.LayerNorm(20),

            nn.Linear(200, 20),
            # nn.LayerNorm(20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.0001
iteration = 1000


x1 = torch.arange(-1000, 1000).float().to('cuda')
x2 = torch.arange(0, 2000).float().to('cuda')
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = (2*x1 - x2 + 1).to('cuda')

model = Model().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

scaler = GradScaler()

start_time = time.time()
writer = SummaryWriter(comment='_FP16')

for iter in range(iteration):
    with autocast():
        y_pred = model(x)
        loss = loss_function(y, y_pred.squeeze())
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    writer.add_scalar('loss', loss, iter)
    optimizer.zero_grad()

    if iter % 100 == 0:
        # 获取 GPU 的内存使用情况
        print("GPU Memory Allocated:", torch.cuda.memory_allocated())
        print("GPU Memory Cached:   ", torch.cuda.memory_reserved())

print("Time: ", time.time() - start_time)
torch.save(model.state_dict(), 'model_state_dict_fp16.pth')

FP32训练代码:

python 复制代码
import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time

from torch.cuda.amp import GradScaler, autocast


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20000),

            nn.Linear(20000, 20000),
            # nn.Dropout(0.1),

            nn.Linear(20000, 200),
            # nn.LayerNorm(20),

            nn.Linear(200, 20),
            # nn.LayerNorm(20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.0001
iteration = 1000


x1 = torch.arange(-1000, 1000).float().to('cuda')
x2 = torch.arange(0, 2000).float().to('cuda')
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = (2*x1 - x2 + 1).to('cuda')

model = Model().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

scaler = GradScaler()

start_time = time.time()
writer = SummaryWriter(comment='_FP32')

for iter in range(iteration):

    y_pred = model(x)
    loss = loss_function(y, y_pred.squeeze())
    loss.backward()
    optimizer.step()


    writer.add_scalar('loss', loss, iter)
    optimizer.zero_grad()

    if iter % 100 == 0:
        # 获取 GPU 的内存使用情况
        print("GPU Memory Allocated:", torch.cuda.memory_allocated())
        print("GPU Memory Cached:   ", torch.cuda.memory_reserved())

print("Time: ", time.time() - start_time)
torch.save(model.state_dict(), 'model_state_dict_fp32.pth')

最终两者的效果如下:

实验名 占用GPU 消耗时间
FP16与FP32混合训练 4867271680 bytes 78.73 s
FP32训练 4867274752 bytes 140.18 s

实验分析如下:

1、在内存占用方面,两种训练方法几乎相同。这可能是因为模型结构和数据集大小相同,所以内存占用没有显著差异。然而,通常情况下,FP16训练应该占用更少的内存,因为它使用的是半精度浮点数。

2、在训练时间方面,混合精度训练明显快于纯FP32训练。这是因为FP16训练可以加快计算速度并降低内存需求,从而允许模型更快地运行。混合精度训练结合了FP16的高效率和FP32的数值稳定性,提供了一个平衡的解决方案。

在提供的损失图中,我们可以看到蓝色曲线代表使用全精度(FP32)训练的模型损失,而橙色曲线代表使用混合精度(FP16与FP32)训练的模型损失。以下是对两种训练方法损失曲线的进一步分析:

· 在下降到一定程度后,两条曲线都达到平稳状态,这表明模型已基本收敛。在这个平稳阶段,损失变化不大,说明模型在训练集上的表现已经稳定。

· 没有明显的过拟合迹象,因为损失曲线没有再次上升的趋势。

· 混合精度训练似乎在时间效率略优于全精度训练,尽管最终损失值的差异不大,但在追求快速迭代和高效训练的情况下,选择混合精度训练会更有优势;而在收敛速度上,混合精度训练在140epoch时收敛完毕,而全精度训练早在100epoch就收敛完,说明全精度训练收敛较于混合精度训练,在step层面更快,而它在时间层面更慢。

相关推荐
缺的不是资料,是学习的心20 分钟前
使用qwen作为基座训练分类大模型
python·机器学习·分类
AI趋势预见32 分钟前
使用AI生成金融时间序列数据:解决股市场的数据稀缺问题并提升信噪比
人工智能·深度学习·神经网络·语言模型·金融
Zda天天爱打卡1 小时前
【机器学习实战中阶】使用Python和OpenCV进行手语识别
人工智能·python·深度学习·opencv·机器学习
martian6651 小时前
第19篇:python高级编程进阶:使用Flask进行Web开发
开发语言·python
背太阳的牧羊人2 小时前
冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。
人工智能·语言模型·自然语言处理
gis收藏家2 小时前
利用 SAM2 模型探测卫星图像中的农田边界
开发语言·python
YiSLWLL2 小时前
Tauri2+Leptos开发桌面应用--绘制图形、制作GIF动画和mp4视频
python·rust·ffmpeg·音视频·matplotlib
数据馅2 小时前
python自动生成pg数据库表对应的es索引
数据库·python·elasticsearch
编程、小哥哥2 小时前
python操作mysql
android·python
Serendipity_Carl2 小时前
爬虫基础之爬取某站视频
爬虫·python·pycharm