在训练深度神经网络时,一个经典问题是 梯度消 和 梯度爆炸。
随着网络层数增加,反向传播中的梯度会在多层之间不断相乘。如果这些乘积逐渐变小,梯度就会衰减到接近 0;如果逐渐变大,梯度则会迅速增大。这会导致网络难以训练甚至完全不收敛。
这一问题在 深层前馈网络 和 循环神经网络(RNN) 中都非常常见。
一、梯度是如何在网络中传播的
神经网络通过 反向传播(Backpropagation) 更新参数。设网络共有 L 层,第 层的权重为
,损失函数为
。
梯度可以写成:
关键在于中间这一项:
它是 一系列导数的乘积 。其中 表示第 k 层的激活输出。
如果这些导数平均小于 1,梯度会指数衰减;如果大于 1,则会指数增长。
二、梯度消失
梯度消失指的是:在反向传播过程中,梯度逐层变小,最后接近 0。
一个典型原因是 Sigmoid 激活函数。
Sigmoid 函数:
其导数为:
Sigmoid 导数最大值为:
如果网络有 10 层:
可以看到梯度已经几乎为 0。
结果是:
-
输入层附近的参数几乎不更新
-
网络训练速度极慢
-
深层网络无法学习复杂特征
三、梯度爆炸
梯度爆炸则是另一种情况:如果导数大于 1,梯度会指数增长。
例如:
随着层数增加可能迅速变得非常大。
结果包括:
-
loss剧烈震荡
-
权重更新过大
-
参数出现 NaN
-
模型训练失败
四、TensorFlow实验:观察梯度消失
下面用一个简单实验直观展示梯度消失。
python
import tensorflow as tf
import numpy as np
tf.random.set_seed(42)
np.random.seed(42)
# 输入 64 个样本,每个样本 100 维
x = tf.random.normal((64, 100))
# 回归目标
y = tf.random.normal((64, 1))
model = tf.keras.Sequential()
# 输入层后接 20 层 sigmoid,全连接层较深,容易出现梯度消失
# 切换relu,梯度消失明显减轻。
for _ in range(20):
# model.add(tf.keras.layers.Dense(
# 64,
# activation='relu',
# kernel_initializer='he_normal'
# ))
model.add(tf.keras.layers.Dense(
64,
activation='sigmoid',
kernel_initializer='glorot_uniform'
))
# 输出层
model.add(tf.keras.layers.Dense(1))
with tf.GradientTape() as tape:
pred = model(x)
loss = tf.reduce_mean(tf.square(pred - y))
# 计算所有可训练参数的梯度
grads = tape.gradient(loss, model.trainable_variables)
print(f"Loss = {loss.numpy():.6f}\n")
for i, (var, grad) in enumerate(zip(model.trainable_variables, grads)):
grad_norm = tf.norm(grad).numpy()
print(f"{i:02d} | {var.name:30s} | grad norm = {grad_norm:.8e}")
结果:
sigmoid:
00 | dense/kernel:0 | grad norm = 8.17580270e-13
01 | dense/bias:0 | grad norm = 2.33749597e-13
02 | dense_1/kernel:0 | grad norm = 4.91960465e-12
03 | dense_1/bias:0 | grad norm = 1.22193434e-12
04 | dense_2/kernel:0 | grad norm = 2.31958393e-11
05 | dense_2/bias:0 | grad norm = 5.75166347e-12
06 | dense_3/kernel:0 | grad norm = 9.13673998e-11
07 | dense_3/bias:0 | grad norm = 2.22901229e-11
08 | dense_4/kernel:0 | grad norm = 3.77236742e-10
09 | dense_4/bias:0 | grad norm = 9.34700928e-11
10 | dense_5/kernel:0 | grad norm = 1.86993665e-09
11 | dense_5/bias:0 | grad norm = 4.74612183e-10
12 | dense_6/kernel:0 | grad norm = 8.78709727e-09
13 | dense_6/bias:0 | grad norm = 2.14278661e-09
14 | dense_7/kernel:0 | grad norm = 3.70847388e-08
15 | dense_7/bias:0 | grad norm = 8.95946783e-09
16 | dense_8/kernel:0 | grad norm = 1.75122992e-07
17 | dense_8/bias:0 | grad norm = 4.32770797e-08
18 | dense_9/kernel:0 | grad norm = 8.62212744e-07
19 | dense_9/bias:0 | grad norm = 2.04899578e-07
20 | dense_10/kernel:0 | grad norm = 3.72339127e-06
21 | dense_10/bias:0 | grad norm = 8.78402432e-07
22 | dense_11/kernel:0 | grad norm = 1.38942723e-05
23 | dense_11/bias:0 | grad norm = 3.39990129e-06
24 | dense_12/kernel:0 | grad norm = 5.44022914e-05
25 | dense_12/bias:0 | grad norm = 1.38788946e-05
26 | dense_13/kernel:0 | grad norm = 2.60285946e-04
27 | dense_13/bias:0 | grad norm = 6.29604619e-05
28 | dense_14/kernel:0 | grad norm = 9.69215936e-04
29 | dense_14/bias:0 | grad norm = 2.40968046e-04
30 | dense_15/kernel:0 | grad norm = 3.73327779e-03
31 | dense_15/bias:0 | grad norm = 9.40103782e-04
32 | dense_16/kernel:0 | grad norm = 1.55217508e-02
33 | dense_16/bias:0 | grad norm = 3.68970097e-03
34 | dense_17/kernel:0 | grad norm = 6.85354173e-02
35 | dense_17/bias:0 | grad norm = 1.61047429e-02
36 | dense_18/kernel:0 | grad norm = 3.30706924e-01
37 | dense_18/bias:0 | grad norm = 8.29390287e-02
38 | dense_19/kernel:0 | grad norm = 1.16695201e+00
39 | dense_19/bias:0 | grad norm = 2.90607035e-01
40 | dense_20/kernel:0 | grad norm = 3.62535310e+00
41 | dense_20/bias:0 | grad norm = 8.80031526e-01
relu:
00 | dense/kernel:0 | grad norm = 1.68664873e+00
01 | dense/bias:0 | grad norm = 1.36555120e-01
02 | dense_1/kernel:0 | grad norm = 1.36872327e+00
03 | dense_1/bias:0 | grad norm = 1.90614641e-01
04 | dense_2/kernel:0 | grad norm = 1.29527867e+00
05 | dense_2/bias:0 | grad norm = 1.92289561e-01
06 | dense_3/kernel:0 | grad norm = 1.16377687e+00
07 | dense_3/bias:0 | grad norm = 1.82995304e-01
08 | dense_4/kernel:0 | grad norm = 1.35152102e+00
09 | dense_4/bias:0 | grad norm = 1.93998978e-01
10 | dense_5/kernel:0 | grad norm = 1.62513626e+00
11 | dense_5/bias:0 | grad norm = 2.17954010e-01
12 | dense_6/kernel:0 | grad norm = 2.20767760e+00
13 | dense_6/bias:0 | grad norm = 2.97283173e-01
14 | dense_7/kernel:0 | grad norm = 2.05378723e+00
15 | dense_7/bias:0 | grad norm = 3.55660528e-01
16 | dense_8/kernel:0 | grad norm = 2.14699459e+00
17 | dense_8/bias:0 | grad norm = 3.32726449e-01
18 | dense_9/kernel:0 | grad norm = 2.20528030e+00
19 | dense_9/bias:0 | grad norm = 3.99901599e-01
20 | dense_10/kernel:0 | grad norm = 2.79961729e+00
21 | dense_10/bias:0 | grad norm = 5.02139211e-01
22 | dense_11/kernel:0 | grad norm = 3.55761576e+00
23 | dense_11/bias:0 | grad norm = 4.86391962e-01
24 | dense_12/kernel:0 | grad norm = 4.30128431e+00
25 | dense_12/bias:0 | grad norm = 5.73925495e-01
26 | dense_13/kernel:0 | grad norm = 5.22420502e+00
27 | dense_13/bias:0 | grad norm = 6.69952810e-01
28 | dense_14/kernel:0 | grad norm = 4.91555786e+00
29 | dense_14/bias:0 | grad norm = 7.18435049e-01
30 | dense_15/kernel:0 | grad norm = 5.29556084e+00
31 | dense_15/bias:0 | grad norm = 1.03335011e+00
32 | dense_16/kernel:0 | grad norm = 6.20802498e+00
33 | dense_16/bias:0 | grad norm = 1.16042447e+00
34 | dense_17/kernel:0 | grad norm = 4.97060537e+00
35 | dense_17/bias:0 | grad norm = 1.04178715e+00
36 | dense_18/kernel:0 | grad norm = 4.75352526e+00
37 | dense_18/bias:0 | grad norm = 1.11193109e+00
38 | dense_19/kernel:0 | grad norm = 4.99584198e+00
39 | dense_19/bias:0 | grad norm = 1.11559403e+00
40 | dense_20/kernel:0 | grad norm = 4.60004091e+00
41 | dense_20/bias:0 | grad norm = 1.30675387e+00
五、梯度问题的常见解决方法
为了解决梯度消失与梯度爆炸问题,深度学习提出了多种方法。
1 使用 ReLU 激活函数
ReLU 函数:
导数:
在正区间导数为 1,因此不会出现指数衰减。
2 合理的权重初始化
如果权重初始化过大或过小,都会导致梯度问题。
常见初始化方法:
Xavier初始化
He初始化
这些初始化方法可以保持信号在网络中传播时的方差稳定。
3 Batch Normalization
BatchNorm 会对每一层输入进行归一化:
作用:
-
稳定输入分布
-
加快收敛
-
缓解梯度消失
4 梯度裁剪(Gradient Clipping)
当梯度过大时,可以限制梯度的最大范数:
例如
梯度 = 100
阈值 = 5
则
g = 100 * 5/100 = 5
这样梯度被限制在阈值以内。
5 残差连接(ResNet)
ResNet 引入跳跃连接:
求导:
即使:
梯度仍然可以通过 +1路径传播。
因此 ResNet 可以训练 上百甚至上千层网络。
六、小结
梯度消失与梯度爆炸来源于 反向传播中的连乘效应。
当导数小于 1 时,梯度指数衰减;当大于 1 时,梯度指数增长。
现代深度学习主要通过以下方法解决这一问题:
-
ReLU 激活函数
-
合理权重初始化(Xavier / He)
-
Batch Normalization
-
Gradient Clipping
-
Residual Connection
这些技术共同推动了深度神经网络的发展,使得数百层甚至上千层模型成为可能。
附录:
为什么 Xavier / He 初始化可以缓解梯度问题
核心目标其实只有一个:
让信号在网络中传播时保持方差稳定。
如果每层输出方差逐渐变小 → 梯度消失
如果逐渐变大 → 梯度爆炸
1 前向传播方差分析
假设一层神经网络:
其中
-
x为输入
-
W为权重矩阵
设
-
输入方差 Var(x)
-
权重方差 Var(W)
输出方差为
这里
nin 表示输入维度,即:
输入神经元数量
例如
Dense(100,50)
则
nin = 100
nout = 50
2 为了保持方差稳定
希望
代入
得到
但如果考虑 前向传播 + 反向传播同时稳定,Glorot(Xavier)推导得到
其中
-
nin:输入神经元数量
-
nout:输出神经元数量
3 He 初始化
He 初始化是针对 ReLU 网络推导的。
因为 ReLU 会把一半信号置为 0,因此有效方差减少一半。
因此需要更大的初始化:
欢迎大家关注,后续我会更新更多关于深度学习与信号处理的知识。