混合精度训练详解

随着模型规模的不断增大,GPU的显存逐渐成为问题。为了能够在较小的显存中训练较大的模型,同时保证训练速度,Micikevicius等人提出了混合精度训练的概念。他们经过实验发现,可以使用更低的精度来训练神经网络,进而带来巨大速度收益。混合精度训练一般能够获得2-3倍的速度提升。

一、什么是精度

模型的参数都是用浮点数来进行表示的,而这里的精度指的就是表示一个浮点数所需要的位数。

单精度(FP32)

单精度浮点数(FP32)使用32位二进制来表示一个实数。在这32位中,1位用于符号位(表示正负),8位用于指数位(表示数值的范围),剩下的23位用于尾数位(表示数值的精度)。

半精度(FP16)

半精度浮点数(FP16)使用16位二进制来表示一个实数。在16位中,1位用于符号位,5位用于指数位,剩下的10位用于尾数位。由于FP16的指数位和尾数位都比FP32少,它能够表示的数值范围和精度都较低。但FP16在计算速度和显存利用率方面有显著优势,因为它占用的空间是FP32的一半。

半精度的优势

  1. 在同样的GPU显存下,半精度浮点数可以容纳更大的参数量、更多的训练数据。FP16的占用的空间是FP32的一半,因此权重等参数所占用的内存也是原来的一半,节省下来的内存可以放更大的网络模型或者使用更多的数据进行训练。
  2. 针对分布式训练,特别是在大模型训练的过程中,通讯的开销制约了网络模型训练的整体性能,低精度意味着可以提升通讯性能,减少等待时间,加快数据的流通。

二、混合精度训练的流程

综上,所谓的混合精度训练,就是指:原本模型在所有的计算中,都是采用FP32的格式进行计算,但是模型的所有计算过程中,有一部分过程对于精确度的要求并不高,因此可以将FP32格式的数据替换为FP16格式的数据。

1. 前向传播(Forward)

模型的权重(weights)仍然是以FP32的形式存储。在前向传播阶段,模型会先把FP32格式的权重复制一份,并转化为FP16格式进行前向传播的计算。

2. 损失计算(FP16)

在计算损失(Loss)时采用FP32形式进行计算

3. 损失缩放(Loss Scale)

由于反向传播阶段采用FP16格式计算梯度,而loss的值可能会非常小,此时会遇到数值稳定性的问题:

  • 因为FP16的表示范围较小,过小的loss可能导致梯度的值在转换过程中下溢(变得非常接近于零)。因此引入了损失缩放(Loss Scaling)的概念。

我们会在计算损失之前,将损失值乘以一个缩放因子(例如1024),这个操作会将可能下溢的数值提升到FP16可以表示的范围。

4. 反向传播(Backward)

在反向传播阶段,我们计算权重的梯度。

5. 权重更新(Optimize)

在得到FP16格式的梯度后,需要将其转换回FP32,以便更新FP32格式的参数。这一步骤通常涉及:

  1. 梯度转换:将缩放后的FP16梯度转换为FP32精度,以便与模型的FP32权重进行运算。
  2. 梯度还原:在更新权重之前,需要将缩放后的梯度除以Loss Scale因子,以恢复其原始的大小。这是因为在更新权重时,我们希望使用的是未经缩放的梯度值。
  3. 权重更新:将FP16梯度转换为FP32,使用FP32精度的梯度更新模型的权重。从而确保准确性。

三、Loss Scale详解


FP16的精度范围有限,训练一些模型的时候,梯度数值在FP16精度下都被表示为0,如上图所示。

为了让这些梯度能够被FP16表示,可以在计算Loss的时候,将loss乘以一个扩大的系数loss scale,比如1024。这样,一个接近0的极小的数字经过乘法,就能够被FP16表示。这个过程发生在前向传播的最后一步,反向传播之前。

四、基于Pytorch的代码样例

python 复制代码
import torch
from torch.cuda.amp import GradScaler, autocast
from torch import nn
from torch.optim import SGD

# 假设我们有一个简单的神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(32 * 32 * 32, 10)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

# 初始化模型和优化器
model = SimpleNet().cuda()
optimizer = SGD(model.parameters(), lr=0.001)

# 初始化GradScaler
scaler = GradScaler()

# 训练循环
for epoch in range(num_epochs):
    for inputs, targets in data_loader:  # 假设data_loader是已经定义好的数据加载器
        optimizer.zero_grad()  # 清空之前的梯度

        # 自动混合精度的上下文
        with autocast():
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

        # 反向传播前先缩放损失
        scaler.scale(loss).backward()

        # 调用optimizer.step()来更新权重
        scaler.step(optimizer)

        # 更新缩放因子
        scaler.update()

我们首先定义了一个简单的神经网络模型SimpleNet,然后初始化了模型、优化器和GradScaler。在训练循环中,我们使用autocast上下文来指定哪些操作应该使用FP16执行。GradScaler负责自动处理损失的缩放和梯度的精度转换,使得我们可以在不牺牲数值稳定性的情况下进行混合精度训练。

torch.cuda.amp.GradScaler默认可以动态调整loss scale,torch.autocast自动为不同的算子选择合适的精度。

相关推荐
正在走向自律4 分钟前
AI写作(十)发展趋势与展望(10/10)
人工智能·aigc·ai写作
皓74122 分钟前
打造旅游卡服务新标杆:构建SOP框架与智能知识库应用
大数据·人工智能·旅游·敏捷流程
B站计算机毕业设计超人22 分钟前
计算机毕业设计Hive+Spark空气质量预测 空气质量可视化 空气质量分析 空气质量爬虫 Hadoop 机器学习 深度学习 Django 大模型
hive·hadoop·爬虫·深度学习·机器学习·spark·数据可视化
努力的小雨27 分钟前
从零开始学机器学习——了解聚类
机器学习
视窗中国29 分钟前
中信建投张青:以金融巨擘之姿,铸就公益慈善新篇章
人工智能·金融
Crossoads37 分钟前
【汇编语言】更灵活的定位内存地址的方法(二)—— 从 [bx+idata] 到 [bx+si+idata]:让你灵活的访问内存
android·java·服务器·网络协议·tcp/ip·机器学习·汇编语言
幽络源小助理37 分钟前
桥梁缺陷YOLO免费数据集分享 – 6308张已标注8类缺陷图像
人工智能·计算机视觉·目标跟踪
念啊啊啊啊丶1 小时前
【弱监督视频异常检测】2024-ESWA-基于扩散的弱监督视频异常检测常态预训练
人工智能·深度学习·神经网络·机器学习·计算机视觉
陌上阳光1 小时前
初学人工智不理解的名词3
人工智能·语音识别
ZHOU_WUYI1 小时前
5. langgraph中的react agent使用 (从零构建一个react agent)
人工智能·langchain