【Transformer 与注意力机制】06|梯度下降与反向传播
每次跟人讲神经网络,最容易卡住的不是激活函数,也不是矩阵乘法,而是「网络到底是怎么学的」这件事。前向传播好理解:把输入塞进一堆矩阵和非线性,最后吐出一个预测。但「学习」这个词背后究竟发生了什么?参数是怎么被推动的?为什么是「梯度」这个东西在主导一切?
为什么我们用反向传播而不是数值微分?为什么实际工程中没人写 SGD,全都在用 Adam?为什么 Transformer 训练前几千步要 warmup、之后又要 cosine 衰减?这些问题在很多教材里被分开讲,散落在不同章节,初学者很难把它们拼成一张连贯的图。
这一篇要把这些问题一次性讲透。我尽量不跳步:从高中数学里的「求最小值」这件事讲起,一路过渡到一维梯度下降、多元梯度、损失曲面、计算图、反向传播,最后再到现代优化器和 Transformer 训练里的那些坑。
它是后续所有篇章的共同地基------后面无论讲注意力、讲位置编码、讲 LayerNorm、讲 Scaling Laws,模型「能不能训得动」永远绕不开本篇这套机制。如果你已经熟悉这些,可以跳到 7、8、9 节看与 Transformer 直接相关的工程细节;如果是第一次系统接触,建议从头读。
一、从「求最小值」开始想问题
1.1 一个我们都做过的题
回忆一下高中数学里那道无数人做过的题:求 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = ( x − 3 ) 2 + 2 f(x) = (x-3)^2 + 2 </math>f(x)=(x−3)2+2 的最小值。
当时的标准解法有两条路。要么配方看出顶点 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 3 , 2 ) (3, 2) </math>(3,2),要么求导 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x ) = 2 ( x − 3 ) f'(x) = 2(x-3) </math>f′(x)=2(x−3),让导数等于零解出 <math xmlns="http://www.w3.org/1998/Math/MathML"> x = 3 x = 3 </math>x=3。两种方法的本质是同一个:找一个点,让函数在这个点附近不再下降。这就是「最小值」最朴素的几何意义------在它身边的小邻域里,没人比它更低。
把这个题再变得难一点:换成 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = x 4 − 5 x 2 + 4 x + 7 f(x) = x^4 - 5x^2 + 4x + 7 </math>f(x)=x4−5x2+4x+7。求导得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x ) = 4 x 3 − 10 x + 4 f'(x) = 4x^3 - 10x + 4 </math>f′(x)=4x3−10x+4,这是一个三次方程。我们当然有公式可以解,但稍微再升一阶到五次以上,连解析公式都没有了------这是 Galois 在十九世纪初已经证明过的事。
再考虑一下高维: <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x 1 , x 2 , ... , x 1000 ) f(x_1, x_2, \dots, x_{1000}) </math>f(x1,x2,...,x1000),让所有偏导同时等于零,得到的方程组没有任何希望「一笔写出解」。神经网络的损失函数动辄上亿个参数,更不可能用「求导等于零」的办法直接求解。
那怎么办?换条路:既然解不出来,我们就「猜一个起点,然后一步一步往低处挪」。这条路就叫迭代下降 ,而它在最常见的形态里有一个名字:梯度下降(Gradient Descent)。
这个想法本身比深度学习老得多。Cauchy 1847 年那封短短两页的笔记《Méthode générale pour la résolution des systèmes d'équations simultanées》就是它最早的雏形。Cauchy 当时只是想数值求解非线性方程组,并没有意识到一百多年后这个思路会成为人工智能的引擎。
1.2 一维梯度下降的几何直觉
先回到最简单的一维。假设我们站在曲线 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = f ( x ) y = f(x) </math>y=f(x) 上某一点 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0,眼睛看不见整条曲线,只能感知到脚下的「斜率」。
如果当前斜率 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x 0 ) > 0 f'(x_0) > 0 </math>f′(x0)>0,说明面朝右是上坡、朝左是下坡,所以应该往左挪;如果斜率为负,则应该往右挪。朝着导数符号的反方向走,就是在下山。这件事用一行公式表达就是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t + 1 = x t − η ⋅ f ′ ( x t ) x_{t+1} = x_t - \eta \cdot f'(x_t) </math>xt+1=xt−η⋅f′(xt)
这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η(eta)叫学习率(learning rate),它控制每一步走多远。学习率不是一个无关紧要的旋钮,它几乎决定了整个训练的成败。
设想一下你站在一个抛物线 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = x 2 f(x) = x^2 </math>f(x)=x2 上的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 = 10 x_0 = 10 </math>x0=10 处,导数是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 x 0 = 20 2x_0 = 20 </math>2x0=20。如果取 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 0.1 \eta = 0.1 </math>η=0.1,下一步到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 = 10 − 0.1 × 20 = 8 x_1 = 10 - 0.1 \times 20 = 8 </math>x1=10−0.1×20=8,再下一步到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 6.4 6.4 </math>6.4、 <math xmlns="http://www.w3.org/1998/Math/MathML"> 5.12 5.12 </math>5.12,每步都在缩小,最终收敛到零。
但如果取 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 1.1 \eta = 1.1 </math>η=1.1,下一步到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 10 − 1.1 × 20 = − 12 10 - 1.1 \times 20 = -12 </math>10−1.1×20=−12,再下一步到 <math xmlns="http://www.w3.org/1998/Math/MathML"> − 12 + 1.1 × 24 = 14.4 -12 + 1.1 \times 24 = 14.4 </math>−12+1.1×24=14.4,每步都在变大,整个过程发散------一个看起来无害的旋钮稍微调过头,整个训练就炸了。
这就是为什么任何一本深度学习教材在第一次讲到梯度下降时都会强调「学习率是个艺术」。太大震荡甚至发散,太小则像蚂蚁搬家一样慢得令人崩溃。后面我们会看到 Adam 把这件事自动化到一定程度,但即便用 Adam,初始学习率仍然要手调,仍然是 Transformer 训练里最常见的事故源头之一。
更精确一点的说法:对二次函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = 1 2 α x 2 f(x) = \frac{1}{2}\alpha x^2 </math>f(x)=21αx2,梯度下降的收敛性条件是 <math xmlns="http://www.w3.org/1998/Math/MathML"> η < 2 / α \eta < 2/\alpha </math>η<2/α;当 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 1 / α \eta = 1/\alpha </math>η=1/α 时一步就到。对一般凸函数也有类似条件,跟函数的 Lipschitz 常数挂钩。这些理论结果在凸优化教材里很标准(参见 Boyd & Vandenberghe 的《Convex Optimization》),但在非凸的神经网络里只能当作「定性参考」。
上图左侧画的就是这一维过程:在抛物线上从一个高点出发,沿着切线斜率反方向一步步往下挪,最终落到最低点。它给我们提供了一个最干净的心智模型:梯度下降不是一种神秘算法,它就是「站在哪、看脚下、朝低处走」这件事的代数化。
1.3 把直觉推到多维:偏导与梯度向量
真实的神经网络从来不是一元函数。损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 是关于「网络全部参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ」的函数,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 维度可能是几百万、几十亿甚至更多。
这时一维的导数概念要升级。每个维度上各有一个偏导数 (partial derivative)------它的含义是:「固定其他维度不动,只动这一维时,函数怎么变」。把所有偏导数排成一个向量,就是梯度(gradient):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ L ( θ ) = ( ∂ L ∂ θ 1 , ∂ L ∂ θ 2 , ... , ∂ L ∂ θ n ) \nabla L(\theta) = \left(\frac{\partial L}{\partial \theta_1}, \frac{\partial L}{\partial \theta_2}, \dots, \frac{\partial L}{\partial \theta_n}\right) </math>∇L(θ)=(∂θ1∂L,∂θ2∂L,...,∂θn∂L)
梯度向量有一个非常漂亮的几何性质:它指向函数值增长最快的方向。
换句话说,沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ L \nabla L </math>∇L 走是上升最快,沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> − ∇ L -\nabla L </math>−∇L 走是下降最快。这件事可以从方向导数的定义推出来:函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 在单位方向 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u 上的方向导数等于 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ L ⋅ u \nabla L \cdot u </math>∇L⋅u,这个点积在 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u 与梯度同向时取最大值,与梯度反向时取最小值。
这条结论是整个梯度下降的理论基石------我们之所以朝 <math xmlns="http://www.w3.org/1998/Math/MathML"> − ∇ L -\nabla L </math>−∇L 方向走,是因为在所有方向里它就是「下降最快」的那一个。注意这里说的是「下降最快的方向」,不是「最快到达最小值的方向」,这两件事在椭圆形损失面上并不一致,下面就会看到。
于是多维梯度下降的更新公式只是把一维公式里的导数换成梯度:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ t + 1 = θ t − η ⋅ ∇ L ( θ t ) \theta_{t+1} = \theta_t - \eta \cdot \nabla L(\theta_t) </math>θt+1=θt−η⋅∇L(θt)
但「最快下降」并不意味着「最快到达最小值」。这里面有一个非常重要却容易被忽略的细节。
在等高线呈圆形(即各方向曲率相同)时,负梯度方向直指圆心;但当损失曲面是椭圆形(即不同方向的曲率差异很大)时,负梯度方向并不指向最低点,而是垂直于等高线。结果就是参数沿着「短轴方向」剧烈震荡、沿着「长轴方向」缓慢挪动。
下图左侧是各方向曲率接近一致的情形,负梯度几乎直指圆心;右侧是狭长的病态谷底,负梯度每一步都只顾局部最陡方向,于是参数会在短轴方向来回折返,只能沿长轴方向缓慢逼近最低点。这种现象在深度网络里非常普遍------损失面常常是高度病态的(ill-conditioned),不同参数维度的曲率可以相差几个数量级------也是后面 Momentum、RMSProp、Adam 一系列工作要解决的核心问题。
1.4 凸优化与非凸现实
梯度下降原本是为「凸函数」准备的算法。在凸函数上它有非常漂亮的理论保证:只要学习率合适,从任何起点出发都能收敛到全局最小值,速度大约是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 / t ) O(1/t) </math>O(1/t)(无加速)或 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 / t 2 ) O(1/t^2) </math>O(1/t2)(加速版,Nesterov)。
然而神经网络的损失函数几乎从来不是凸的。它在参数空间里有无数局部最小值、鞍点、平台区。我们用一个理论上只能保证局部下降的算法,去训练一个布满坑洼的曲面,居然能训出 GPT-4 这种东西,这本身就是深度学习里最玄学也最经验的部分。
但实践里有一个被多次复现的观察:在足够高维的损失面上,严格的局部最小值非常少,绝大多数「停滞点」都是鞍点。鞍点的特点是有的方向上是极小、有的方向上是极大,只要梯度下降还能感知到任何下降方向,它就能继续走。
这一观察在 Goodfellow 等人 2014 年的工作《Qualitatively characterizing neural network optimization problems》里有很好的实验证据。他们沿着「初始参数 → 训练完毕参数」的直线方向画 loss 曲线,发现这条直线上几乎是单调下降的------也就是说从随机初始化到最终解之间没有需要翻越的「山」。这个观察和「损失面上几乎没有真正麻烦的局部极小」一脉相承,也是「为什么 SGD 能训出大模型」这个看似奇迹的现象背后最常被引用的解释之一。
更近一些的工作如 Li et al. 2018 的《Visualizing the Loss Landscape of Neural Nets》进一步说明:残差连接和合理的网络宽度让损失面变得更「平」、更接近凸函数。这是为什么 ResNet 之后的深度网络比 ResNet 之前的更好训。
1.5 Newton 法、梯度下降与「为什么不用更好的方法」
有人会反问:既然梯度下降只用一阶信息,为什么不直接用 Newton 法?
Newton 法在二次函数上一步就到,对一般非线性函数也比梯度下降快得多。它的更新规则是 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − H − 1 ∇ L \theta_{t+1} = \theta_t - H^{-1} \nabla L </math>θt+1=θt−H−1∇L,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> H H </math>H 是 Hessian 矩阵。在低维优化里 Newton 法非常实用------SciPy 的 minimize 默认就有 Newton-CG 选项。
下图左侧画的是 Newton 法的核心直觉:在当前点附近把目标函数做二阶 Taylor 展开,得到一个局部二次近似,然后直接跳到这条近似抛物线的顶点;如果目标本身就是二次函数,这一步就是精确解。右侧则把它的代价也画出来了:Hessian 有 <math xmlns="http://www.w3.org/1998/Math/MathML"> n 2 n^2 </math>n2 个元素,而且每一步都要解一次线性系统 <math xmlns="http://www.w3.org/1998/Math/MathML"> H Δ = ∇ L H\Delta = \nabla L </math>HΔ=∇L。
也正因为如此,深度学习里 Newton 法几乎不可用:Hessian 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × n n \times n </math>n×n 矩阵,参数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> n = 1 0 9 n = 10^9 </math>n=109 时存储就要 <math xmlns="http://www.w3.org/1998/Math/MathML"> 4 × 1 0 18 4 \times 10^{18} </math>4×1018 字节,等于 4 EB。即便有「Hessian-vector product」的技巧避免显式存储,每次还要解一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 维线性系统,计算量远超梯度下降。
退一步:可以考虑「拟 Newton 法」(quasi-Newton)比如 BFGS、L-BFGS。它们用过去几步的梯度近似 Hessian 的逆。L-BFGS 在中小规模优化(比如 logistic 回归、共轭梯度法)里非常成功,但在百亿参数的神经网络里仍然吃不消------存几步历史就要几十 GB。
所以「梯度下降 + 自适应优化器」不是因为它们最优,而是因为它们在「计算复杂度 / 内存复杂度 / 训练效果」三角里给出了一个工程可行的权衡。这一点要记住:所有深度学习的算法选择,都是在「数学最优」和「工程可行」之间的妥协。
1.6 一个数值发散的实验
为让前面那句「学习率太大就发散」的话有体感,做一个最小实验。
考虑 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = x 2 / 2 f(x) = x^2 / 2 </math>f(x)=x2/2,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 = 1 x_0 = 1 </math>x0=1 出发。学习率取 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 0.5 , 1.0 , 1.5 , 2.0 , 2.5 \eta = 0.5, 1.0, 1.5, 2.0, 2.5 </math>η=0.5,1.0,1.5,2.0,2.5 五种情形:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 0.5 \eta = 0.5 </math>η=0.5: <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 = 0.5 , x 2 = 0.25 , x 3 = 0.125 , ... x_1 = 0.5, x_2 = 0.25, x_3 = 0.125, \dots </math>x1=0.5,x2=0.25,x3=0.125,... 收敛
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 1.0 \eta = 1.0 </math>η=1.0: <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 = 0 x_1 = 0 </math>x1=0 一步就到
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 1.5 \eta = 1.5 </math>η=1.5: <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 = − 0.5 , x 2 = 0.25 , x 3 = − 0.125 , ... x_1 = -0.5, x_2 = 0.25, x_3 = -0.125, \dots </math>x1=−0.5,x2=0.25,x3=−0.125,... 振荡收敛
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 2.0 \eta = 2.0 </math>η=2.0: <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 = − 1 , x 2 = 1 , x 3 = − 1 , ... x_1 = -1, x_2 = 1, x_3 = -1, \dots </math>x1=−1,x2=1,x3=−1,... 周期震荡,永远不收敛
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 2.5 \eta = 2.5 </math>η=2.5: <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 = − 1.5 , x 2 = 2.25 , x 3 = − 3.375 , ... x_1 = -1.5, x_2 = 2.25, x_3 = -3.375, \dots </math>x1=−1.5,x2=2.25,x3=−3.375,... 发散
把这五条轨迹画在同一坐标系里,会更直观地看到学习率的临界值(这里是 <math xmlns="http://www.w3.org/1998/Math/MathML"> η c = 2 \eta_c = 2 </math>ηc=2,等于 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 / f ′ ′ 2/f'' </math>2/f′′):低于它收敛、等于它震荡、高于它发散。
深度网络里这个临界值依赖于损失曲面的最大特征值,每个时刻都不一样、每个参数维度也不一样,所以才需要 Adam 这样的自适应方案。
1.7 高维空间与低维直觉的差异
我们的几何直觉来自二维和三维世界。在二维里画一张地形图,山峰、山谷、鞍点是常见的;但更严格地说,「鞍点」至少需要两个独立方向才能定义,一维里更准确的说法是驻点或拐点。
随着维度升高,发生「所有方向上都是局部极小」这件事变得指数级地稀有。一个临界点要成为局部极小,要求 Hessian 矩阵在所有方向上都半正定;如果 Hessian 的特征值分布近似随机,在 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 维空间里所有特征值都为正的概率大致是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 − n 2^{-n} </math>2−n。在百万维参数空间里这个概率可以忽略不计。
这就是「鞍点比局部极小多得多」的数学本质。Dauphin 等人 2014 年的论文《Identifying and attacking the saddle point problem in high-dimensional non-convex optimization》给了这个观察的严格分析,也提出了一些专门处理鞍点的方法。但实践中 SGD + Adam 已经能自然「滑过」绝大多数鞍点,所以这些专门方法在工程上并不常用。
二、神经网络的损失函数与梯度
2.1 损失:我们想最小化什么
在第 04 篇里我们说神经网络是一个可调函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ = f ( x ; θ ) \hat{y} = f(x; \theta) </math>y^=f(x;θ)。「学习」就是调整参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ,让这个函数在训练数据上输出尽量接近真实标签。
但「接近」是一个模糊的词,工程上必须把它定义成一个具体的标量------这个标量就是损失函数 (loss function),通常记作 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) L(\theta) </math>L(θ)。
最经典的两种损失:回归任务用均方误差
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L MSE = 1 2 N ∑ i ( y ^ i − y i ) 2 L_{\text{MSE}} = \frac{1}{2N} \sum_i (\hat{y}_i - y_i)^2 </math>LMSE=2N1i∑(y^i−yi)2
分类任务用交叉熵
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L CE = − 1 N ∑ i ∑ c y i , c log y ^ i , c L_{\text{CE}} = -\frac{1}{N}\sum_i \sum_c y_{i,c} \log \hat{y}_{i,c} </math>LCE=−N1i∑c∑yi,clogy^i,c
它们各自的物理含义、概率解释、和最大似然估计的关系,是另一篇专门的内容(后面 27、30 篇会深入讲)。这里只需要记住一个关键事实:损失函数把「网络输出和标签的差距」量化成一个可对参数求导的标量。
可求导这一点是梯度下降能用的前提;而把整个网络设计成处处可微(除了一些零测度集合上的非可微点比如 ReLU 的尖角),是深度学习能跑起来的工程前提。
「处处可微」这件事其实没有看起来那么 trivial。max-pool、ReLU、argmax、采样这些操作都不是处处可微的,但深度学习社区花了很大力气把它们「亚梯度化」(subgradient)或者「松弛化」(如 Gumbel-Softmax)以便嵌入 backprop。第 33 篇讲 RLHF 时会回到「不可微操作的梯度估计」这件事。
2.2 损失曲面:我们站在哪
把 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) L(\theta) </math>L(θ) 画出来,就是损失曲面(loss landscape)。
在二维(两个参数)时它是一张可视化的地形图;在百万维时它没法画,但我们仍然可以想象:起伏的高维空间,里面有山峰、有山谷、有狭长的盆地、有几乎平坦的高原、有看似最低实则只是局部的洼地。
每一次训练,本质就是从随机初始化开始,沿着某条不可见的轨迹在这片高维地形里行走,希望走到一个足够低的位置。
这里立刻冒出几个有趣的问题。
第一,初始化重要吗? 答案是非常重要。后面 25、36 篇会讲到 Xavier、He、以及 Transformer 训练初期的爆炸问题,几乎都和初始化挂钩。一个错误的初始化会让前向激活方差爆炸或衰减,反向梯度同样爆炸或消失,模型从一开始就处在梯度链路的崩溃边缘。
第二,能保证走到「最低」吗? 不能,但实践中走到一个「足够好」的极小已经够用。这里的「足够好」是相对训练目标而言的------在 cross-entropy 上达到接近熵下界、在测试集上达到目标准确率,这些都是足够好的指标。
第三,不同起点会走到同一个解吗? 通常不会。但有一个非常令人意外的观察:现代大模型在不同种子下训出来的损失曲线几乎一致,最终性能差异常常不超过零点几个百分点------损失曲面虽然非凸,但在大模型尺度下表现得像「处处差不多」。
这种「大模型损失面变平」的现象在多个 scaling 工作中都被观察到。一个解释是:模型容量足够大时,几乎所有初始化都被训练拉到同一个「等价类」的解空间,那里的极小值在性能上等价。
2.3 数值微分:朴素却昂贵的近似
到这里有人会反问:既然要求梯度,为什么不直接用导数定义?
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ θ i ≈ L ( θ + ϵ e i ) − L ( θ − ϵ e i ) 2 ϵ \frac{\partial L}{\partial \theta_i} \approx \frac{L(\theta + \epsilon e_i) - L(\theta - \epsilon e_i)}{2\epsilon} </math>∂θi∂L≈2ϵL(θ+ϵei)−L(θ−ϵei)
这就是中心差分法,在数值分析里非常经典。它对每一个参数都需要做两次前向传播。对一个一亿参数的网络来说,算一次完整梯度需要做两亿次前向,每次前向本身就是亿级浮点运算------这件事在工程上完全不可行。
数值微分还有别的问题:浮点截断误差、 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 大小的取舍。 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 太大近似不准, <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 太小会被浮点抵消放大相对误差,最优 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 大约在 <math xmlns="http://www.w3.org/1998/Math/MathML"> 机器精度 \sqrt{\text{机器精度}} </math>机器精度 的量级,单精度下大约是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 4 10^{-4} </math>10−4 量级。但即便这些问题都不存在,「O(参数量) 次前向」这一条就足以判它死刑。
我们需要一种算法,能在一次前向 + 一次反向就把所有参数的梯度全部算出来。这就是反向传播。
反向传播只在调试时偶尔会和数值微分一起出现:写 PyTorch 自定义算子时,社区强烈建议用 torch.autograd.gradcheck(内部就是用数值微分作 ground truth)来验证你的反向实现是不是正确。这是数值微分在 21 世纪还能找到工程价值的少数场景。
2.4 损失函数与概率:为什么交叉熵几乎一统天下
简单说一下交叉熵为什么在分类任务上几乎一统天下,因为它和后面讨论 softmax、训练目标的内容紧密相关。
如果模型输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 解释成「样本属于各类别的概率」,那么训练目标可以写成「最大化训练数据出现的对数似然」,等价于最小化负对数似然 <math xmlns="http://www.w3.org/1998/Math/MathML"> − ∑ i log p ( y i ∣ x i ; θ ) -\sum_i \log p(y_i \mid x_i; \theta) </math>−∑ilogp(yi∣xi;θ)。当 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p 是 softmax 输出时,这个负对数似然就是交叉熵。
更深一层:交叉熵和 KL 散度只差一个常数(数据自身的熵),所以「最小化交叉熵 = 最小化模型预测分布与真实分布之间的 KL 散度」。这件事在 Transformer 的语言模型预训练目标里再次出现------预测下一个 token 的 cross entropy 就是估计真实语言分布的 KL 距离。第 30 篇会详谈。
2.5 二阶信息:Hessian 矩阵长什么样
梯度只是损失函数的「一阶信息」,还有二阶信息------Hessian 矩阵。
Hessian 是所有二阶偏导组成的方阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> H i j = ∂ 2 L / ( ∂ θ i ∂ θ j ) H_{ij} = \partial^2 L / (\partial \theta_i \partial \theta_j) </math>Hij=∂2L/(∂θi∂θj)。它描述了损失曲面在当前点附近的曲率。Newton 法用 Hessian 的逆来选取更新方向 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − H − 1 ∇ L \theta_{t+1} = \theta_t - H^{-1} \nabla L </math>θt+1=θt−H−1∇L,理论上比梯度下降快得多------Newton 法一步就能到二次函数的最小值。
但 Hessian 在深度学习里几乎不可用:一个一亿参数模型的 Hessian 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 8 × 1 0 8 = 1 0 16 10^8 \times 10^8 = 10^{16} </math>108×108=1016 个数,光存储就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 16 × 4 10^{16} \times 4 </math>1016×4 字节 = 40 PB。即便用 Hessian-vector 乘积(HVP)等技巧避免显式存储,每次迭代的成本仍然远超梯度下降。
后续在工业界出现的一些「准二阶」方法(K-FAC、Shampoo、Sophia)尝试用对角块或 Kronecker 因子近似 Hessian,在小批量大模型训练上已经有一些成功(比如 2024 年 Google 在某些任务上用 Shampoo 替代 Adam,训练步数减少 30% 以上)。但 Adam 仍然是工业界默认选择,原因是它的工程稳定性和可调性更可控。
2.6 损失景观可视化:从二维切片看大模型
百万维曲面没法画,但有一个常用的可视化技巧:在参数空间里取两个随机方向 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 1 , u 2 u_1, u_2 </math>u1,u2,画 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ∗ + α u 1 + β u 2 ) L(\theta^* + \alpha u_1 + \beta u_2) </math>L(θ∗+αu1+βu2) 在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( α , β ) (\alpha, \beta) </math>(α,β) 平面上的等高线。
Li et al. 2018 的论文《Visualizing the Loss Landscape of Neural Nets》系统做了这件事,图像非常震撼:
- 没有残差连接的深网络,损失面像「破碎的悬崖」,到处是陡峭的局部极小;
- 加上残差连接后,损失面变得平滑、近似凸,从初始点到最终解之间几乎没有「山」要翻;
- 网络越宽(每层神经元越多),损失面越平。
这些可视化结果给「为什么残差连接管用」「为什么宽网络更好训」这些直觉提供了图像证据。Transformer 的设计------堆很多层 + 每层都有残差 + LayerNorm 归一化------让损失景观尽可能平,这是它能 scale 到千亿参数的工程根基。
2.7 凸函数的几个简单事实
虽然神经网络非凸,但凸函数仍然是理解优化算法的起点。
凸函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 满足 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( λ x + ( 1 − λ ) y ) ≤ λ f ( x ) + ( 1 − λ ) f ( y ) f(\lambda x + (1-\lambda)y) \le \lambda f(x) + (1-\lambda)f(y) </math>f(λx+(1−λ)y)≤λf(x)+(1−λ)f(y)。直观上:曲线上任意两点的连线段在曲线上方。
凸函数的几个性质对深度学习仍然有教学意义:
- 任何局部极小都是全局极小。这是凸函数最重要的性质,也是为什么经典优化理论以凸为中心。
- 梯度为零是充分条件 :在凸函数上, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ f = 0 \nabla f = 0 </math>∇f=0 就保证是全局极小;非凸时则只能保证「驻点」,可能是鞍点或极大值。
- 二阶条件:Hessian 处处半正定。
线性回归的 loss、logistic 回归的 loss 都是凸函数;soft-margin SVM 的 hinge loss 是凸的;普通 MLP 加 ReLU 的 loss 几乎从不是凸的。但「凸近似」在深度学习里常被用作分析工具------比如把损失景观在最优点附近做二阶展开,得到一个局部凸的二次近似。
三、链式法则与反向传播
3.1 高中链式法则的复习
链式法则不是什么深奥的东西。
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = f ( g ( x ) ) y = f(g(x)) </math>y=f(g(x)),那么 <math xmlns="http://www.w3.org/1998/Math/MathML"> d y d x = f ′ ( g ( x ) ) ⋅ g ′ ( x ) \frac{dy}{dx} = f'(g(x)) \cdot g'(x) </math>dxdy=f′(g(x))⋅g′(x)。再多一层: <math xmlns="http://www.w3.org/1998/Math/MathML"> y = f ( g ( h ( x ) ) ) y = f(g(h(x))) </math>y=f(g(h(x))),那么 <math xmlns="http://www.w3.org/1998/Math/MathML"> d y d x = f ′ ⋅ g ′ ⋅ h ′ \frac{dy}{dx} = f' \cdot g' \cdot h' </math>dxdy=f′⋅g′⋅h′。也就是说,复合函数的导数等于沿着函数路径每一段局部导数的乘积。
这条规则在高中就学过,但从来没人告诉我们:把它推广到「复合上千个函数」「每个函数有亿级参数」的情形,并且找一个高效的计算顺序,就是反向传播。
把这件事推广到多变量:如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 经过若干中间变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z 1 , z 2 , ... , z L z_1, z_2, \dots, z_L </math>z1,z2,...,zL 才依赖到某个参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ,那么
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ θ = ∂ L ∂ z L ⋅ ∂ z L ∂ z L − 1 ⋯ ∂ z 2 ∂ z 1 ⋅ ∂ z 1 ∂ θ \frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial z_L} \cdot \frac{\partial z_L}{\partial z_{L-1}} \cdots \frac{\partial z_2}{\partial z_1} \cdot \frac{\partial z_1}{\partial \theta} </math>∂θ∂L=∂zL∂L⋅∂zL−1∂zL⋯∂z1∂z2⋅∂θ∂z1
这条公式就是反向传播的全部数学基础。神经网络的所有「神奇」全部建立在它之上。
但请注意,这条公式只是说「梯度怎么算」,并没有说「该按什么顺序算」。反向传播之所以叫「反向」,是因为它选择从右往左乘------从损失开始一步步往输入方向算回去。
这种顺序之所以高效,是因为它复用了大量中间结果,让计算复杂度从「O(参数量) 次前向」(如数值微分)降到了「一次反向遍历」。下面用一个具体的两层网络把这件事讲清楚。
3.2 一个两层网络的手算反向传播
考虑最小可重现的两层网络。输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R d x \in \mathbb{R}^d </math>x∈Rd,先过线性层得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> z 1 = W 1 x + b 1 z_1 = W_1 x + b_1 </math>z1=W1x+b1,再过激活 <math xmlns="http://www.w3.org/1998/Math/MathML"> a 1 = σ ( z 1 ) a_1 = \sigma(z_1) </math>a1=σ(z1),再过第二层线性 <math xmlns="http://www.w3.org/1998/Math/MathML"> z 2 = W 2 a 1 + b 2 z_2 = W_2 a_1 + b_2 </math>z2=W2a1+b2,最后输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ = z 2 \hat{y} = z_2 </math>y^=z2。损失用均方误差 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 1 2 ∥ y ^ − y ∥ 2 L = \frac{1}{2}\|\hat{y} - y\|^2 </math>L=21∥y^−y∥2。
我们想算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ W 1 , ∂ L / ∂ b 1 , ∂ L / ∂ W 2 , ∂ L / ∂ b 2 \partial L / \partial W_1, \partial L/\partial b_1, \partial L/\partial W_2, \partial L/\partial b_2 </math>∂L/∂W1,∂L/∂b1,∂L/∂W2,∂L/∂b2 这四个梯度。
按链式法则从右往左推。先算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ y ^ \partial L / \partial \hat{y} </math>∂L/∂y^。损失对 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 的导数很简单: <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ y ^ = y ^ − y \partial L / \partial \hat{y} = \hat{y} - y </math>∂L/∂y^=y^−y,这是一个跟 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 同维的向量。把这个向量记作 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ 2 \delta_2 </math>δ2(第二层的「误差信号」)。
一步往前推,因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ = z 2 \hat{y} = z_2 </math>y^=z2,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ z 2 = δ 2 \partial L / \partial z_2 = \delta_2 </math>∂L/∂z2=δ2。
接下来要算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ W 2 \partial L / \partial W_2 </math>∂L/∂W2 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ b 2 \partial L / \partial b_2 </math>∂L/∂b2。 <math xmlns="http://www.w3.org/1998/Math/MathML"> z 2 = W 2 a 1 + b 2 z_2 = W_2 a_1 + b_2 </math>z2=W2a1+b2,对 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 2 W_2 </math>W2 求偏导得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ z 2 / ∂ W 2 = a 1 \partial z_2 / \partial W_2 = a_1 </math>∂z2/∂W2=a1(更准确说,因为是矩阵,写成外积形式 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ W 2 = δ 2 ⋅ a 1 T \partial L / \partial W_2 = \delta_2 \cdot a_1^T </math>∂L/∂W2=δ2⋅a1T);对 <math xmlns="http://www.w3.org/1998/Math/MathML"> b 2 b_2 </math>b2 求偏导是 1,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ b 2 = δ 2 \partial L/\partial b_2 = \delta_2 </math>∂L/∂b2=δ2。这两个梯度算完,第二层的工作就结束了。
然后把误差信号往前传递到 <math xmlns="http://www.w3.org/1998/Math/MathML"> a 1 a_1 </math>a1:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L / ∂ a 1 = W 2 T δ 2 \partial L / \partial a_1 = W_2^T \delta_2 </math>∂L/∂a1=W2Tδ2
再过激活函数的局部导数 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ a 1 / ∂ z 1 = σ ′ ( z 1 ) \partial a_1 / \partial z_1 = \sigma'(z_1) </math>∂a1/∂z1=σ′(z1),得到第一层的误差信号:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ 1 = ( W 2 T δ 2 ) ⊙ σ ′ ( z 1 ) \delta_1 = (W_2^T \delta_2) \odot \sigma'(z_1) </math>δ1=(W2Tδ2)⊙σ′(z1)
( <math xmlns="http://www.w3.org/1998/Math/MathML"> ⊙ \odot </math>⊙ 是逐元素乘)。最后用同样的方式算第一层的梯度:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L / ∂ W 1 = δ 1 ⋅ x T , ∂ L / ∂ b 1 = δ 1 \partial L / \partial W_1 = \delta_1 \cdot x^T, \qquad \partial L / \partial b_1 = \delta_1 </math>∂L/∂W1=δ1⋅xT,∂L/∂b1=δ1
整个过程一气呵成。它有一个非常优美的结构:每一层只需要从下游接收一个误差信号 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ \delta </math>δ,就能算出本层参数的梯度,并向上一层传递新的 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ \delta </math>δ。
每一层的工作量正比于它自己的参数量。把所有层加起来,反向传播的总工作量正比于整个网络的参数量------和一次前向同阶。这就是为什么 backprop 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 参数量 ) O(参数量) </math>O(参数量) 而不是数值微分的 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 参数 量 2 ) O(参数量^2) </math>O(参数量2)。
这一性质决定了深度学习能不能「scale 起来」。如果反向传播也是平方级,今天就不会有任何深度模型存在;连 LeNet 都跑不起来。
3.3 一个数值例子:把抽象变具体
为了让上面的推导有体感,给一个最小数值例子。
假设输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x = 2 x = 2 </math>x=2(标量),参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 = 0.5 , b 1 = 0.1 , W 2 = 0.3 , b 2 = − 0.2 W_1 = 0.5, b_1 = 0.1, W_2 = 0.3, b_2 = -0.2 </math>W1=0.5,b1=0.1,W2=0.3,b2=−0.2,激活函数取 sigmoid <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( z ) = 1 / ( 1 + e − z ) \sigma(z) = 1/(1+e^{-z}) </math>σ(z)=1/(1+e−z),标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = 1 y = 1 </math>y=1。
前向:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> z 1 = 0.5 × 2 + 0.1 = 1.1 z_1 = 0.5 \times 2 + 0.1 = 1.1 </math>z1=0.5×2+0.1=1.1
- <math xmlns="http://www.w3.org/1998/Math/MathML"> a 1 = σ ( 1.1 ) ≈ 0.7503 a_1 = \sigma(1.1) \approx 0.7503 </math>a1=σ(1.1)≈0.7503
- <math xmlns="http://www.w3.org/1998/Math/MathML"> z 2 = 0.3 × 0.7503 − 0.2 ≈ 0.0251 z_2 = 0.3 \times 0.7503 - 0.2 \approx 0.0251 </math>z2=0.3×0.7503−0.2≈0.0251
- <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ = 0.0251 \hat{y} = 0.0251 </math>y^=0.0251
- <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 1 2 ( 0.0251 − 1 ) 2 ≈ 0.4752 L = \frac{1}{2}(0.0251 - 1)^2 \approx 0.4752 </math>L=21(0.0251−1)2≈0.4752
反向:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> δ 2 = y ^ − y = 0.0251 − 1 = − 0.9749 \delta_2 = \hat{y} - y = 0.0251 - 1 = -0.9749 </math>δ2=y^−y=0.0251−1=−0.9749
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ W 2 = δ 2 ⋅ a 1 ≈ − 0.7315 \partial L/\partial W_2 = \delta_2 \cdot a_1 \approx -0.7315 </math>∂L/∂W2=δ2⋅a1≈−0.7315
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ b 2 = − 0.9749 \partial L/\partial b_2 = -0.9749 </math>∂L/∂b2=−0.9749
- <math xmlns="http://www.w3.org/1998/Math/MathML"> δ 1 = ( W 2 ⋅ δ 2 ) ⋅ σ ′ ( z 1 ) \delta_1 = (W_2 \cdot \delta_2) \cdot \sigma'(z_1) </math>δ1=(W2⋅δ2)⋅σ′(z1),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ′ ( z 1 ) = a 1 ( 1 − a 1 ) ≈ 0.1873 \sigma'(z_1) = a_1 (1 - a_1) \approx 0.1873 </math>σ′(z1)=a1(1−a1)≈0.1873
- <math xmlns="http://www.w3.org/1998/Math/MathML"> δ 1 ≈ 0.3 × ( − 0.9749 ) × 0.1873 ≈ − 0.0548 \delta_1 \approx 0.3 \times (-0.9749) \times 0.1873 \approx -0.0548 </math>δ1≈0.3×(−0.9749)×0.1873≈−0.0548
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ W 1 = δ 1 ⋅ x ≈ − 0.1096 \partial L/\partial W_1 = \delta_1 \cdot x \approx -0.1096 </math>∂L/∂W1=δ1⋅x≈−0.1096
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ b 1 ≈ − 0.0548 \partial L/\partial b_1 \approx -0.0548 </math>∂L/∂b1≈−0.0548
可以看到第一层的梯度量级(0.05)已经显著小于第二层(0.97)。这就是「梯度从输出层往输入层逐层衰减」的最简单例子。如果这是一个 50 层的网络,到第一层时梯度可能小到几乎为零------这就是梯度消失的最初萌芽。
3.4 计算图视角:反向传播的工程实现
把上面的推导抽象一下:一个神经网络其实可以画成一张计算图(computation graph)。
节点是运算(加、乘、ReLU、softmax......),边是数据流。前向传播是按拓扑序从输入节点走到损失节点;反向传播是按反拓扑序从损失节点走回输入。
每个节点只需要知道两件事:自己的输出对自己的输入的局部导数怎么算(这是节点自带的「向后函数」),以及上游传下来的全局梯度。把它们乘起来再传给输入边,就完成了一个节点的反向运算。
这种「每个节点自顾自,框架负责拓扑调度」的抽象正是 PyTorch、TensorFlow、JAX 这些框架的工程内核。
PyTorch 的 autograd 用的是动态图:前向时一边算结果一边构建反向图;TensorFlow 1.x 用的是静态图:先定义图再喂数据。这两种风格在「反向传播是计算图上的反拓扑遍历」这件事上完全一致,差别只在于图什么时候构建、能不能整体优化。
JAX 用的是 trace-and-transform:用 Python 写函数,JAX 在追踪时把它转换成中间表达式,再做反向(或前向)自动微分;这个设计在某些场景下比 PyTorch 更易于做高阶微分(求二阶导数、三阶导数)和函数变换(vmap、pmap)。
理解了计算图视角后,再看反向传播就不再是「一长串链式法则乘起来」,而是一种局部规则 + 全局调度的算法。
每个算子只关心自己那一小块的导数;框架负责把所有算子按正确顺序串起来。Transformer 里那些花里胡哨的算子------multi-head attention、LayerNorm、RoPE、RMSNorm------在反向传播视角下都不过是「定义了前向、定义了后向」的两个函数对,框架照着图自动求出梯度。我们之所以可以「写一个新结构、跑一遍、看效果」,全靠这套抽象。
3.5 自动微分的两种模式:前向模式 vs 反向模式
为完整起见提一下:自动微分(automatic differentiation, AD)其实有两种模式。
反向模式 就是我们前面讲的反向传播------从输出往输入算。它的成本是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( ops ) O(\text{ops}) </math>O(ops),与输入维度无关。当输出是标量、输入是高维(损失函数对参数)时是最优选择,正好对应神经网络训练。
前向模式则相反,从输入往输出算,需要为每个输入维度单独跑一遍。它的成本与输入维度成正比,但与输出维度无关。当输入低维、输出高维时(比如某些物理仿真里的雅可比矩阵),前向模式更划算。
JAX 同时支持两种(jax.grad 是反向,jax.jvp 是前向)。在神经网络训练里几乎只用反向模式;前向模式偶尔出现在二阶优化、敏感性分析等场景。第 53 篇讲机制可解释性时,部分归因方法(比如 Path Patching)会用到前向模式或两者混合。
3.6 为什么不能直接用「反向自动微分」搞所有事
到这里有人会问:既然 autograd 这么好用,为什么还有那么多研究讨论梯度估计、变分推断、直通估计器(Straight-Through Estimator)这些东西?
答案是:autograd 只能处理「每个节点都可微」的图。一旦图里出现采样、离散选择、argmax 这些不可微的操作,反向传播就走不通了。
Transformer 里的常规操作几乎全是可微的,所以训练顺风顺水;但当我们想训练「带采样的策略网络」(RLHF 里就有),梯度就得借助策略梯度、Gumbel-Softmax 等技巧绕一下。第 33 篇讲 RLHF 时会回到这件事。
另一个相关问题:内存。反向传播需要保存前向时的所有中间激活,因为它们在反向时要用到。一个 100 层、batch size 32、序列长 1024 的 Transformer,激活内存可能高达数十 GB。梯度检查点(gradient checkpointing)是一种常用补救:只保留若干「锚点」激活,反向时再前向重算其余的。它用计算换内存,可以让大模型训练在更小显存的卡上跑。后续 llm-infra 系列会详细讲。
3.7 一个常被忽略的事实:反向传播的内存代价
反向传播虽然计算上 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 参数量 ) O(\text{参数量}) </math>O(参数量),但内存上不便宜。
要算反向,需要保存前向时的所有中间激活------因为局部导数 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ z ℓ + 1 / ∂ z ℓ \partial z_{\ell+1} / \partial z_\ell </math>∂zℓ+1/∂zℓ 通常依赖 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ℓ z_\ell </math>zℓ 的值。一个 GPT-3 175B 模型,单次前向激活内存可能达到几百 GB,单卡装不下。
这是为什么大模型训练在工程上必须分布式,并大量使用「梯度检查点」(gradient checkpointing):只保留每隔 <math xmlns="http://www.w3.org/1998/Math/MathML"> L \sqrt{L} </math>L 层的激活,反向时再前向重算,把内存从 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( L ) O(L) </math>O(L) 降到 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( L ) O(\sqrt{L}) </math>O(L ),代价是计算量约增加 33%。
PyTorch 的 torch.utils.checkpoint 实现了这个机制;FlashAttention 在 Attention 这一段做了更精细的内存优化,让长序列下的 Attention 反向不再需要存中间矩阵。第 42 篇会专门讲。
3.8 检查反向是否实现正确:gradcheck
写自定义 PyTorch 算子时,怎么验证你的 backward 函数写对了?
社区标准做法是 torch.autograd.gradcheck:它用数值微分(中心差分)作为 ground truth,对你的解析梯度做小批量验证。如果误差超过阈值(默认 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 6 10^{-6} </math>10−6 量级),就报错。
这是数值微分在现代深度学习里仍然有用的少数场景之一------不是当作训练算法,而是作为「单元测试」。
四、梯度消失与爆炸:连乘的诅咒
4.1 梯度链上的乘积放大与衰减
链式法则的乘法形式带来一个副作用:层数一多,梯度沿着链路连乘,要么指数衰减、要么指数爆炸。
具体来说,假设网络有 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 层,每一层的雅可比矩阵的「典型尺度」是 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s(更精确地说是雅可比的奇异值),那么从最后一层传到第一层时,梯度大小大约是 <math xmlns="http://www.w3.org/1998/Math/MathML"> s L s^L </math>sL。
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> s = 0.9 s = 0.9 </math>s=0.9、 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 50 L = 50 </math>L=50,梯度就被压缩到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.005 0.005 </math>0.005;如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> s = 1.1 s = 1.1 </math>s=1.1、 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 50 L = 50 </math>L=50,则爆炸到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 117 117 </math>117 量级。前者叫梯度消失 (vanishing gradient),意味着浅层几乎拿不到学习信号;后者叫梯度爆炸(exploding gradient),意味着一次更新就能把参数推到天边。
这件事在深度网络里曾经长期是个大麻烦。
1990 年代到 2000 年代中期,超过 5 层的全连接网络几乎训不动,原因就是梯度消失。Hochreiter 1991 年的硕士论文(德语原版)和 Bengio 等 1994 年的论文《Learning long-term dependencies with gradient descent is difficult》第一次系统分析了这个问题。
后来 Hochreiter 和 Schmidhuber 1997 年提出 LSTM,本质就是给梯度修了一条「不会被反复连乘」的高速通道(详见第 09 篇)。
再往后深度卷积网络里的 ResNet(He et al. 2015)用残差连接(residual connection)解决了同一个问题:让信息可以走「捷径」绕过一层,等价于把雅可比写成 <math xmlns="http://www.w3.org/1998/Math/MathML"> J = I + R J = I + R </math>J=I+R,只要 <math xmlns="http://www.w3.org/1998/Math/MathML"> R R </math>R 不太大, <math xmlns="http://www.w3.org/1998/Math/MathML"> J J </math>J 的奇异值就接近 1,连乘起来不会爆炸也不会消失。
Transformer 把残差连接和 LayerNorm 一起用,也是出于同样的原因------后续第 24、25 篇会深入。
4.2 一个数值上的实验
为了让连乘的影响有体感,做一个简化思想实验。
假设我们有一个 100 层的 MLP,每层用 sigmoid 激活。Sigmoid 的导数最大值是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / 4 1/4 </math>1/4(在 <math xmlns="http://www.w3.org/1998/Math/MathML"> z = 0 z=0 </math>z=0 处)。也就是说,每一层贡献给反向梯度的「sigmoid 局部导数」最大不超过 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.25 0.25 </math>0.25。
那么从 100 层往前传,仅仅 sigmoid 这一项就会衰减到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.2 5 100 ≈ 6 × 1 0 − 61 0.25^{100} \approx 6 \times 10^{-61} </math>0.25100≈6×10−61。这是一个比单精度浮点最小正常数( <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 38 10^{-38} </math>10−38)还要小几十个数量级的数------直接就是零。
这就是为什么 sigmoid 在深层网络里几乎不能用。ReLU 的导数要么是 0 要么是 1,期望值大约 0.5;GELU 的导数在大部分区间接近 1,所以它们在深网络里更稳定。第 05 篇专门讲过激活函数选择,背后的根本原因就是这件事。
4.3 现代办法:从架构到训练技巧
今天处理梯度消失/爆炸的常用手段,可以大致分成三类。
第一类是架构层面:残差连接(ResNet、Transformer)、归一化层(BatchNorm、LayerNorm、RMSNorm)、合理的激活函数(ReLU、GELU、SwiGLU)、合理的初始化(Xavier、He、SP-init)。
第二类是优化器层面 :自适应优化器(Adam、AdamW)天生对梯度尺度有归一化效果,能在一定程度上缓解爆炸。具体说,Adam 的更新量大致是 <math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ t / v ^ t \hat{m}_t / \sqrt{\hat{v}_t} </math>m^t/v^t ,当梯度本身很小时这个比值仍然在合理量级;当梯度突然变大时分母也变大,自动「踩刹车」。
第三类是训练技巧:梯度裁剪(gradient clipping)把梯度范数硬限制在某个阈值以下、混合精度训练(mixed precision)里的动态损失缩放(loss scaling)防止 fp16 下溢。Transformer 训练几乎把这三类全都用上。
但这些手段不是免费的。
残差连接虽然救了梯度,但本身需要前向、反向多算一次加法;LayerNorm 在每一步都要算均值方差,反向传播也要算它的导数;梯度裁剪本质是改了梯度方向(裁剪后已经不是原来的梯度),训练动力学因此变得更复杂;混合精度的 loss scaling 调不好可能直接 NaN。
这就是为什么训练大模型从来不是「调出超参就行」,每一个看似无害的工程决策背后都连着梯度链路上的某个细节。
4.4 初始化:让梯度链一开始就在合理量级
初始化是控制梯度尺度的第一关。
最朴素的初始化是「全部填随机数」(高斯 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 1 ) \mathcal{N}(0, 1) </math>N(0,1)),但这在深网络里立刻爆炸------每过一层方差大致乘以「输入维度 × 权重方差」,多层下来量级失控。
Xavier 初始化(Glorot & Bengio 2010)对线性激活做了精确的方差守恒推导,给出 <math xmlns="http://www.w3.org/1998/Math/MathML"> W ∼ N ( 0 , 2 / ( n in + n out ) ) W \sim \mathcal{N}(0, 2/(n_{\text{in}} + n_{\text{out}})) </math>W∼N(0,2/(nin+nout));He 初始化(He et al. 2015)针对 ReLU 多了一个因子 2,写成 <math xmlns="http://www.w3.org/1998/Math/MathML"> W ∼ N ( 0 , 2 / n in ) W \sim \mathcal{N}(0, 2/n_{\text{in}}) </math>W∼N(0,2/nin)。
在 Transformer 里,原论文用了一个简单方案:embedding 用 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 1 ) \mathcal{N}(0, 1) </math>N(0,1)、线性层用 Xavier。后来的工作(如 GPT-NeoX、Megatron)又加了一些细节------比如对残差分支后的线性层除一个额外的 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 L \sqrt{2L} </math>2L (其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 是层数)来防止深网络放大,这叫「scaled init」。
每一种初始化的背后都是「让方差在前向、反向传播时近似守恒」这一原则。理解了 4.1 节的连乘机制,初始化背后的数学就不再神秘------它只是在调一个让 <math xmlns="http://www.w3.org/1998/Math/MathML"> s ≈ 1 s \approx 1 </math>s≈1 的开关。
4.5 ResNet 的数学解释
为什么残差连接 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = x + F ( x ) y = x + F(x) </math>y=x+F(x) 能解决梯度消失?让我们用反向传播的视角具体看。
考虑一个 ResNet 块的反向:上游传下来梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ y \partial L / \partial y </math>∂L/∂y,我们要算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ x \partial L / \partial x </math>∂L/∂x。链式法则给出:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ x = ∂ L ∂ y ⋅ ∂ y ∂ x = ∂ L ∂ y ⋅ ( I + F ′ ( x ) ) \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \cdot (I + F'(x)) </math>∂x∂L=∂y∂L⋅∂x∂y=∂y∂L⋅(I+F′(x))
注意那个 <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I------不管 <math xmlns="http://www.w3.org/1998/Math/MathML"> F ′ ( x ) F'(x) </math>F′(x) 是什么,加上单位矩阵后,反向梯度至少有一条「直达」通道: <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ x ⊇ ∂ L / ∂ y \partial L / \partial x \supseteq \partial L / \partial y </math>∂L/∂x⊇∂L/∂y(指方向上至少包含)。即便 <math xmlns="http://www.w3.org/1998/Math/MathML"> F ′ ( x ) F'(x) </math>F′(x) 因为 saturated activation 接近零,梯度也不会消失。
把多个残差块串起来,整体的雅可比是 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∏ ℓ ( I + F ℓ ′ ( x ℓ ) ) \prod_\ell (I + F'\ell(x\ell)) </math>∏ℓ(I+Fℓ′(xℓ))。展开后,里面会包含从 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ y L \partial L / \partial y_L </math>∂L/∂yL 到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ x 0 \partial L / \partial x_0 </math>∂L/∂x0 的「单位路径」,跳过任意中间层的子集。这条路径的存在保证了梯度永远有一条不被衰减的高速通道。
这就是为什么 ResNet 能轻松训练 1000 层网络,而普通堆叠超过 20 层就训不动。Transformer 沿用了这个设计,每一层 attention 和 FFN 都被残差包裹,在 GPT-4 那样接近 100 层的深度上仍然能稳定训练。
4.6 一个被反复发现的事实:归一化层让梯度稳定
LayerNorm、BatchNorm、RMSNorm 这些归一化层除了「让激活分布稳定」之外,还有一个隐藏的副作用:让反向梯度的尺度也稳定。
具体说,归一化层的输出是 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = ( x − μ ) / σ y = (x - \mu) / \sigma </math>y=(x−μ)/σ,反向时它的雅可比涉及 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / σ 1/\sigma </math>1/σ。如果某个 batch 里 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ \sigma </math>σ 很小(激活分布塌缩),梯度就会被 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / σ 1/\sigma </math>1/σ 放大。这听起来是坏事,但在实际训练里它扮演了「自动重标定梯度」的角色------当激活分布塌缩时梯度变大,能把分布重新撑开。
这件事在《How Does Batch Normalization Help Optimization?》(Santurkar 2018)里有详细分析,结论是 BN 不是通过「减少 Internal Covariate Shift」起作用,而是通过「让损失景观变平滑」起作用。LayerNorm 的机制大致类似,第 25 篇会展开。
4.7 Attention logits 的数值溢出:Transformer 时代的新爆炸
梯度爆炸在 Transformer 时代以一种新形态出现:attention logits 的数值溢出。
在 Self-Attention 里我们要算 <math xmlns="http://www.w3.org/1998/Math/MathML"> softmax ( Q K T / d k ) \text{softmax}(QK^T / \sqrt{d_k}) </math>softmax(QKT/dk )。如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 的数值幅度因为 LayerNorm 失效或某些极端 token 而变大, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 的某些元素可能达到几百甚至几千。fp16 的最大正常数是 65504------一旦 logit 超过约 11, <math xmlns="http://www.w3.org/1998/Math/MathML"> e logit e^{\text{logit}} </math>elogit 就会溢出成 inf,softmax 输出 NaN,梯度也是 NaN,训练崩溃。
这件事在 BERT、GPT 早期训练里是个重大事故源。后来的解决方案是「logit 裁剪」(logit cap):在 attention 的 softmax 之前把 logit 限制在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 30 , 30 ] [-30, 30] </math>[−30,30] 之类的范围内。Gemma、Grok 等近年模型公开过类似做法。
更深一层的解决方案是 FlashAttention(Dao 2022):把 attention 的 softmax 在 fp32 累加器里完成,再写回 fp16/bf16。这绕过了中间数值不稳定的问题,第 42 篇会展开讲。
理解这件事的关键在于:梯度爆炸不止发生在「反向传播的连乘」里,也可能发生在「前向某个算子的中间数值溢出」里。Transformer 时代的训练稳定性问题,多了一个新的爆炸源头。
五、随机梯度下降与小批量
5.1 全批量太慢,单样本太抖
回到梯度下降。前面写的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − η ∇ L ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) </math>θt+1=θt−η∇L(θt) 里, <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) L(\theta) </math>L(θ) 是「整个数据集上的损失」,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ L \nabla L </math>∇L 也是「整个数据集的梯度」。
算一次梯度需要遍历整个数据集------在 ImageNet 上是 128 万张图、在 Common Crawl 上是几万亿 token------单步成本天文级,整个训练过程跑不完。
另一个极端是随机梯度下降(Stochastic Gradient Descent, SGD):每次只用一个样本估计梯度。这样每一步快得多,但单样本的梯度方差很大,更新方向抖动严重。
后来发现一个折中做法------小批量梯度下降(Mini-batch SGD)------每次随机抽一个小批量(比如 32、64、256 个样本),用这个批量上的梯度作为整体梯度的估计。这是今天所有深度学习训练的默认做法。
它兼具「单步成本可控」和「方差较小」两个优点,并且因为方差仍然存在,反而能帮助训练逃出鞍点和浅局部最小。
5.2 批量大小的微妙
批量大小(batch size)听起来只是工程参数,其实是有数学含义的旋钮。
批量越大,梯度估计越准、方差越小,但每步信息量大、对学习率要求也更高(直观上每步移动距离更大);批量越小,方差越大,但隐含的「噪声正则化」效果更强,泛化往往更好。
Goyal 等人 2017 年的论文《Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour》给出了一个经验法则:批量增大 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 倍时学习率也要增大约 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 倍(线性缩放规则),并配以适当的 warmup。
Transformer 训练里这条法则被反复使用,但它不是绝对的------不同任务上需要的缩放比例并不一致。比如 LLM 预训练里通常用 <math xmlns="http://www.w3.org/1998/Math/MathML"> k \sqrt{k} </math>k 缩放更稳定(因为 Adam 的自适应分母已经吸收了一部分梯度方差信息)。
随机性还带来一个有趣的副作用:训练曲线永远是抖动的。我们看到的损失下降曲线表面上漂亮,放大看其实每一步都在抖。这种抖动既是「估计误差」也是「探索机会」。SGD 在很多任务上比全批量更容易找到好解,并不是因为它精确,而恰恰因为它不精确。
5.3 梯度累积:用计算换更大的有效批量
工程上还有一个技巧:梯度累积(gradient accumulation)。
假设单卡显存只能装下 batch size 4 的前向反向,但你想要等价于 batch size 64 的训练效果,做法是:连续算 16 个 batch 的梯度,每个 batch 算完不立刻更新参数,而是把梯度累加起来;累加到第 16 个 batch 时再做一次参数更新,然后清空累加器。
这等价于一次 batch size <math xmlns="http://www.w3.org/1998/Math/MathML"> 4 × 16 = 64 4 \times 16 = 64 </math>4×16=64 的训练,只是时间慢了 16 倍(节省的是显存)。今天几乎所有大模型预训练都用梯度累积来达到目标的「全局批量大小」。GPT-3 的全局 batch size 是 3.2M token------单卡显然装不下,必须靠数据并行 + 梯度累积凑出来。
5.4 Epoch、Step、Token:三个常被混淆的概念
在 CV 时代我们说「训练 90 个 epoch」;在 NLP 大模型时代很少有人说 epoch,一律说 token 数或 step 数。这不是命名习惯的差异,是训练范式的根本不同。
Epoch:完整遍历一次训练集称为一个 epoch。CV 数据集(ImageNet 128 万张)几个小时就能跑完一个 epoch,训 90 个 epoch 是常态。
Step:一次参数更新称为一个 step。和 batch size、设备数有关------比如 GPT-3 一个 step 处理 3.2M token,一个 epoch(300B token)需要约 9.4 万 step。
Token:训练总共喂进的 token 数。LLM 时代主流指标,因为:(1) 大数据集(几万亿 token)几乎不需要重复,0.5--2 个 epoch 就够;(2) Scaling laws 里说的「训练量」就是 token 数。
混淆这几个概念会犯错。比如「学习率衰减到训练末期」在 CV 里指 epoch 末,在 LLM 里指 token 数末。论文里写「3000 step warmup」和「1B token warmup」是不同的指标。Chinchilla scaling law 里专门用 token 数作为单位,原因就是它和 epoch 解耦、和数据集大小解耦,是更纯净的「训练量」度量。
5.5 同步 SGD 与异步 SGD:分布式下的梯度
大模型训练几乎都是数据并行------多张卡各自处理一部分 batch,再把梯度聚合。这里有两种典型策略:
同步 SGD(synchronous SGD):所有 worker 算完梯度后做一次 all-reduce,得到平均梯度,所有 worker 用同样的梯度做更新。等价于一次大批量更新,训练动力学清晰、可复现。
异步 SGD(asynchronous SGD):每个 worker 独立算梯度独立更新,参数服务器接收所有更新。延迟更小但梯度「过时」(stale),对收敛性有损。早期 Google DistBelief、Hogwild! 用过这种方案,今天大模型训练几乎不用------因为同步 SGD 配合 NCCL all-reduce 已经足够快。
LLM 工业界默认用同步 SGD,代价是任何一个 worker 慢都会拖慢全局(straggler 问题)。这也是 ZeRO、FSDP 这些技术要解决的工程问题。
六、动量与自适应:从 SGD 到 Adam
6.1 动量(Momentum):把历史方向考虑进来
回到前面那个椭圆形损失曲面的尴尬场面:负梯度方向并不指向最低点,而是垂直于等高线,所以参数在窄轴上反复震荡。
一个朴素但非常有效的修补办法是把「历史梯度方向」也考虑进来。Polyak 1964 年提出的动量法(Momentum)就是这件事:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> v t + 1 = μ v t + ∇ L ( θ t ) , θ t + 1 = θ t − η v t + 1 v_{t+1} = \mu v_t + \nabla L(\theta_t), \qquad \theta_{t+1} = \theta_t - \eta v_{t+1} </math>vt+1=μvt+∇L(θt),θt+1=θt−ηvt+1
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt 是「速度」, <math xmlns="http://www.w3.org/1998/Math/MathML"> μ ∈ [ 0 , 1 ) \mu \in [0, 1) </math>μ∈[0,1) 是动量系数(典型值 0.9)。
直觉上这个公式说:每一步不止看当前梯度,还把过去所有梯度加权累计起来。如果一个方向上梯度反复一致,速度就在这个方向上累加;如果一个方向上梯度反复反向,正负相消。
最终的效果是:沿着谷底方向加速,沿着窄轴方向阻尼------震荡被压住,进入谷底速度更快。
把动量理解成「带惯性的小球沿着曲面滚」是经典比喻。 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ 是「摩擦系数」的反面------ <math xmlns="http://www.w3.org/1998/Math/MathML"> μ = 0 \mu = 0 </math>μ=0 完全没惯性,等于普通 SGD; <math xmlns="http://www.w3.org/1998/Math/MathML"> μ → 1 \mu \to 1 </math>μ→1 接近无摩擦,越走越快但也越难刹车。
Nesterov 1983 年提出了一个稍微聪明一点的变体,叫 Nesterov Accelerated Gradient(NAG):先按动量方向「往前看一步」,再在那一步算梯度。
直观上 NAG 比 momentum 更早感知到「前方要拐弯」,所以转弯更及时;理论上 NAG 在凸优化里有更好的收敛速度( <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 / t 2 ) O(1/t^2) </math>O(1/t2) vs <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 / t ) O(1/t) </math>O(1/t))。对非凸的深度网络来说,差异不大但它常常被当作 momentum 的「升级版」用。
6.2 自适应学习率:RMSProp 与 AdaGrad
动量解决了「方向」问题,但没解决「步长」问题:不同参数维度上的合理步长应该不一样。
AdaGrad(Duchi 2011)第一次系统地把这件事搬到深度学习------给每个参数维护一个「历史梯度平方和」,用它的平方根做分母去缩放学习率:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s t = s t − 1 + g t 2 , θ t + 1 = θ t − η s t + ϵ g t s_{t} = s_{t-1} + g_t^2, \qquad \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{s_t} + \epsilon} g_t </math>st=st−1+gt2,θt+1=θt−st +ϵηgt
历史梯度大的方向,分母大,步长小;历史梯度小的方向,分母小,步长大。
在稀疏特征学习(推荐、广告)里 AdaGrad 用得很好。它特别适合「有些特征几乎不出现,但出现一次就要狠学一下」的场景。
但它有一个致命缺点: <math xmlns="http://www.w3.org/1998/Math/MathML"> s t s_t </math>st 是单调递增的,时间一长所有方向的步长都被压到接近零,模型再也走不动了。
RMSProp(Hinton 2012 年课件里提出,没有正式论文)是 AdaGrad 的改良:把累加换成指数滑动平均:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s t = β s t − 1 + ( 1 − β ) g t 2 s_t = \beta s_{t-1} + (1-\beta) g_t^2 </math>st=βst−1+(1−β)gt2
这样 <math xmlns="http://www.w3.org/1998/Math/MathML"> s t s_t </math>st 不再单调递增,而是反映「最近一段时间的梯度平方均值」。它解决了 AdaGrad 走不动的问题,又保留了「自适应步长」的好处。
RMSProp 在 RNN 训练里曾经非常流行(DQN 等强化学习经典工作都用它),但在 Transformer 时代被 Adam 全面取代------Adam 本质就是「RMSProp + Momentum + 偏差修正」。
6.3 Adam:动量 + 自适应一锅端
Adam(Kingma & Ba, 2014 年论文《Adam: A Method for Stochastic Optimization》)把 momentum 和 RMSProp 的思路合在一起,并加了一个偏差修正:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t </math>mt=β1mt−1+(1−β1)gt
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 </math>vt=β2vt−1+(1−β2)gt2
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \qquad \hat{v}t = \frac{v_t}{1 - \beta_2^t} </math>m^t=1−β1tmt,v^t=1−β2tvt
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ t + 1 = θ t − η ⋅ m ^ t v ^ t + ϵ \theta{t+1} = \theta_t - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} </math>θt+1=θt−η⋅v^t +ϵm^t
这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 = 0.9 , β 2 = 0.999 , ϵ = 1 0 − 8 \beta_1 = 0.9, \beta_2 = 0.999, \epsilon = 10^{-8} </math>β1=0.9,β2=0.999,ϵ=10−8 是默认值。
<math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt 是梯度的一阶矩(动量), <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt 是二阶矩(梯度平方的滑动均值)。 <math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ t \hat{m}_t </math>m^t、 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t 是「偏差修正」------因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> m 0 = v 0 = 0 m_0 = v_0 = 0 </math>m0=v0=0,初期估计偏小,需要除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − β t 1 - \beta^t </math>1−βt 来拉回来。
最后那一步看起来像 momentum 和 RMSProp 的乘积:动量决定方向,自适应分母决定步长。
Adam 不是没有缺点。它对超参(特别是 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 \beta_2 </math>β2)敏感,在某些任务上泛化比 SGD-with-momentum 差,2017 年 Wilson 等人的论文《The Marginal Value of Adaptive Gradient Methods in Machine Learning》指出过这个现象。
但在 Transformer 训练里,几乎没人用纯 SGD------原因后面会说。
6.4 AdamW:权重衰减的修正
2019 年 Loshchilov 和 Hutter 在论文《Decoupled Weight Decay Regularization》里指出:当我们在 Adam 里同时使用 L2 正则化时,正则项会被自适应分母缩放,效果会被破坏。
具体来说,「L2 正则」意味着在梯度里加一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ θ \lambda \theta </math>λθ,然后整个梯度被 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ \sqrt{\hat{v}} </math>v^ 缩放------结果是参数维度上 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v 大的地方衰减弱, <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v 小的地方衰减强,跟「权重衰减」想要的「均匀拉回到零」效果完全不一致。
他们提出 AdamW,把权重衰减从梯度里拿出来,单独应用到参数上:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ t + 1 = θ t − η ⋅ ( m ^ t v ^ t + ϵ + λ θ t ) \theta_{t+1} = \theta_t - \eta \cdot \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right) </math>θt+1=θt−η⋅(v^t +ϵm^t+λθt)
差别看似细微,但对大模型训练效果显著。
今天几乎所有 Transformer 训练(GPT、LLaMA、PaLM 系列)默认都用 AdamW,不是 Adam。这是工程社区花了几年时间踩坑后达成的共识。
6.5 为什么 Transformer 一定要用 Adam/AdamW
Transformer 训练的几个特征让 SGD 几乎不可用。
第一,参数维度间的曲率差异极大。Transformer 里有 Embedding 矩阵、注意力的 Q/K/V 投影、FFN 的两个矩阵、LayerNorm 的 scale 和 bias,每种参数的「合理步长」可能差几个数量级。SGD 用一个全局学习率,根本没法兼顾。
第二,训练初期梯度尺度变化剧烈。LayerNorm 后的激活分布会随训练快速漂移,一个固定学习率很容易在初期炸裂。Adam 的自适应分母对此天然鲁棒。
第三,训练时间长、计算昂贵,每次重启代价高。工程上需要一个「调一次就能跑到底」的优化器,Adam 在这件事上的鲁棒性比 SGD 强很多。
第四,Adam 在稀疏更新场景表现更好。Embedding 矩阵每个 token 只在它出现时被更新(其余 batch 元素的梯度对这一行是零),这就是稀疏更新;Adam 的自适应分母会按「这一行的实际更新频率」调步长,比 SGD 更合理。
但代价是显著的:Adam 每个参数要维护两个状态( <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt、 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt),相比 SGD 多 2 倍的优化器内存。对一个 70B 参数的模型来说,光优化器状态就要 70B × 2 × 4 字节(fp32)= 560 GB,这是「为什么训大模型要 ZeRO」(DeepSpeed 的优化器分片)的直接原因。
后续 llm-infra 系列会详细讲这一点。
6.6 一些更新的优化器探索
Adam 不是终点,但在工业界几乎是默认。简单提一下近年的几个值得关注的尝试。
Lion (Chen et al. 2023, Google Brain):用 sign(动量) 代替 Adam 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ / v ^ \hat{m}/\sqrt{\hat{v}} </math>m^/v^ 。它的思想是「方向比大小重要」,所以只用动量的符号;优化器状态从 2 个减到 1 个,节省一半显存。Lion 在某些视觉任务和小型 Transformer 上比 AdamW 更好,但在大规模 LLM 上还没成为默认。
Sophia (Liu et al. 2023, Stanford):用对角 Hessian 估计代替 Adam 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ \sqrt{\hat{v}} </math>v^ ,把二阶信息引入。在 GPT-2 规模下相对 AdamW 节省 50% 训练步数;在更大规模下尚待验证。
Shampoo / Distributed Shampoo:用 Kronecker 因子近似全 Hessian,是「准二阶方法工业化」的代表。Google 在某些产品级训练里已用。
这些方法都没有完全取代 AdamW,主要原因是工程稳定性、超参敏感性、与 LayerNorm/RMSNorm 的协同还没充分验证。但它们都在挑战「Adam 是不是最终答案」这个问题。第 36 篇会再讨论。
6.7 Adam 在不同模型上的实证差异
把 SGD-momentum 和 AdamW 在三类典型任务上的表现总结成一张表:
| 任务 | SGD-momentum | AdamW | 差异原因 |
|---|---|---|---|
| ImageNet ResNet-50 | 76.5% top-1 | 75.8% top-1 | SGD 略好,CV 任务损失面相对良好 |
| BERT-base 预训练 | 难以收敛 | 标准做法 | LayerNorm + 各向异性梯度 |
| GPT-3 175B 预训练 | 不可用 | 必需 | 大模型训练不稳定性 |
| 强化学习(PPO) | 通常用 Adam | 通常用 Adam | 梯度方差大、不平稳 |
| 推荐系统稀疏 embedding | 不合适 | AdaGrad/Adam 更好 | 稀疏更新需要自适应 |
这张表说明一个原则:优化器选择是任务相关的,没有「先进就一定好」的简单规律。Transformer 时代 AdamW 一统天下,但在视觉、强化学习、推荐里仍有不同选择。
6.8 一个不为人知的细节:Adam 的 epsilon
Adam 公式里那个 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ = 1 0 − 8 \epsilon = 10^{-8} </math>ϵ=10−8 看起来无关紧要,其实它是「对小梯度方向的步长上限」。
具体来说:当 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t → 0 \hat{v}_t \to 0 </math>v^t→0(某个方向梯度很小很久了),Adam 的更新量 <math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ t / ( v ^ t + ϵ ) → m ^ t / ϵ \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon) \to \hat{m}_t / \epsilon </math>m^t/(v^t +ϵ)→m^t/ϵ。所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 越小,「冷门方向」的步长越大; <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 越大,「冷门方向」被压制得越多。
在 BERT、GPT 训练里通常用 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ = 1 0 − 8 \epsilon = 10^{-8} </math>ϵ=10−8 默认值;但 Vision Transformer(ViT)社区发现 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ = 1 0 − 7 \epsilon = 10^{-7} </math>ϵ=10−7 或 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 6 10^{-6} </math>10−6 训练更稳定,因为 ViT 早期梯度更分散。这是一个非常容易被忽略的旋钮。
6.9 一个从未被打破的事实:Adam 的更新等价于「方向 × 单位步长」
仔细看 Adam 的更新公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Δ θ = − η ⋅ m ^ t v ^ t + ϵ \Delta\theta = -\eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} </math>Δθ=−η⋅v^t +ϵm^t
如果忽略 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ,这个比值在每个维度上的「典型大小」近似为 1------因为分子是梯度均值、分母是梯度方差的平方根,二者量级相近。
也就是说,Adam 实际上是「方向 × <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η」的更新 。学习率 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η 直接控制每步移动的「实际距离」,而不像 SGD 那样需要根据梯度大小再换算。
这个性质是 Adam 工程鲁棒性的核心来源:不管模型规模多大、不管梯度分布多偏,Adam 的更新尺度都被自适应分母「单位化」了,所以同一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η 在不同规模上往往能用。SGD 没有这个性质,超参对规模高度敏感。
这也是为什么近年 Lion 这种「sign(动量)」优化器能起作用------它本质也是「方向 × <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η」,把 Adam 的动机简化到极致。
七、学习率调度:训练的另一只手
7.1 为什么要 warmup
Transformer 训练有一个非常常见的现象:如果直接用一个目标学习率(比如 1e-4)从第一步开始,训练初期 loss 经常爆炸。
原因是 Adam 的自适应分母 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \sqrt{\hat{v}_t} </math>v^t 在前几步还没充分估计准,更新方向不可靠;同时模型参数刚初始化,输出分布远离目标分布,梯度本身就大且方向乱。两件事叠加,第一步就可能把参数推到崩溃区域。
Warmup 的做法是在训练开始的前若干步(典型几千步)把学习率从 0 线性增加到目标学习率,再开始正常训练。
这给了优化器估计 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt 的时间,也给了模型「适应一下输出分布」的缓冲。
原 Transformer 论文(Vaswani 2017)的学习率调度是 warmup + 反平方根衰减;后来 BERT、GPT、LLaMA 都基本沿用了「warmup + 某种衰减」的两段式结构,只是衰减形式有所不同(cosine、polynomial、constant 都有)。
具体的 warmup 长度因模型而异。BERT-base 用 1 万步 warmup(约 1% 训练时长);LLaMA-65B 用 2000 步 warmup(约 0.13% 训练时长);GPT-3 用 3.75 亿 token 的 warmup。一个粗略的法则是「warmup 占总训练步数的 0.5%--5%」。
7.2 衰减:从大步到小步
Warmup 之后必须配某种学习率衰减。
直觉上:训练初期希望大步快速进入低 loss 区域,后期需要小步精修。常见的衰减形式有线性衰减、cosine 衰减、阶梯衰减、polynomial 衰减。
Cosine 衰减由 Loshchilov 2017 年的 SGDR 论文带火,今天在 Transformer 训练里非常普遍------比如 LLaMA 2 用的就是 cosine schedule,最大学习率衰减到 10% 左右。
学习率调度看起来只是个超参,其实和优化器本身一样重要。Chinchilla(Hoffmann et al. 2022)里就专门讨论过:学习率调度长度应该匹配训练 token 数。
如果你计划训 100B token 但 cosine 只衰减到 30B token 的位置,剩下的 70B token 全是小步学习,几乎没用。这一点在第 34 篇 Scaling Laws 会再讲。
7.3 warmup-stable-decay:把训练和验证都照顾到
一种近年流行的 schedule 叫「warmup-stable-decay」(WSD):先 warmup 升上去,然后保持恒定学习率训练大段时间,最后再 cosine 或线性衰减下来。
它的好处是「中间恒定段」可以随时停下来 checkpoint,再换不同的衰减长度从同一个 checkpoint 出发训出多个版本。这在做 scaling law 实验、模型体检时非常有用,不需要每次都从头训。
MiniCPM、DeepSeek 等近年模型都用了这种 schedule。
7.4 cyclic 与 restart:另一类思路
还有一类 schedule 叫 cyclic learning rate(CLR)或 SGDR(Stochastic Gradient Descent with warm Restarts):学习率周期性升降,每隔若干 epoch 就「重启」一次。
它的思想是「让模型周期性地被晃动一下」,避免陷入同一个局部极小区域。在某些视觉任务和小模型上效果很好,但在大型 Transformer 上不常见------大模型预训练成本太高,一次 cosine 已经够了,不希望承受 restart 带来的训练动力学复杂度。
7.5 什么时候不需要 warmup
虽然 Transformer 训练几乎一律用 warmup,但 warmup 不是免费的------前几千步学习率小,相当于这段训练「白训」。
近年有几个工作探索「不需要 warmup 也能稳定训练」的初始化与归一化方案。比如 DeepMind 的 NFNet(Brock 2021)用「自适应梯度裁剪 + 标量化残差」让深网络不需要 BN 也不需要 warmup;ReLoRA、SP-Init 等工作也尝试过「让模型从第一步就能用目标学习率」。
但这些方案都有代价:要么对初始化要求严苛、要么对架构有约束。在主流 Transformer 训练里 warmup 仍是默认。
7.6 学习率与批量大小的耦合
最后再强调一次:学习率从来不是「孤立的旋钮」,它和批量大小高度耦合。
一个粗略的经验:当批量大小翻倍时,学习率大约也要 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 \sqrt{2} </math>2 到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 2 </math>2 倍地放大(具体倍数因优化器、任务而异)。GPT-3 训练用 3.2M token 的全局批量,对应的学习率是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 6 × 1 0 − 5 6 \times 10^{-5} </math>6×10−5;如果批量降到 1M,学习率应该相应调到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 3 × 1 0 − 5 3 \times 10^{-5} </math>3×10−5 左右。
这一耦合在工程上很重要:当你为了适配新硬件改批量时,不能忘了同步调学习率。否则训练动力学会和论文里的设定完全不同。
7.7 学习率本质上是步长的工程旋钮
回顾整个学习率调度这一节的设计思想:把「每一步走多远」拆成几个阶段,针对不同训练阶段的需求做不同的步长安排。
Warmup 是为了让自适应优化器先估准 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt、让模型先适应输出分布;恒定段是为了让模型在最有价值的「中段」用大步快进;衰减段是为了让模型在末段精修。每一段都对应训练动力学的一个特定阶段,不能省略也不能换顺序。
理解了这套思路,再看 BERT、GPT、LLaMA、Gemini 论文里的学习率 schedule 图,就不会觉得每个项目都在「神秘调参」------它们其实都在同一个框架下做不同细节的取舍。
八、实战中的坑
8.1 梯度裁剪:救命稻草
即便有 Adam 和 LayerNorm,Transformer 训练偶尔仍然会遇到「loss 突然飙升」(loss spike)。
原因可能是某个 batch 里碰到了少数极端样本、某个数据 shard 损坏、混合精度下某次 fp16 溢出。一旦发生 loss spike,梯度可能瞬间放大几个数量级,这一步把参数推到深坑里,模型再也救不回来。
梯度裁剪 (gradient clipping)是最常用的防御机制。具体做法是:算完梯度后,先算它的全局 L2 范数 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∥ g ∥ \|g\| </math>∥g∥,如果超过某个阈值 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c,就把梯度乘以 <math xmlns="http://www.w3.org/1998/Math/MathML"> c / ∥ g ∥ c / \|g\| </math>c/∥g∥ 缩到阈值以下。
这会改变梯度方向吗?不会,只是缩小幅度。GPT-3 训练里用的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> c = 1 c = 1 </math>c=1,LLaMA 也用类似量级。这个简单的技巧已经救过无数次大模型训练。
也有按 element 裁剪(clip by value)和按 norm 裁剪(clip by global norm)两种风格。今天主流大模型几乎都用 global norm,因为它能保留方向信息。
8.2 混合精度的陷阱
为了节省显存和加速,Transformer 训练大量使用混合精度(mixed precision):参数和激活用 fp16/bf16 存储,关键累加(如 LayerNorm 的均值方差、softmax 的 exp 累加)用 fp32。
fp16 的动态范围比 fp32 小很多(指数位只有 5 bit,能表示的最大约 <math xmlns="http://www.w3.org/1998/Math/MathML"> 6.5 × 1 0 4 6.5 \times 10^4 </math>6.5×104,最小约 <math xmlns="http://www.w3.org/1998/Math/MathML"> 6 × 1 0 − 8 6 \times 10^{-8} </math>6×10−8),loss 太小时直接下溢成零。Loss scaling 是常用补救:把 loss 乘一个大常数(比如 1024),反向传播完再把梯度除回来。
bf16 因为指数位多(8 bit,和 fp32 同),动态范围接近 fp32,不需要 loss scaling,是今天大模型训练的主流。代价是 bf16 尾数位少(7 bit,比 fp16 的 10 bit 还少),精度反而低------但因为深度学习对精度不敏感、对动态范围敏感,这个权衡是值得的。
8.3 优化器状态的内存
回到前面的算账:Adam/AdamW 每个 fp32 参数需要 8 字节优化器状态( <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt + <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt 各 4 字节),加上梯度 4 字节、参数本体 4 字节,每个参数总共要 16 字节。
一个 7B 模型仅训练状态就是 112 GB,单卡 80 GB 显存装不下;70B 模型则是 1120 GB。
这是为什么训练大模型必须分布式、必须分片优化器状态。这件事在本系列里只是一个边角细节,但它支配了整个 LLM 工业界的工程选型------ZeRO、FSDP、tensor parallel、pipeline parallel 都是为了把这块内存切下去。详细可以看 llm-infra 系列。
8.4 学习率寻找:LR Finder
有一种实用工具叫 LR finder(Smith 2017):在一次短训练里把学习率从极小(如 1e-7)指数地增加到极大(如 10),同时记录 loss。loss-learning rate 曲线通常呈现「先平、后下降、最后爆炸」三段式。最优学习率大致取「下降段最陡的那一点」往前 1/4 处。
这个方法在中小规模训练上非常实用,几分钟就能找到一个合理的初始学习率。在大模型预训练上不太用------成本太高,工程师更倾向于看 scaling law 论文里推荐的值。
8.5 Loss 曲线的诊断
最后一个实战经验:怎么从 loss 曲线看出训练健康度。
正常曲线:开始有快速下降(warmup 期),然后进入对数尺度下的稳定下降,偶尔有小幅波动。
学习率太大:早期 loss 反复跳跃,甚至上升后才下降;严重时直接 NaN。
学习率太小:loss 下降极慢,跑了大半训练还在和初始 loss 同一个数量级。
梯度消失:loss 在某个值上「卡住」,但参数明显没收敛------可能是初始化不当或激活函数选错。
Loss spike:训练中突然出现一个尖刺,恢复 / 不恢复都有可能。如果反复出现,要查数据、查混合精度、查梯度裁剪。
过拟合:训练 loss 继续下降但验证 loss 开始上升。在大模型预训练里很少看到(因为数据几乎不重复),在微调里常见。
熟练的工程师能在头几小时的训练曲线就判断出「这次训练值不值得继续跑下去」,省掉无数 GPU 小时。
8.6 训练监控:除了 loss 还该看什么
只看 loss 曲线,会错过很多早期警报。健康的训练监控至少要看以下指标:
梯度全局范数 :每个 step 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∥ g ∥ \|g\| </math>∥g∥,应当在合理量级(比如 0.1--10)内波动。如果突然飙升到几百,说明出现 spike;如果衰减到接近零,说明梯度消失。
参数范数:每层参数的 L2 范数。如果某一层范数持续上涨,可能权重衰减不够;如果某一层范数衰减到接近零,可能这一层「死了」(dead unit)。
激活值的方差:每层 forward 后激活的方差。LayerNorm 之后应该接近 1;如果某层激活方差持续偏离,说明归一化没起作用。
学习率:画出来确认 warmup、衰减按预期进行。这看起来 trivial,但实战里学习率 schedule 写错的事故很多。
优化器状态 :Adam 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t 平方根分布。如果某个范围的参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t 很小(梯度长期为零),可能是这部分参数没在学。
throughput:每秒处理的 token 数。如果突然下降,可能是 dataloader 卡了或某个 GPU 慢了。
成熟的训练 pipeline 会把这些指标都打到 TensorBoard 或 Weights & Biases,定期人工 review。GPT-3、PaLM 这些大模型训练时都有专人 24 小时看曲线,发现异常立即热修。
8.7 数据流水线对训练的影响
训练中常被忽略的一环是数据流水线(data pipeline)。
如果 dataloader 比 GPU 慢,GPU 会闲置等数据;如果 dataloader 抢占了 CPU,训练主进程会被拖慢。在大模型训练里,数据 throughput 和模型 throughput 必须匹配,否则任何一边的瓶颈都会拖垮整个训练。
具体说,一个 GPT-3 175B 模型每 step 处理 3.2M token,单 step 耗时约 1.5 秒,意味着 dataloader 必须每秒供 2M+ token 的 tokenized 数据。这要求并行 worker、预先 tokenize 后存盘、shard 化按需加载等一整套工程。
数据流水线还和「训练动力学」有关。如果数据 shuffle 不充分,模型可能在某一段时间反复看到相似样本,loss 局部下降但梯度方向偏向某种数据特性,最终影响泛化。LLaMA、PaLM 等模型都强调「全 shuffle、跨 shard、跨语种」混杂采样。
8.8 一个真实事故复盘:Megatron-Turing NLG 530B 训练
为了让本节不只是抽象建议,举一个公开的真实案例。
Megatron-Turing NLG 530B(Smith 2022)是 NVIDIA 与 Microsoft 联合训练的 530B 参数模型,预训练用了 270B token、跑了几个月。论文里专门提到训练过程中遇到的 loss spike:在训练大约 25% 进度时出现一次明显尖刺,loss 短时间从 1.9 飙升到 2.3。
团队的处理方式是:回滚到 spike 前的 checkpoint,跳过引发 spike 的那几个 batch(怀疑是某些极端样本),从同一 checkpoint 重新出发。这在工程上叫「skip-and-resume」。论文还提到他们一共做过 3 次这样的 skip-and-resume,最终模型才训完。
这件事在多个大模型训练中都被独立观察到:GPT-3、PaLM、LLaMA 的训练日志(部分公开)都报告过 spike 事件。今天的处理是「自动 spike 检测 + 自动回滚」,已经成为大型训练系统的标配。
8.9 训练的工程闭环:从 loss 到下一次实验
最后一节讲一件「软」事情:训练不是一次性的事,是一个闭环。
成熟的团队每次跑训练都遵循「假设 → 跑实验 → 监控 → 复盘 → 下次调整」的循环。每次只动一个旋钮(学习率、batch、初始化、架构等),观察对 loss 曲线的影响,写进实验台账,下次有针对性地调整。
这看起来像研究方法论而不是技术,但在大模型时代它越来越重要:训练成本动辄几十万美元,没法靠「试错」推进;每个决策都得能解释清楚为什么这么调,以及下一步如何继续。
第 35、36 篇会从「数据工程」「训练稳定性」两个角度专门讲这套方法论。
九、再看一眼:把这些机制和 Transformer 串起来
到这里我们讲的几乎全部是「通用深度学习训练机制」。在结尾这一节,我把这些机制和 Transformer 训练里的几个经典选择具体对上号,方便你在后续阅读论文时心里有底。
9.1 GPT-3 的训练配置在做什么
GPT-3 论文(Brown et al. 2020)的训练超参可以拆开来看:
- 优化器 :AdamW, <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 = 0.9 , β 2 = 0.95 , ϵ = 1 0 − 8 \beta_1 = 0.9, \beta_2 = 0.95, \epsilon = 10^{-8} </math>β1=0.9,β2=0.95,ϵ=10−8。注意 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 = 0.95 \beta_2 = 0.95 </math>β2=0.95 而不是默认的 0.999------更短的二阶矩窗口让 Adam 对当前梯度更敏感,适合大批量、高动力学的场景。
- 学习率 :最大 <math xmlns="http://www.w3.org/1998/Math/MathML"> 6 × 1 0 − 5 6 \times 10^{-5} </math>6×10−5(175B 模型),cosine 衰减到 10%。
- Warmup:3.75 亿 token(约 0.125% 训练量)。
- Batch:3.2M token 全局批量,配合 weight decay 0.1。
- 梯度裁剪:global norm = 1.0。
- 混合精度:fp16(loss scaling)。
这些配置每一个背后都有本篇讲过的原理:AdamW 解决各向异性 + 权重衰减,warmup 防止初期不稳定,cosine 配合 token 总数实现「大步快进 + 小步精修」,gradient clipping 防 spike,混合精度省内存。每一个旋钮都不是「随便调的」。
9.2 LLaMA 与 GPT-3 的差异
LLaMA(Touvron 2023)的优化器配置和 GPT-3 大同小异,但有几处细节差异:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 = 0.95 \beta_2 = 0.95 </math>β2=0.95(与 GPT-3 一致,同样是为了大批量训练);
- weight decay 0.1(与 GPT-3 一致);
- 学习率根据模型大小调(7B 用 <math xmlns="http://www.w3.org/1998/Math/MathML"> 3 × 1 0 − 4 3 \times 10^{-4} </math>3×10−4,65B 用 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1.5 × 1 0 − 4 1.5 \times 10^{-4} </math>1.5×10−4,呈反平方根趋势);
- cosine 衰减到 10%;
- Adam 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 仍是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 5 10^{-5} </math>10−5(注意比 GPT-3 大!)。
那个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 5 10^{-5} </math>10−5 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 是 LLaMA 的「鲁棒性配方」------更大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 抑制了冷门方向的过激更新,让训练更不容易 spike。这是 Meta 团队踩过坑后的工程选择,不是论文里写的「最优值」。
9.3 训练异常的常见 root cause
你要是做大模型训练,未来一定会遇到 loss 突然飙升或 NaN。常见 root cause 大致有这么几类:
- 数据问题:某个 shard 里有损坏样本(极长 token、空文本、非法字符);
- 混合精度溢出:fp16 下某次 attention logits 超过 65504;
- 学习率过大:尤其在 warmup 结束、刚进入 plateau 时;
- 初始化不当:embedding 或某层 W 的初始尺度过大;
- gradient clipping 阈值过松:spike 没被压住;
- 数值不稳定的算子:softmax 在极大输入下饱和、exp 溢出(这是 FlashAttention 也要解决的问题,第 42 篇详谈)。
发现训练异常时,按上面的清单顺序排查,能解决大部分事故。第 36 篇会把这些 root cause 配合具体案例展开。
9.4 训练 vs 推理:梯度只在训练时存在
最后一个常见混淆:梯度、反向传播、Adam 这些只在训练时存在。
推理(inference)时只跑前向,不算梯度、不更新参数、不需要优化器状态。这就是为什么推理服务的内存只需要「参数本体 + 当前激活」,远小于训练所需的 16 字节/参数。
理解这一点对 LLM 部署很重要:一个 70B 模型推理只需要约 140 GB 显存(fp16 参数),但训练它需要至少 1.1 TB 显存(参数 + 梯度 + Adam 状态)。这个差异决定了「为什么 70B 模型能在两张 H100 上推理,却需要 64+ 张 H100 才能训练」。后续 KV cache、推理量化等话题会在第 49、51 篇展开。
十、关键概念回顾
写到这里,本篇覆盖的内容已经横跨数学、算法、工程三个维度。下面把它们用一段总结性的叙述串起来,作为对全文的一次回顾。
到这里如果你能把下面这件事讲给一个不懂深度学习的朋友,那你已经把这一篇消化了。
神经网络的「学习」本质上是一种「站在哪、看脚下、朝低处走」的迭代过程:网络的参数是高维空间中的一个点,损失函数把这个点映射到一个标量「不爽程度」,我们想找到一个让「不爽程度」尽可能小的位置。
怎么走?算一下「不爽程度」对每个参数的偏导(即梯度),梯度的反方向就是当前下降最快的方向,沿着它走一小步------这就是梯度下降。
但每走一步都需要算几亿、几百亿个偏导。如果用数值微分一个一个算,复杂度是参数量平方级,永远训不完。
反向传播是一种聪明的算法:它把神经网络看成一张计算图,从损失节点出发按反拓扑序一路传递「误差信号」,每个节点只需要本地的局部导数,整张图的所有梯度在一次反向遍历中就被算完。复杂度回到了和一次前向同阶。这件事是深度学习能 scale 起来的隐形地基。
而朴素的梯度下降在真实损失曲面上走得并不优雅。曲面是高度病态的、在不同方向曲率差异巨大、噪声估计带来抖动。
Momentum 用历史方向信息平滑掉震荡;自适应优化器(AdaGrad、RMSProp、Adam)用每个参数自己的「梯度平方均值」做分母,让步长在不同维度上自动平衡。Adam 把 momentum 和自适应一锅端,又加上偏差修正,是过去十年深度学习最有影响力的优化器。AdamW 是它在带权重衰减场景下的小改良,今天是大模型训练的事实标准。
最后还有一个常被忽略的事实:训练过程在工程上从来不是单调下降。loss spike、梯度爆炸、混合精度下溢、学习率不合适、warmup 不够长------这些都是看似无害但能直接毁掉数百万 GPU 小时训练的坑。梯度裁剪、warmup、cosine 衰减、bf16 是当代 Transformer 训练的标配工具箱。第 36 篇会再回到这些工程细节。
如果一句话总结:梯度下降是「想法」,反向传播是「工程实现」,Adam/AdamW 是「让这套机制在病态曲面上能稳定收敛的修补」。三者缺一不可。
十一、常见误解
误解一:「梯度下降总能找到全局最小值」。
不会。梯度下降在凸函数上有这个保证,但神经网络损失曲面是非凸的,理论上 SGD 能停在任意鞍点或局部最小。实践中它常常能找到「足够好的解」,但这是经验观察,不是定理。把它当数学保证看会出问题。
误解二:「反向传播是一种神秘的深度学习算法」。
不是。反向传播就是链式法则在计算图上的反向遍历,1970 年代就被多个领域独立发现(Linnainmaa 1970、Werbos 1974)。它在深度学习里之所以重要,是因为它正好把「算所有参数的梯度」这件事压成了 O(参数量);它在其它领域(比如最优控制、自动微分)早就有等价物。
误解三:「Adam 永远比 SGD 好」。
不一定。在大多数 CNN 视觉任务上,SGD-with-momentum 的最终泛化经常比 Adam 好(《The Marginal Value of Adaptive Gradient Methods in Machine Learning》, Wilson 2017)。但在 Transformer 上 Adam/AdamW 几乎是必须的,原因是 Transformer 损失曲面的极端各向异性和初期梯度的剧烈波动。「用什么优化器」是任务相关的,不是「先进就好」。
误解四:「学习率调小就更稳」。
学习率太小不仅训得慢,还可能让模型陷入坏的局部极小区域走不出去。它和批量大小、warmup、衰减形式联动决定训练动力学。盲目调小学习率往往掩盖问题(比如初始化不好、数据有噪声)而不是解决问题。
误解五:「梯度消失/爆炸是历史问题,现代架构都解决了」。
没有彻底解决。残差连接、LayerNorm、合理初始化都缓解了它,但在极深网络(GPT-4 级别)、极长上下文、低精度训练里,它仍然会以更隐蔽的形式出现------比如 attention logits 数值溢出、softmax 饱和、深层激活漂移。理解梯度链路上的乘积放大与衰减,仍是当代大模型工程师的必修课。
十二、下一步
读完这一篇,你应该能解释:神经网络是怎么把「不会」变成「会」的;为什么求导是核心运算;反向传播为什么不是数值微分;Adam 解决的是什么问题;学习率为什么需要 warmup。
接下来第 07. Softmax 与概率分布 会处理另一个关键齿轮:怎么把任意一组实数变成一个加起来等于一的概率分布,以及为什么 Softmax 是注意力机制的灵魂------后者把分数变成权重的过程,正是用 Softmax 完成的。
如果你对反向传播在 Transformer 里的具体数值表现有兴趣,第 25. Layer Normalization 与第 36. 训练稳定性 会回到本篇里提到的那些坑,给出更具体的工程数据。如果想理解优化器内存为什么是 LLM 工程的核心约束,可以转去看 llm-infra 系列里关于 ZeRO、FSDP 的章节。
十三、参考文献
- Cauchy, A. (1847). Méthode générale pour la résolution des systèmes d'équations simultanées. Comptes Rendus de l'Académie des Sciences.
- Linnainmaa, S. (1970). The representation of the cumulative rounding error of an algorithm as a Taylor expansion of the local rounding errors. Master's thesis, University of Helsinki.
- Werbos, P. J. (1974). Beyond Regression: New Tools for Prediction and Analysis in the Behavioral Sciences. PhD thesis, Harvard University.
- Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning representations by back-propagating errors. Nature, 323(6088), 533--536.
- Bengio, Y., Simard, P., & Frasconi, P. (1994). Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks, 5(2).
- Polyak, B. T. (1964). Some methods of speeding up the convergence of iteration methods. USSR Computational Mathematics and Mathematical Physics, 4(5).
- Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence O(1/k²). Soviet Math. Doklady.
- Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. JMLR.
- Tieleman, T. & Hinton, G. (2012). Lecture 6.5 - rmsprop. COURSERA: Neural Networks for Machine Learning.
- Kingma, D. P. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. arXiv:1412.6980.
- Loshchilov, I. & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019.
- Goodfellow, I. J., Vinyals, O., & Saxe, A. M. (2014). Qualitatively characterizing neural network optimization problems. arXiv:1412.6544.
- Goyal, P., et al. (2017). Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. arXiv:1706.02677.
- Loshchilov, I. & Hutter, F. (2017). SGDR: Stochastic Gradient Descent with Warm Restarts. ICLR 2017.
- Wilson, A. C., et al. (2017). The Marginal Value of Adaptive Gradient Methods in Machine Learning. NeurIPS 2017.
- Hoffmann, J., et al. (2022). Training Compute-Optimal Large Language Models (Chinchilla). arXiv:2203.15556.
- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. CVPR 2016.
- Glorot, X. & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. AISTATS.
- He, K., et al. (2015). Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. ICCV.
- Dauphin, Y., et al. (2014). Identifying and attacking the saddle point problem in high-dimensional non-convex optimization. NeurIPS.
- Li, H., et al. (2018). Visualizing the Loss Landscape of Neural Nets. NeurIPS.
- Smith, L. N. (2017). Cyclical Learning Rates for Training Neural Networks. WACV.
- Chen, X., et al. (2023). Symbolic Discovery of Optimization Algorithms (Lion). arXiv:2302.06675.
- Liu, H., et al. (2023). Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training. arXiv:2305.14342.
- Boyd, S. & Vandenberghe, L. (2004). Convex Optimization. Cambridge University Press.
- 上一篇:05. 激活函数与非线性
- 下一篇:07. Softmax 与概率分布
- 系列总览:【Transformer 与注意力机制】