[学术论文每日一读]大语言模型的摩尔定律:Scaling Laws for NLP Models

导言

本次介绍的论文为:OpenAI在2020年发表的Scaling laws for neural language models,虽然目前仍然只是挂在arXiv上,但是截至23年11月7日已经被引用了714 次了,在一些小方向这是只有开山之作才能达到的引用数,一作是理论机器学习方向的小牛Jared Kaplan,h-index为54。粗略鉴定本文还是有足够的含金量的。

文章结论概览

本文是典型的实验型论文,作者针对基于 Transformer 架构和交叉熵损失的自然语言模型,通过训练大量不同参数量的模型并计算测试集损失,研究了模型规模和模型性能之间的规律,并总结成定律。

因为本文是实验型论文,因此先直接总结实验结果:

  1. 测试集上的loss随模型规模、数据集大小和训练计算量的增加而按幂律比例下降。
  2. 网络的宽度或深度等其他架构细节在一个较大范围内对性能的影响甚微。
  3. 更大的模型可以更高效的利用训练样本。仅从训练效率的角度考虑,可以在较少的数据上训练大型模型,并在模型未完全收敛时就停止训练。

文章的实验环境与参数

为了探究语言模型规模扩展的影响,作者训练了一系列模型,然后用力控制变量法,这些模型在以下几个方面有所不同:

  1. 模型规模(N):参数量范围从768个到15亿。
  2. 数据集规模(D):数据量从2200万个token到230亿个。
  3. 模型结构:模型的深度、宽度、注意力头数以及前馈网络的维度都进行了实验。
  4. 上下文长度:在大多数测试中为1024,但也对较短的上下文长度进行了实验。
  5. 批量大小(batch size):大多数模型的batch size为524,288。在这个批量大小进行训练,能在训练时间和计算资源使用效率之间找到一个较好的平衡点。

实验部分分析

假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 表示测试时的交叉熵损失。 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C 表示训练模型所使用的计算资源量。

模型参数的数量、数据集的规模以及训练时消耗的计算资源与模型性能的关系

模型的性能极大程度上依赖于模型的参数量,并且相对较弱地受模型结构的影响

模型性能主要由三个要素决定:模型参数的数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N、数据集的规模 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 以及训练时消耗的计算资源 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C。 <math xmlns="http://www.w3.org/1998/Math/MathML"> N , D , C N, D, C </math>N,D,C够大时,模型的深度、宽度以及其他结构性超参数对于性能的影响非常小。简单地说,模型越大,数据越多,算力越强,力大砖飞。

上图展示了模型性能与这三个因素之间的实验图像,可以发现,控制另两个因素不变,只改变其中的一个因素时,模型性能与 <math xmlns="http://www.w3.org/1998/Math/MathML"> N , D , C N, D, C </math>N,D,C存在平滑的幂律关系,这里读者需要注意,实验图中的x坐标不是等间距的,而是不同的数量级,因此这并不是线性关系。可以用幂函数进行拟合,图中也给出了模型性能和因素之间的实验上的幂律函数。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = ( C min ⁡ / 2.3 ⋅ 1 0 8 ) − 0.050 ( 1 ) L = ( D / 5.4 ⋅ 1 0 13 ) − 0.095 ( 2 ) L = ( N / 8.8 ⋅ 1 0 13 ) − 0.076 ( 3 ) L=\left(C_{\min } / 2.3 \cdot 10^8\right)^{-0.050} \ \ \ \ \ \ \ \ \ \ (1) \\ L=\left(D / 5.4 \cdot 10^{13}\right)^{-0.095} \ \ \ \ \ \ \ \ \ \ \ \ \ (2) \\ L=\left(N / 8.8 \cdot 10^{13}\right)^{-0.076} \ \ \ \ \ \ \ \ \ \ \ \ \ (3) </math>L=(Cmin/2.3⋅108)−0.050 (1)L=(D/5.4⋅1013)−0.095 (2)L=(N/8.8⋅1013)−0.076 (3)

embedding参数和non-embedding参数对模型性能的影响

这篇论文在探讨模型参数时,将 embedding 参数和非 embedding 参数分开对待 ,是因为这两种参数对模型性能的影响程度不一样。上图中的参数数量指的是去除掉embedding参数以外的模型参数 。当考虑所有参数时,模型性能与模型的参数总量和模型深度强相关。不考虑嵌入参数时,不同深度的模型性能的趋势基本相同。

如上图左图所示:包含嵌入参数时,性能似乎强烈依赖于层数以及参数数量,相同数量,模型越深,效果越好。上图右图:排除嵌入参数时,不同深度的模型性能都为线性变化并且较为相近。

过拟合

只要我们同时扩展模型参数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 和训练集数据数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D,就可以有效提升模型性能,但如果保持 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 或 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 中的任何一个固定,而只增加另一个,就会进入递减回报的状态。

上图在实验时使用了early stop策略,如左图所示:控制数据集数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D不变,只增加模型参数,数据集大小就会成为模型性能的瓶颈所在 ,数据量过小, <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 不断增大只能导致模型过拟合。右图探究了过拟合和 <math xmlns="http://www.w3.org/1998/Math/MathML"> N / D N/D </math>N/D 之间的关系,数据量越小, <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 差距越大,模型越过拟合。

模型参数与样本效率与训练时长

如左图所示,模型越大,利用训练集样本的效率越高 ,越大的模型能够通过更少的学习步骤达到相同的性能标准。 右图则给出了模型在训练过程中的loss变化情况和训练时间的关系,很明显,模型的训练是有边际成本的,模型性能的提升,或者说loss的减小是有一段快速下降期,而后就基本不变了,因此可以使用early-stop机制来减小计算成本,同时获得可用模型。

算力有限时,Batch size,训练步数,模型参数两之间应该做何选择

随着算力的增加,我们可以灵活选择算力的分配:扩大模型尺寸、增加批量大小,或是延长训练步数。如上图所示,应将主要的计算资源投入到提升模型大小上,即在其他因素不变的情况下,增加参数对模型性能的增益最大,同时需要适度增加训练集大小。

总结

总的来说,模型大小、训练数据量和算力越大,模型性能就会平滑地提高,并且可以通过实验来发现性能和这几个因素之间的函数关系。当模型够大,训练集够多,算力够强,模型架构、模型超参数对模型性能影响很小。

本文的主要意义在于,在读者训练大模型时,可以根据自己手中的模型参数量,训练数据量,和算力,来大致估算模型最终能达到的性能,或者说是模型最终的Loss,从而判断自己的模型在训练时是否可以停止训练或者判断所处的训练阶段。

不过需要注意:本文的实验模型是transformer架构和交叉熵损失的NLP模型,对于CV和多模态,其他loss,推荐系统的模型可能不具有代表性。

相关推荐
Elastic 中国社区官方博客17 分钟前
使用 Elastic AI Assistant for Search 和 Azure OpenAI 实现从 0 到 60 的转变
大数据·人工智能·elasticsearch·microsoft·搜索引擎·ai·azure
江_小_白1 小时前
自动驾驶之激光雷达
人工智能·机器学习·自动驾驶
yusaisai大鱼3 小时前
TensorFlow如何调用GPU?
人工智能·tensorflow
珠海新立电子科技有限公司5 小时前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
IT古董5 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论
曼城周杰伦5 小时前
自然语言处理:第六十三章 阿里Qwen2 & 2.5系列
人工智能·阿里云·语言模型·自然语言处理·chatgpt·nlp·gpt-3
余炜yw6 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐6 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1237 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr7 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络