混合精度训练详解

随着模型规模的不断增大,GPU的显存逐渐成为问题。为了能够在较小的显存中训练较大的模型,同时保证训练速度,Micikevicius等人提出了混合精度训练的概念。他们经过实验发现,可以使用更低的精度来训练神经网络,进而带来巨大速度收益。混合精度训练一般能够获得2-3倍的速度提升。

一、什么是精度

模型的参数都是用浮点数来进行表示的,而这里的精度指的就是表示一个浮点数所需要的位数。

单精度(FP32)

单精度浮点数(FP32)使用32位二进制来表示一个实数。在这32位中,1位用于符号位(表示正负),8位用于指数位(表示数值的范围),剩下的23位用于尾数位(表示数值的精度)。

半精度(FP16)

半精度浮点数(FP16)使用16位二进制来表示一个实数。在16位中,1位用于符号位,5位用于指数位,剩下的10位用于尾数位。由于FP16的指数位和尾数位都比FP32少,它能够表示的数值范围和精度都较低。但FP16在计算速度和显存利用率方面有显著优势,因为它占用的空间是FP32的一半。

半精度的优势

  1. 在同样的GPU显存下,半精度浮点数可以容纳更大的参数量、更多的训练数据。FP16的占用的空间是FP32的一半,因此权重等参数所占用的内存也是原来的一半,节省下来的内存可以放更大的网络模型或者使用更多的数据进行训练。
  2. 针对分布式训练,特别是在大模型训练的过程中,通讯的开销制约了网络模型训练的整体性能,低精度意味着可以提升通讯性能,减少等待时间,加快数据的流通。

二、混合精度训练的流程

综上,所谓的混合精度训练,就是指:原本模型在所有的计算中,都是采用FP32的格式进行计算,但是模型的所有计算过程中,有一部分过程对于精确度的要求并不高,因此可以将FP32格式的数据替换为FP16格式的数据。

1. 前向传播(Forward)

模型的权重(weights)仍然是以FP32的形式存储。在前向传播阶段,模型会先把FP32格式的权重复制一份,并转化为FP16格式进行前向传播的计算。

2. 损失计算(FP16)

在计算损失(Loss)时采用FP32形式进行计算

3. 损失缩放(Loss Scale)

由于反向传播阶段采用FP16格式计算梯度,而loss的值可能会非常小,此时会遇到数值稳定性的问题:

  • 因为FP16的表示范围较小,过小的loss可能导致梯度的值在转换过程中下溢(变得非常接近于零)。因此引入了损失缩放(Loss Scaling)的概念。

我们会在计算损失之前,将损失值乘以一个缩放因子(例如1024),这个操作会将可能下溢的数值提升到FP16可以表示的范围。

4. 反向传播(Backward)

在反向传播阶段,我们计算权重的梯度。

5. 权重更新(Optimize)

在得到FP16格式的梯度后,需要将其转换回FP32,以便更新FP32格式的参数。这一步骤通常涉及:

  1. 梯度转换:将缩放后的FP16梯度转换为FP32精度,以便与模型的FP32权重进行运算。
  2. 梯度还原:在更新权重之前,需要将缩放后的梯度除以Loss Scale因子,以恢复其原始的大小。这是因为在更新权重时,我们希望使用的是未经缩放的梯度值。
  3. 权重更新:将FP16梯度转换为FP32,使用FP32精度的梯度更新模型的权重。从而确保准确性。

三、Loss Scale详解


FP16的精度范围有限,训练一些模型的时候,梯度数值在FP16精度下都被表示为0,如上图所示。

为了让这些梯度能够被FP16表示,可以在计算Loss的时候,将loss乘以一个扩大的系数loss scale,比如1024。这样,一个接近0的极小的数字经过乘法,就能够被FP16表示。这个过程发生在前向传播的最后一步,反向传播之前。

四、基于Pytorch的代码样例

python 复制代码
import torch
from torch.cuda.amp import GradScaler, autocast
from torch import nn
from torch.optim import SGD

# 假设我们有一个简单的神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(32 * 32 * 32, 10)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

# 初始化模型和优化器
model = SimpleNet().cuda()
optimizer = SGD(model.parameters(), lr=0.001)

# 初始化GradScaler
scaler = GradScaler()

# 训练循环
for epoch in range(num_epochs):
    for inputs, targets in data_loader:  # 假设data_loader是已经定义好的数据加载器
        optimizer.zero_grad()  # 清空之前的梯度

        # 自动混合精度的上下文
        with autocast():
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

        # 反向传播前先缩放损失
        scaler.scale(loss).backward()

        # 调用optimizer.step()来更新权重
        scaler.step(optimizer)

        # 更新缩放因子
        scaler.update()

我们首先定义了一个简单的神经网络模型SimpleNet,然后初始化了模型、优化器和GradScaler。在训练循环中,我们使用autocast上下文来指定哪些操作应该使用FP16执行。GradScaler负责自动处理损失的缩放和梯度的精度转换,使得我们可以在不牺牲数值稳定性的情况下进行混合精度训练。

torch.cuda.amp.GradScaler默认可以动态调整loss scale,torch.autocast自动为不同的算子选择合适的精度。

相关推荐
Cosolar33 分钟前
AutoGen:微软开源的多Agent对话框架详解
人工智能·系统架构·大模型·agent·rag
Urbano34 分钟前
一条休闲束脚裤的工业化诞生科普 八道自动化缝纫工序拆解
人工智能
陕西企来客5 小时前
企来客科技来客 GEO 优化系统深度解析:核心技术与原因分析
大数据·人工智能·科技·搜索引擎
来让爷抱一个8 小时前
MonkeyCode 多模型切换技巧:什么时候用 Claude/GPT/DeepSeek
人工智能·ai编程
李白你好8 小时前
AI Agent 架构的自动化渗透测试工具
运维·人工智能·自动化
2601_949499948 小时前
8 大工业光模块供应商选型:芯瑞科技 400G OSFP 助力 AI 算力集群升级
人工智能·科技
温柔只给梦中人8 小时前
NLP学习:注意力机制
人工智能·学习·自然语言处理
weixin_429630268 小时前
3.49 HVLF:一种跨场景的整体视觉定位框架
深度学习·机器学习·计算机视觉
广州灵眸科技有限公司8 小时前
瑞芯微RV1126B开发板(EASY-EAI-PI2) Easy-Eai编译环境准备与更新
服务器·前端·人工智能·python·深度学习
深度学习lover8 小时前
<数据集>yolo樱桃识别<目标检测>
人工智能·深度学习·yolo·目标检测·计算机视觉·数据集·樱桃识别