吴恩达《深度学习》之看懂He 初始化的“能量守恒”直觉

欢迎来到深度学习模型生死存亡的命脉关卡------权重初始化与方差保持的数学艺术

每一个深度学习架构师在搭建深层网络时,必然会遭遇的"悄无声息的隐形杀手":深层激活值迅速衰减为 0

核心知识点:

  • 场景问题: 在一个 100 层的深层 ReLU 网络中,使用 np.random.randn(..) * 0.01 初始化导致深层激活值迅速衰减为 0(死网现象)。
  • 核心决策: 将初始化缩放因子改为 He 初始化(Kaiming 初始化)
  • 数学核心: 权重标准差设为 Std(W)=2nin\text{Std}(W) = \sqrt{\frac{2}{n_{\text{in}}}}Std(W)=nin2 。相比针对 Tanh 的 Xavier 初始化(分子为 1),将分子改为 2 是为了弥补 ReLU 斩断 50% 负数带来的方差减半。

今天我们不玩虚的,我们直接用最纯粹的方差控制(Variance Control)视角,来看看何恺明(He Kaiming)当年是如何用一个极其简单的数字"2",在微积分的盲区里拯救了整个现代深度网络。

第一步:拆解那个"自杀式"的初始化:* 0.01

初学者最喜欢写的一行代码就是:W = np.random.randn(..) * 0.01。这看起来很安全,把权重全部变成了一堆非常小的随机数。

提问: 假设我们有一个极其简化的一层网络,输入是 xxx,权重是 WWW,输出是 z=Wxz = Wxz=Wx。

根据概率论,如果两个相互独立的随机变量相乘,输出 zzz 的方差会和输入 xxx 的方差以及权重 WWW 的方差直接挂钩。如果你把每一层的权重都乘以 0.01,意味着权重的方差是一个极小的小数(比如 10−410^{-4}10−4)。

当数据 xxx 穿过第 1 层、第 2 层......一直穿到第 100 层时,每一层都在把前一层的方差乘以一个极小的小数。请问:到了第 100 层,这个信号的方差会萎缩到什么地步?当方差趋近于 0 时,所有的激活值(特征)在数字上会变成什么?

解析: 方差会呈指数级暴跌! 0.011000.01^{100}0.01100 几乎就是绝对的 0。

整个深层网络的所有神经元就像全部陷入了昏迷,吐出来的特征全是一片死寂的 0。这就是你看到的"激活值迅速衰减为 0"。没有了激活值,反向传播的梯度也就彻底死了

第二步:Xavier 初始化的流派------为了"方差不增不减"

为了打破这个诅咒,著名的统计学家 Xavier Glorot 提出了一个极其优雅的原则:"我们在向前和向后传播时,必须保证每一层输出的方差,和输入的方差完全相等。" (即 Var(z)=Var(x)\text{Var}(z) = \text{Var}(x)Var(z)=Var(x))。

经过严密的数学推导(在线性激活函数的假设下),为了维持方差不变,每一层权重 WWW 的方差应该和该层的输入神经元个数(输入维度 ninn_{\text{in}}nin)成反比。于是,大名鼎鼎的 Xavier 初始化 诞生了,它的权重标准差应该设为:

Std(W)=1nin\text{Std}(W) = \sqrt{\frac{1}{n_{\text{in}}}}Std(W)=nin1

(注:如果考虑反向传播,为了兼顾输入和输出,分母通常写作 2nin+nout\sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}}nin+nout2 ,我们这里为了和 He 初始化直观对比,简化看单向流动的分子为 1)

提问: 请注意 Xavier 的前提假设------它是针对 Tanh 或者 Sigmoid 这种关于原点对称、且在接近原点处近似线性的激活函数设计的。

如果此时,你把激活函数无情地换成了 ReLU 。ReLU 的数学定义是什么?f(z)=max⁡(0,z)f(z) = \max(0, z)f(z)=max(0,z)。这意味着,当一整批具有标准正态分布(均值为 0)的数据高高兴兴地冲进 ReLU 激活函数时,那另外 50% 负数部分的数据,会遭遇什么对待?

数学惨剧发生了: 负数部分全部被一刀切成了 0!

第三步:何恺明的终极一问------消失的 50% 去哪了?

这才是 He 初始化(He Initialization) 最精妙的数学直觉。

