混合精度Mixed Precision Training

神经网络的参数是用浮点精度表示的, 浮点精度的标准是IEEE 754 - Wikipedia,以下是一个FP16数值在内存中存储格式。

随着神经网络模型规模越来越大,如何减少模型占用的内存并且缩短训练时间成为亟需解决的问题,混合精度训练就是其中之一的解决方案,并且几乎不会影响模型训练的效果。

混合精度原理

想象一下,如果模型参数+loss+gradient都是用fp16保存的,fp16的最小值是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 6.1 × 1 0 − 5 6.1\times 10^{-5} </math>6.1×10−5,小于最小值的gradient都会变成0,相当于浪费了一次梯度传播。或许小的gradient并没有很重要,但是积累多次就会变得不可忽略。当前大模型普遍较低的学习率也会加剧这个问题的影响。

因此为了解决这个问题,就需要用更高精度fp32保存一份参数,在正常前向推理和反向传播时都用fp16,计算好得梯度先转换为fp32,再乘以学习率,然后更新到fp32存储得参数上,最终将fp32参数转换成fp16更新模型参数。

整个流程如下如:

这种用fp16和fp32共同训练模型得技术就叫做混合精度训练(MP, Mixed-Precision training),显然MP并不能节省模型加载需要的内存,因为需要多存储一份fp16的参数和梯度,但是用fp16进行模型前向和后向计算,能够减少中间计算值存储需要的内存,这部分内存会随着sequence length和batch size增大而增大,所以只有在这部分中间值占用内存比重较高时才能带来一定的内存节约。

虽然计算时间的影响不大,但是fp16训练时间的确会大大减少,通常是减少1.5~5.5倍。

更多资料:

fastai - Mixed precision training

Understanding Mixed Precision Training | by Jonathan Davis | Towards Data Science

Loss Scale

是不是混合精度训练就完全没有梯度损失了呢,并不是,在反向传播过程中其实已经有部分梯度因为精度原因丢失了(因为正常模型梯度都不会太大,所以我们主要考虑下溢出)。那么如何解决这部分问题呢,就要用到Loss Scale。

原理是将Loss乘以一个比较大的数scale,因为Loss是用fp32存储的,所以scale的选值范围是比较大的。这样因为反向传播链式法则原理,梯度也会放大很多倍,原本下溢出的值也会保存下来。然后在梯度转换成fp32后除以scale,最后更新就与正常混合精度训练一致了。

流程如下:

一般在开始训练时scale会设定成一个比较大的值,如果计算过程中fp16梯度发生上溢出,会跳过当前步的参数更新,并将scale下调。训练log中会输出如下消息:
⚠️ Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to...

相关推荐
数据智能老司机4 分钟前
CockroachDB权威指南——SQL调优
数据库·分布式·架构
数据智能老司机5 分钟前
CockroachDB权威指南——应用设计与实现
数据库·分布式·架构
数据智能老司机18 分钟前
CockroachDB权威指南——CockroachDB 模式设计
数据库·分布式·架构
数据智能老司机19 小时前
CockroachDB权威指南——CockroachDB SQL
数据库·分布式·架构
数据智能老司机20 小时前
CockroachDB权威指南——开始使用
数据库·分布式·架构
数据智能老司机20 小时前
CockroachDB权威指南——CockroachDB 架构
数据库·分布式·架构
IT成长日记20 小时前
【Kafka基础】Kafka工作原理解析
分布式·kafka
州周1 天前
kafka副本同步时HW和LEO
分布式·kafka
爱的叹息1 天前
主流数据库的存储引擎/存储机制的详细对比分析,涵盖关系型数据库、NoSQL数据库和分布式数据库
数据库·分布式·nosql
千层冷面1 天前
RabbitMQ 发送者确认机制详解
分布式·rabbitmq·ruby