大语言模型的工程技巧(二)——混合精度训练

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
混合精度训练的示例请参考如下链接:regression2chatgpt/ch11_llm/gpt2_lora_optimum.ipynb

本文将讨论如何利用混合精度训练(Mixed Precision Training)来减少内存的开销,特别是GPU内存的开销。这在大语言模型的训练当中是非常重要的。关于GPU的计算可以参考

关于大语言模型的讨论请参考:

内容大纲

一、概述

在人工智能领域,反向传播算法(计算参数梯度的算法)是非常重要的。而在进行反向传播计算时,必须将经过膨胀的计算图存储在内存中(如果使用GPU运算,那么将存储在GPU的专用内存中)。然而,这种存储量相当庞大,在整个计算图的存储结构中,数值存储占据了最大的比例。这些数值包括各个节点的计算结果(来自向前传播的输出),以及相应的梯度(这些梯度是来自反向传播的结果)。虽然梯度累积技术可以通过分解计算图来限制计算图的膨胀,从而降低内存的使用,但面对庞大的模型时,即便是单个数据点的计算图,其所需的内存都是巨大的。例如,大语言模型的参数数量可能高达数十亿甚至上百亿。

二、什么是混合精度训练?

为了解决这个具有挑战性的问题,需要采取额外的优化策略来降低内存的使用。在深入探讨这些策略之前,我们需要更详细地了解数字在计算机中的存储方式。一般而言,数值计算结果使用32位浮点数(需要4字节来存储,使用32位的二进制的方式表示)存储。这种存储方式被称为单精度浮点数。那么,如果使用16位二进制数表示一个数值,会产生什么影响呢?

这种方法的好处之一是能够立即减少所需的存储空间,同时提升计算速度。然而,这种方法也存在一个明显的缺陷,即能够表示的数值范围受限。为了便于讨论,下面以能够表示的最小正数为例。使用16位浮点数,能够表示的最小正数是 2 − 24 2^{-24} 2−24(相比之下,32位浮点数能够表示的最小正数为 2 − 149 2^{-149} 2−149)。当实际的数值小于这个阈值时,计算机会错误地将其视作0,这就是浮点数下溢(Underflow)。

为了尽可能地减少这类错误的发生,可以混合精度训练(Mixed Precision Training)算法,顾名思义,它是指在模型训练过程中使用不同的数值精度来处理不同部分的计算。

三、算法细节

这一算法包含两个主要部分。

  1. 精度分层处理:在这种训练中,模型本身(模型参数)依然使用32位浮点数进行存储,参数更新过程也使用32位浮点数。在模型的向前传播和反向传播过程中,转而使用16位浮点数进行计算。具体情况如图1所示。


图1

  1. 引入比例因子(Scale Factor):在数学上,要防止浮点数下溢是相当容易的,只需要将模型损失乘以一个较大的常数n,该常数也被称为比例因子。根据链式法则,这将导致所有节点的梯度都增大n倍。这种方法确保了梯度落入16位浮点数表示的范围,从而解决浮点数下溢问题。在使用这些梯度进行参数更新时,需要将引入的缩放移除,也就是将梯度除以n。将这个过程与精度分层处理相结合,如图2所示。


图2

混合精度训练方法的优势在于,在保持适当的模型表示能力的同时,显著降低了内存开锁。通过将高精度的32位浮点数与16位浮点数的计算相结合,在不牺牲模型性能的前提下,显著减少内存需求,使计算机能够处理更大规模的模型和数据集。

四、代码实现

在实际应用中,PyTorch已经提供了相应的封装函数,分别是torch.cuda.amp.autocast和torch.cuda.amp.GradScaler。其中autocast实现的是第一部分------精度分层处理;GradScaler实现的是第二部分------引入比例因子。借助这两个工具,在优化算法中使用混合精度训练就变得很容易了。示意代码如下:

python 复制代码
# 常规的模型训练实现
for epoch in range(0): 
    for input, target in zip(data, targets):
        # 启动混合精度训练
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)

        # 在触发反向传播之前,启动缩放因子
        scaler.scale(loss).backward()

        # 更新模型参数
        scaler.step(opt)
        scaler.update()

        opt.zero_grad()
相关推荐
神秘的土鸡2 分钟前
神经网络图像隐写术:用AI隐藏信息的艺术
人工智能·深度学习·神经网络
数据分析能量站3 分钟前
神经网络-LeNet
人工智能·深度学习·神经网络·机器学习
Jaly_W11 分钟前
用于航空发动机故障诊断的深度分层排序网络
人工智能·深度学习·故障诊断·航空发动机
小嗷犬13 分钟前
【论文笔记】Cross-lingual few-shot sign language recognition
论文阅读·人工智能·多模态·少样本·手语翻译
夜幕龙20 分钟前
iDP3复现代码数据预处理全流程(二)——vis_dataset.py
人工智能·python·机器人
吃个糖糖37 分钟前
36 Opencv SURF 关键点检测
人工智能·opencv·计算机视觉
AI慧聚堂1 小时前
自动化 + 人工智能:投标行业的未来是什么样的?
运维·人工智能·自动化
盛世隐者1 小时前
【pytorch】循环神经网络
人工智能·pytorch
cdut_suye1 小时前
Linux工具使用指南:从apt管理、gcc编译到makefile构建与gdb调试
java·linux·运维·服务器·c++·人工智能·python
开发者每周简报1 小时前
微软的AI转型故事
人工智能·microsoft