提问: 既然 ReLU 把一半的数据直接抹杀成了 0。那么,数据在穿过 ReLU 之后,它整体的方差(信号强度)和刚进去之前相比,瞬间被砍掉了多少?

直觉觉醒: 被砍掉了一半(12\frac{1}{2}21)!因为有一半的神经元死掉了,不再贡献任何方差了。

终极追问: 如果每一层过完 ReLU,方差都会由于"负数归零"而自发性地缩水一半。如果我们还死板地沿用 Xavier 初始化(分子是 1),那么随着层数变深,网络是不是依然会因为每层漏掉一半的方差,最终在深层无可挽回地走向衰减和枯竭?

为了在每一层把这自发漏掉的 12\frac{1}{2}21 的方差给硬生生补回来,我们在初始化权重矩阵 WWW 时,是不是应该在分子上乘以一个 2,让初始权重自带两倍的威力?

因果闭环: 天才的直觉!

既然 ReLU 导致方差减半,那我就在权重初始化时把分子从 1 变成 2,用放大两倍的权重方差,去精准抵消 ReLU 带来的那一半损耗!这就是轰动民间的 He 初始化(又称 Kaiming 初始化)的数学核心公式:

Std(W)=2nin\text{Std}(W) = \sqrt{\frac{2}{n_{\text{in}}}}Std(W)=nin2

第四步:PyTorch 里的"方差护航"代码落地

在 PyTorch 工业级开发中,当你使用 nn.Linearnn.Conv2d 且后面跟着 ReLU 时,你必须一行代码把默认的随机缩放切换为 He 初始化:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.init as init

class DeepLinearModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(DeepLinearModel, self).__init__()
        self.layer = nn.Linear(input_dim, hidden_dim)
        
        # 核心:使用 He (Kaiming) 正态分布初始化
        # non-linearity='relu' 告诉 PyTorch:帮我在分子上放那个关键的 "2"!
        init.kaiming_normal_(self.layer.weight, mode='fan_in', nonlinearity='relu')
        
        # 偏置通常直接初始化为 0
        init.constant_(self.layer.bias, 0.0)

    def forward(self, x):
        return torch.relu(self.layer(x))

总结

我们再次用最性感的物理直觉,把这个决策过程串联起来:

错误使用 ∗0.01  ⟹  方差在100层里指数级暴跌  ⟹  深层激活值沦为绝对零度\text{错误使用 } * 0.01 \implies \text{方差在100层里指数级暴跌} \implies \text{深层激活值沦为绝对零度}错误使用 ∗0.01⟹方差在100层里指数级暴跌⟹深层激活值沦为绝对零度

Xavier 初始化 (分子为1)  ⟹  假设数据全通过 (Tanh)  ⟹  面对 ReLU 斩断50%负数的现状无能为力\text{Xavier 初始化 (分子为1)} \implies \text{假设数据全通过 (Tanh)} \implies \text{面对 ReLU 斩断50\%负数的现状无能为力}Xavier 初始化 (分子为1)⟹假设数据全通过 (Tanh)⟹面对 ReLU 斩断50%负数的现状无能为力

He 初始化 (分子改为2)  ⟹  用2倍的权重方差×ReLU留下的 12 信号=1 (完美守恒)  ⟹  深层网络信号无损通过\text{He 初始化 (分子改为2)} \implies \text{用2倍的权重方差} \times \text{ReLU留下的 } \frac{1}{2} \text{ 信号} = 1 \text{ (完美守恒)} \implies \text{深层网络信号无损通过}He 初始化 (分子改为2)⟹用2倍的权重方差×ReLU留下的 21 信号=1 (完美守恒)⟹深层网络信号无损通过

何恺明的伟大,不在于他推导了多么冗长复杂的微积分方程,而在于他一眼看穿了 "ReLU 会把能量砍掉一半" 的物理现实,并用最优雅的补码(将分子从 1 改为 2),让信号在长达百层的数学时空中,达成了能量守恒。


欢迎在评论区留下你的思考: 既然 He 初始化完美解决了 ReLU 网络的方差衰减问题,那么如果我们在网络中引入了 批归一化(Batch Normalization, BN) 层,BN 会在每一层强制重新调整数据的方差。在这种情况下,权重初始化方法的选择(比如用 Xavier 还是 He)还会像以前一样对深层网络的生死起到绝对的决定性作用吗?为什么?