混合精度训练详解

随着模型规模的不断增大,GPU的显存逐渐成为问题。为了能够在较小的显存中训练较大的模型,同时保证训练速度,Micikevicius等人提出了混合精度训练的概念。他们经过实验发现,可以使用更低的精度来训练神经网络,进而带来巨大速度收益。混合精度训练一般能够获得2-3倍的速度提升。

一、什么是精度

模型的参数都是用浮点数来进行表示的,而这里的精度指的就是表示一个浮点数所需要的位数。

单精度(FP32)

单精度浮点数(FP32)使用32位二进制来表示一个实数。在这32位中,1位用于符号位(表示正负),8位用于指数位(表示数值的范围),剩下的23位用于尾数位(表示数值的精度)。

半精度(FP16)

半精度浮点数(FP16)使用16位二进制来表示一个实数。在16位中,1位用于符号位,5位用于指数位,剩下的10位用于尾数位。由于FP16的指数位和尾数位都比FP32少,它能够表示的数值范围和精度都较低。但FP16在计算速度和显存利用率方面有显著优势,因为它占用的空间是FP32的一半。

半精度的优势

  1. 在同样的GPU显存下,半精度浮点数可以容纳更大的参数量、更多的训练数据。FP16的占用的空间是FP32的一半,因此权重等参数所占用的内存也是原来的一半,节省下来的内存可以放更大的网络模型或者使用更多的数据进行训练。
  2. 针对分布式训练,特别是在大模型训练的过程中,通讯的开销制约了网络模型训练的整体性能,低精度意味着可以提升通讯性能,减少等待时间,加快数据的流通。

二、混合精度训练的流程

综上,所谓的混合精度训练,就是指:原本模型在所有的计算中,都是采用FP32的格式进行计算,但是模型的所有计算过程中,有一部分过程对于精确度的要求并不高,因此可以将FP32格式的数据替换为FP16格式的数据。

1. 前向传播(Forward)

模型的权重(weights)仍然是以FP32的形式存储。在前向传播阶段,模型会先把FP32格式的权重复制一份,并转化为FP16格式进行前向传播的计算。

2. 损失计算(FP16)

在计算损失(Loss)时采用FP32形式进行计算

3. 损失缩放(Loss Scale)

由于反向传播阶段采用FP16格式计算梯度,而loss的值可能会非常小,此时会遇到数值稳定性的问题:

  • 因为FP16的表示范围较小,过小的loss可能导致梯度的值在转换过程中下溢(变得非常接近于零)。因此引入了损失缩放(Loss Scaling)的概念。

我们会在计算损失之前,将损失值乘以一个缩放因子(例如1024),这个操作会将可能下溢的数值提升到FP16可以表示的范围。

4. 反向传播(Backward)

在反向传播阶段,我们计算权重的梯度。

5. 权重更新(Optimize)

在得到FP16格式的梯度后,需要将其转换回FP32,以便更新FP32格式的参数。这一步骤通常涉及:

  1. 梯度转换:将缩放后的FP16梯度转换为FP32精度,以便与模型的FP32权重进行运算。
  2. 梯度还原:在更新权重之前,需要将缩放后的梯度除以Loss Scale因子,以恢复其原始的大小。这是因为在更新权重时,我们希望使用的是未经缩放的梯度值。
  3. 权重更新:将FP16梯度转换为FP32,使用FP32精度的梯度更新模型的权重。从而确保准确性。

三、Loss Scale详解


FP16的精度范围有限,训练一些模型的时候,梯度数值在FP16精度下都被表示为0,如上图所示。

为了让这些梯度能够被FP16表示,可以在计算Loss的时候,将loss乘以一个扩大的系数loss scale,比如1024。这样,一个接近0的极小的数字经过乘法,就能够被FP16表示。这个过程发生在前向传播的最后一步,反向传播之前。

四、基于Pytorch的代码样例

python 复制代码
import torch
from torch.cuda.amp import GradScaler, autocast
from torch import nn
from torch.optim import SGD

# 假设我们有一个简单的神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(32 * 32 * 32, 10)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

# 初始化模型和优化器
model = SimpleNet().cuda()
optimizer = SGD(model.parameters(), lr=0.001)

# 初始化GradScaler
scaler = GradScaler()

# 训练循环
for epoch in range(num_epochs):
    for inputs, targets in data_loader:  # 假设data_loader是已经定义好的数据加载器
        optimizer.zero_grad()  # 清空之前的梯度

        # 自动混合精度的上下文
        with autocast():
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

        # 反向传播前先缩放损失
        scaler.scale(loss).backward()

        # 调用optimizer.step()来更新权重
        scaler.step(optimizer)

        # 更新缩放因子
        scaler.update()

我们首先定义了一个简单的神经网络模型SimpleNet,然后初始化了模型、优化器和GradScaler。在训练循环中,我们使用autocast上下文来指定哪些操作应该使用FP16执行。GradScaler负责自动处理损失的缩放和梯度的精度转换,使得我们可以在不牺牲数值稳定性的情况下进行混合精度训练。

torch.cuda.amp.GradScaler默认可以动态调整loss scale,torch.autocast自动为不同的算子选择合适的精度。

相关推荐
杭州泽沃电子科技有限公司1 分钟前
核电的“热血管”与它的智能脉搏:热转换在线监测如何守护能源生命线
人工智能·在线监测
yuzhiboyouye6 分钟前
指引上调是什么意思
人工智能
昨夜见军贴061623 分钟前
IACheck × AI审核:重构来料证书报告审核流程,赋能生产型企业高质量发展
人工智能·重构
OidEncoder26 分钟前
绝对值编码器工作原理、与增量编码器的区别及单圈多圈如何选择?
人工智能
计算机科研狗@OUC32 分钟前
(NeurIPS25) Spiking Meets Attention: 基于注意力脉冲神经网络的高效遥感图像超分辨率重建
人工智能·神经网络·超分辨率重建
EasyGBS33 分钟前
EasyGBS打造变电站高效智能视频监控解决方案
网络·人工智能·音视频
汤姆yu33 分钟前
基于深度学习的杂草检测系统
人工智能·深度学习
LaughingZhu34 分钟前
Product Hunt 每日热榜 | 2026-01-06
人工智能·经验分享·深度学习·神经网络·产品运营
东方佑34 分钟前
SamOutVXP-2601: 轻量级高效语言模型
人工智能·语言模型·自然语言处理
管理快车道36 分钟前
连锁零售利润增长:我的实践复盘
大数据·人工智能·零售