09bab-斯坦福CS336作业一-Xavier初始化

Xavier 初始化 🎯

本文档深入讲解 Xavier 初始化(又称 Glorot 初始化)的原理、数学推导与代码实现。Xavier 初始化通过精心设计的权重方差,使信号在深度神经网络的前向传播和反向传播中保持稳定,有效解决梯度消失和梯度爆炸问题 🛠️


术语表 / Terminology

术语 / Term 中文 说明 / Description
Xavier Initialization Xavier 初始化 由 Xavier Glorot 和 Yoshua Bengio 于 2010 年提出的权重初始化方法
Glorot Initialization Glorot 初始化 Xavier 初始化的别称,以第一作者姓氏命名
Variance Preservation 方差保持 使每层激活值和梯度的方差在传播过程中保持不变的核心原则
Gradient Vanishing 梯度消失 反向传播中梯度逐层递减至接近 0,导致深层参数无法更新
Gradient Exploding 梯度爆炸 反向传播中梯度逐层递增,导致参数更新过大、模型发散
Fan-in ( nin n_{\text{in}} nin) 输入连接数 某一层神经元的输入连接数量(即上一层的神经元个数)
Fan-out ( nout n_{\text{out}} nout) 输出连接数 某一层神经元的输出连接数量(即下一层的神经元个数)
Symmetry Breaking 打破对称性 通过随机初始化使同一层的不同神经元学到不同的特征
Standard Normal Distribution 标准正态分布 均值为 0、方差为 1 的正态分布,记为 N(0,1)\mathcal{N}(0, 1) N(0,1)

章节阅读路线图 🗺️ / Chapter Reading Roadmap

  1. 为什么需要权重初始化 🔍 / Why Weight Initialization Matters → 理解梯度消失/爆炸与对称性问题
  2. Xavier 初始化的核心思想 💡 / Core Idea of Xavier Initialization → 方差保持原则的直觉理解
  3. 数学推导 📐 / Mathematical Derivation → 从前向传播到反向传播的完整公式推导
  4. PyTorch 代码实现 💻 / PyTorch Implementation → 手动实现与原生函数的使用
  5. 总结 📝 / Summary → 回顾核心要点

1. 为什么需要权重初始化 🔍 / Why Weight Initialization Matters

📖 Note: 本章讲解权重初始化的重要性,以及不当初始化导致的梯度消失/爆炸和对称性问题 / This chapter explains the importance of weight initialization and the problems caused by improper initialization.

1.1 梯度消失与梯度爆炸 💥 / Gradient Vanishing and Exploding

在深度神经网络中,反向传播算法通过链式法则逐层传递梯度。考虑一个 LL L 层的网络,第 ll l 层关于权重 W(l)\mathbf{W}^{(l)} W(l) 的梯度可以写成:
∂W(l)o= ∂h(L−1)h(L) ⏟ M(L) ⋅...⋅ ∂h(l)h(l+1) ⏟ M(l+1) ⋅ ∂W(l)h(l) ⏟ v(l) \partial_{\mathbf{W}^{(l)}} \mathbf{o} = \underbrace{\partial_{\mathbf{h}^{(L-1)}} \mathbf{h}^{(L)}}{\mathbf{M}^{(L)}} \cdot \ldots \cdot \underbrace{\partial{\mathbf{h}^{(l)}} \mathbf{h}^{(l+1)}}{\mathbf{M}^{(l+1)}} \cdot \underbrace{\partial{\mathbf{W}^{(l)}} \mathbf{h}^{(l)}}_{\mathbf{v}^{(l)}} ∂W(l)o=M(L) ∂h(L−1)h(L)⋅...⋅M(l+1) ∂h(l)h(l+1)⋅v(l) ∂W(l)h(l)

这是 L−lL - l L−l 个矩阵的连乘。当层数 LL L 很大时,连乘的结果会出现两种极端情况:

  • 梯度消失(Gradient Vanishing) 📉:每个矩阵的特征值都小于 1,连乘后梯度指数级衰减,深层参数几乎不更新
  • 梯度爆炸(Gradient Exploding) 📈:每个矩阵的特征值都大于 1,连乘后梯度指数级增长,参数更新过大导致模型发散

直观类比 🎯:想象一排多米诺骨牌------如果每块骨牌比前一块稍微小一点(特征值 < 1),到最后一排时力量已经微弱到无法推倒(梯度消失);如果每块都比前一块大一点(特征值 > 1),到最后排时冲击力已经大到失控(梯度爆炸)。

参考资料:

1.2 对称性问题 🪞 / Symmetry Breaking Problem

如果将同一层的所有权重初始化为相同的值 cc c,那么前向传播时,该层所有神经元会产生完全相同的输出。反向传播时,这些神经元收到的梯度也完全相同,导致参数更新后权重仍然一样。⚠️

这意味着无论训练多久,同一层的所有神经元都在做同样的事情------网络的有效宽度退化为 1。

直观类比 🏫:想象一个课堂上所有学生都抄同一份答案------即使老师(梯度)反复纠正,所有人的答案始终一样,没有任何分工协作。只有每人独立思考(随机初始化),才能发挥集体的力量。

解决方案随机初始化(Random Initialization)------从某个概率分布中随机采样权重,打破这种对称性,让每个神经元从训练一开始就学到不同的特征。🎲

参考资料:

1.3 朴素初始化的缺陷 ❌ / Flaws of Naive Initialization

最朴素的思路是从标准正态分布 N(0,1)\mathcal{N}(0, 1) N(0,1) 中采样权重(均值为 0,方差为 1)。但这种做法在深度网络中会导致严重的方差偏移。

考虑一个线性层 y=w1x1+w2x2+...+wnxn y = w_1 x_1 + w_2 x_2 + \ldots + w_n x_n y=w1x1+w2x2+...+wnxn,假设每个 xi x_i xi 和 wi w_i wi 都独立同分布,且均值为 0、方差为 1。由于各项独立,输出的方差为:
Var(y)= ∑i=1n Var(wixi)=n⋅Var(w)⋅Var(x)=n⋅1⋅1=n\text{Var}(y) = \sum_{i=1}^{n} \text{Var}(w_i x_i) = n \cdot \text{Var}(w) \cdot \text{Var}(x) = n \cdot 1 \cdot 1 = n Var(y)=i=1∑nVar(wixi)=n⋅Var(w)⋅Var(x)=n⋅1⋅1=n

也就是说,经过一个有 nn n 个输入的线性层后,输出的方差变成了 nn n 倍!如果 n=512n = 512 n=512(常见的模型维度),方差会膨胀 512 倍,标准差膨胀约 22.6 倍。经过多层传播后,数值会彻底失控。💥

直观类比 📢:朴素的 N(0,1)\mathcal{N}(0,1) N(0,1) 初始化就像把 512 个人同时对着一个麦克风说话------每个人的声音正常大小,但 512 个人叠加在一起,声音就震耳欲聋了。Xavier 初始化的作用就是让每个人"小声说话",使叠加后的总音量保持在正常水平。

参考资料:


2. Xavier 初始化的核心思想 💡 / Core Idea of Xavier Initialization

💡 Note: 本章讲解 Xavier 初始化的核心原则------方差保持 / This chapter explains the core principle of Xavier initialization --- variance preservation.

2.1 方差保持原则 ⚖️ / Variance Preservation Principle

Xavier 初始化由 Xavier Glorot 和 Yoshua Bengio 在 2010 年的论文 "Understanding the difficulty of training deep feedforward neural networks" 中提出。其核心思想可以用一句话概括:

让每一层的激活值方差和梯度方差在传播过程中保持不变。 🎯

具体来说,需要同时满足两个条件:

  1. 前向传播:每一层输出的方差 = 输入的方差(信号不衰减也不放大)
  2. 反向传播:每一层梯度的方差 = 下一层梯度的方差(梯度不消失也不爆炸)

直观类比 🔊:想象一套音响系统,信号从麦克风(输入)经过多个放大器(网络层)传到扬声器(输出)。如果每个放大器都把信号放大,最后会失真(梯度爆炸);如果每个都衰减信号,最后听不到声音(梯度消失)。Xavier 初始化就是让每个放大器的增益恰好为 1------信号进来什么样,出去还是什么样,保真传输。

2.2 关键假设 📋 / Key Assumptions

为了推导方差保持的条件,需要做出以下简化假设:

假设 说明
零均值 权重和输入都以 0 为均值: E wij =0 Ew_{ij} = 0 Ewij=0, Exj=0 Ex_j = 0 Exj=0
独立同分布 权重之间、输入之间、权重与输入之间相互独立
偏置为零 偏置初始化为 0: bi=0 b_i = 0 bi=0
线性激活 使用近似线性的激活函数(如 tanh),使得 Var(a)≈Var(z)\text{Var}(a) \approx \text{Var}(z) Var(a)≈Var(z)

这些假设虽然简化了现实情况,但推导出的初始化方法在实践中被证明非常有效。💪

参考资料:


3. 数学推导 📐 / Mathematical Derivation

📐 Note: 本章完整推导 Xavier 初始化的数学公式 / This chapter provides the complete mathematical derivation of Xavier initialization.

3.1 前向传播的方差分析 📈 / Forward Pass Variance Analysis

考虑第 ll l 层的一个输出神经元 oi o_i oi:
oi= ∑j=1nin wij xj+bi o_i = \sum_{j=1}^{n_{\text{in}}} w_{ij} x_j + b_i oi=j=1∑ninwijxj+bi

其中 nin n_{\text{in}} nin 是该层的输入连接数(fan-in), xj x_j xj 是输入, wij w_{ij} wij 是权重。

由于 bi=0 b_i = 0 bi=0,且 wij w_{ij} wij 和 xj x_j xj 独立、均值为 0,根据独立随机变量乘积的方差公式:
Var( wij xj)=Exj2Var( wij )+E wij 2Var(xj)+Var( wij )Var(xj) \text{Var}(w_{ij} x_j) = Ex_j^2 \text{Var}(w_{ij}) + Ew_{ij}^2 \text{Var}(x_j) + \text{Var}(w_{ij}) \text{Var}(x_j) Var(wijxj)=Exj2Var(wij)+Ewij2Var(xj)+Var(wij)Var(xj)

因为 Exj=0 Ex_j = 0 Exj=0 且 E wij =0 Ew_{ij} = 0 Ewij=0,前两项为 0,简化为:
Var( wij xj)=Var( wij )⋅Var(xj) \text{Var}(w_{ij} x_j) = \text{Var}(w_{ij}) \cdot \text{Var}(x_j) Var(wijxj)=Var(wij)⋅Var(xj)

由于各项独立,输出的方差等于各项方差之和:
Var(oi)= ∑j=1nin Var( wij )⋅Var(xj)=nin⋅σw2⋅σx2 \text{Var}(o_i) = \sum_{j=1}^{n_{\text{in}}} \text{Var}(w_{ij}) \cdot \text{Var}(x_j) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2 Var(oi)=j=1∑ninVar(wij)⋅Var(xj)=nin⋅σw2⋅σx2

其中 σw2=Var( wij ) \sigma_w^2 = \text{Var}(w_{ij}) σw2=Var(wij) 是权重的方差, σx2=Var(xj) \sigma_x^2 = \text{Var}(x_j) σx2=Var(xj) 是输入的方差。

方差保持条件 :要使输出方差等于输入方差( Var(oi)=Var(xj) \text{Var}(o_i) = \text{Var}(x_j) Var(oi)=Var(xj)),需要:
nin⋅σw2=1⟹σw2= 1nin n_{\text{in}} \cdot \sigma_w^2 = 1 \quad \Longrightarrow \quad \sigma_w^2 = \frac{1}{n_{\text{in}}} nin⋅σw2=1⟹σw2=nin1

3.2 反向传播的方差分析 📉 / Backward Pass Variance Analysis

反向传播时,梯度从输出端向输入端流动。对于同一层,考虑梯度 ∂L ∂xj \frac{\partial L}{\partial x_j} ∂xj∂L 关于输入 xj x_j xj 的方差:
∂L ∂xj = ∑i=1nout wij ⋅ ∂L ∂oi \frac{\partial L}{\partial x_j} = \sum_{i=1}^{n_{\text{out}}} w_{ij} \cdot \frac{\partial L}{\partial o_i} ∂xj∂L=i=1∑noutwij⋅∂oi∂L

其中 nout n_{\text{out}} nout 是该层的输出连接数(fan-out)。

用与前向传播完全相同的推导逻辑,梯度方差保持的条件为:
nout⋅σw2=1⟹σw2= 1nout n_{\text{out}} \cdot \sigma_w^2 = 1 \quad \Longrightarrow \quad \sigma_w^2 = \frac{1}{n_{\text{out}}} nout⋅σw2=1⟹σw2=nout1

3.3 兼顾前向与反向 ⚖️ / Balancing Forward and Backward

现在面临一个两难困境:

  • 前向传播要求 σw2= 1nin \sigma_w^2 = \dfrac{1}{n_{\text{in}}} σw2=nin1
  • 反向传播要求 σw2= 1nout \sigma_w^2 = \dfrac{1}{n_{\text{out}}} σw2=nout1

对于同一层, nin n_{\text{in}} nin 和 nout n_{\text{out}} nout 通常不同,无法同时满足两个条件。Xavier 初始化的解决方案是取两者的调和平均
σw2= 2 nin+nout \sigma_w^2 = \frac{2}{n_{\text{in}} + n_{\text{out}}} σw2=nin+nout2

直观类比 🎚️:就像调节音响的均衡器------前向传播是"高音",反向传播是"低音",你不可能把两个旋钮都调到最大。Xavier 的做法是取一个折中值,让高音和低音都保持在可接受的范围内。

3.4 两种分布形式 📊 / Two Distribution Forms

根据权重采样分布的不同,Xavier 初始化有两种常用形式:

正态分布形式(Xavier Normal) 🔔

从正态分布 N(0,σw2) \mathcal{N}(0, \sigma_w^2) N(0,σw2) 中采样:
wij ∼N (0, 2 nin+nout ) w_{ij} \sim \mathcal{N}\left(0, \frac{2}{n_{\text{in}} + n_{\text{out}}}\right) wij∼N(0,nin+nout2)

均匀分布形式(Xavier Uniform) 📏

从均匀分布 U(−a,a)U(-a, a) U(−a,a) 中采样。由于均匀分布 U(−a,a)U(-a, a) U(−a,a) 的方差为 a23 \dfrac{a^2}{3} 3a2,令其等于 σw2 \sigma_w^2 σw2:
a23= 2 nin+nout ⟹a= 6 nin+nout \frac{a^2}{3} = \frac{2}{n_{\text{in}} + n_{\text{out}}} \quad \Longrightarrow \quad a = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} 3a2=nin+nout2⟹a=nin+nout6

因此:
wij ∼U (− 6 nin+nout , 6 nin+nout ) w_{ij} \sim U\left(-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right) wij∼U(−nin+nout6 ,nin+nout6 )

3.4 为什么公式中会出现 2 和 6? 🤔 / Why 2 and 6 in the Formulas

在代码中,我们分别看到 sqrt(2 / (fan_in + fan_out))sqrt(6 / (fan_in + fan_out))。这两个常数 26 不是随意设定的,而是从方差保持条件推导出来的。

正态分布中的 2 📐 / The 2 in Normal Distribution

方差保持条件要求权重方差为:
σw2= 2 nin+nout \sigma_w^2 = \frac{2}{n_{\text{in}} + n_{\text{out}}} σw2=nin+nout2

对于正态分布 N(0,σ2)\mathcal{N}(0, \sigma^2) N(0,σ2),参数 std 就是标准差 σ\sigma σ,所以:
std= σw2 = 2 nin+nout \text{std} = \sqrt{\sigma_w^2} = \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}} std=σw2 =nin+nout2

这就是为什么代码中用 sqrt(2 / (fan_in + fan_out)) ------分子 2 来自前向和反向传播方差的调和平均。

均匀分布中的 6 📐 / The 6 in Uniform Distribution

对于均匀分布 U(−a,a)U(-a, a) U(−a,a),其方差为:
VarU(−a,a)= (a−(−a))212 = (2a)212 = 4a212 =a23\text{Var}U(-a, a) = \frac{(a - (-a))^2}{12} = \frac{(2a)^2}{12} = \frac{4a^2}{12} = \frac{a^2}{3} VarU(−a,a)=12(a−(−a))2=12(2a)2=124a2=3a2

我们要求这个方差等于 Xavier 的方差条件 σw2= 2 nin+nout \sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}} σw2=nin+nout2:
a23= 2 nin+nout \frac{a^2}{3} = \frac{2}{n_{\text{in}} + n_{\text{out}}} 3a2=nin+nout2

两边同时乘以 3:
a2= 6 nin+nout a^2 = \frac{6}{n_{\text{in}} + n_{\text{out}}} a2=nin+nout6

开平方得:
a= 6 nin+nout a = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} a=nin+nout6

这就是为什么代码中用 sqrt(6 / (fan_in + fan_out)) ------6 是由 2×3=62 \times 3 = 6 2×3=6 得来的,其中 2 来自方差保持条件,3 来自均匀分布方差公式 a23 \dfrac{a^2}{3} 3a2 的分母。

直观理解 🎯:把 6 拆解为两个部分

  • 2 = 前向传播 + 反向传播(调和平均的两个方向)
  • 3 = 均匀分布 U(−a,a)U(-a, a) U(−a,a) 的方差特性(区间宽度的平方除以 12 化简后的结果)
  • 6 = 2 × 3 = 方差保持条件与均匀分布特性的结合

💡 记忆技巧 💭:均匀分布的 6 是正态分布的 2 的 3 倍,因为均匀分布 "更平"(概率密度在整个区间均匀分布),需要更大的范围才能达到相同的方差效果。

3.5 公式总结 📋 / Formula Summary

形式 公式 常数来源 适用场景
Xavier Normal w∼N (0, 2 nin+nout ) w \sim \mathcal{N}\left(0, \dfrac{2}{n_{\text{in}} + n_{\text{out}}}\right) w∼N(0,nin+nout2) 2 = 前向 + 反向调和平均 需要精确控制方差时
Xavier Uniform w∼U (− 6 nin+nout , 6 nin+nout ) w \sim U\left(-\sqrt{\dfrac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\dfrac{6}{n_{\text{in}} + n_{\text{out}}}}\right) w∼U(−nin+nout6 ,nin+nout6 ) 6 = 2 × 3(均匀分布方差特性) PyTorch 默认形式

💡 注意 :在 CS336 作业中,Xavier 初始化通常使用正态分布形式,权重标准差为 σ= 2 nin+nout \sigma = \sqrt{\dfrac{2}{n_{\text{in}} + n_{\text{out}}}} σ=nin+nout2 。

参考资料:


4. PyTorch 代码实现 💻 / PyTorch Implementation

💻 Note: 本章通过代码实现 Xavier 初始化,验证方差保持效果 / This chapter implements Xavier initialization in code and verifies the variance preservation effect.

4.1 手动实现 Xavier 初始化 🔧 / Manual Implementation

python 复制代码
import torch                                              # 导入 PyTorch 核心库 🔥
import math                                               # 导入数学库,用于计算平方根 🔢

"""Xavier 正态分布初始化(手动实现) 🎯

参数 / Args:
    tensor: 需要初始化的权重张量 / Weight tensor to initialize
    
返回 / Returns:
    原地修改后的张量 / In-place modified tensor
    
示例 / Example:
    W = torch.empty(512, 256)
    xavier_normal(W)  # W 被原地修改为 Xavier 初始化的值
"""
def xavier_normal(tensor):
    # 计算 fan_in 和 fan_out 📐
    fan_in = tensor.size(1) if tensor.dim() > 1 else tensor.size(0)  # 输入连接数
    fan_out = tensor.size(0)                                          # 输出连接数
    
    # 计算 Xavier 标准差:sqrt(2 / (fan_in + fan_out)) ⚖️
    std = math.sqrt(2.0 / (fan_in + fan_out))
    
    # 从正态分布 N(0, std^2) 中采样,原地修改 🎲
    with torch.no_grad():
        tensor.normal_(mean=0.0, std=std)
    return tensor


"""Xavier 均匀分布初始化(手动实现) 📏

参数 / Args:
    tensor: 需要初始化的权重张量 / Weight tensor to initialize
    
返回 / Returns:
    原地修改后的张量 / In-place modified tensor
    
示例 / Example:
    W = torch.empty(512, 256)
    xavier_uniform(W)  # W 被原地修改为 Xavier 初始化的值
"""
def xavier_uniform(tensor):
    # 计算 fan_in 和 fan_out 📐
    fan_in = tensor.size(1) if tensor.dim() > 1 else tensor.size(0)  # 输入连接数
    fan_out = tensor.size(0)                                          # 输出连接数
    
    # 计算均匀分布的边界 a = sqrt(6 / (fan_in + fan_out)) ⚖️
    a = math.sqrt(6.0 / (fan_in + fan_out))
    
    # 从均匀分布 U(-a, a) 中采样,原地修改 🎲
    with torch.no_grad():
        tensor.uniform_(-a, a)
    return tensor

4.2 PyTorch 原生函数 ⚡ / PyTorch Native Functions

PyTorch 在 torch.nn.init 模块中提供了开箱即用的 Xavier 初始化函数:

python 复制代码
import torch                                              # 导入 PyTorch 核心库 🔥
import torch.nn as nn                                     # 导入神经网络模块 🧠
import torch.nn.init as init                              # 导入初始化工具箱 🛠️

# ========== 方式一:直接初始化张量 ==========

# 创建权重张量 🏗️
W = torch.empty(256, 512)                                 # 形状 [fan_out=256, fan_in=512]

# Xavier 正态分布 ⚡
init.xavier_normal_(W)                                    # 原地修改,W ~ N(0, 2/(512+256))

# Xavier 均匀分布 ⚡
init.xavier_uniform_(W)                                   # 原地修改,W ~ U(-a, a)

# 带增益参数 ⚖️
init.xavier_normal_(W, gain=1.5)                          # gain 参数放大标准差,适用于特殊激活函数


# ========== 方式二:在模型定义时初始化 ==========

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(512, 256, bias=False)    # 无偏置线性层
        self.linear2 = nn.Linear(256, 128, bias=False)    # 无偏置线性层
        
        # 对所有线性层应用 Xavier 正态初始化 🎯
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight)             # 初始化权重

model = SimpleNet()                                       # 创建模型实例

参数说明 📋:

参数 说明
tensor 需要初始化的权重张量
gain 可选的缩放因子,默认 1.0。使用不同激活函数时需要调整(如 ReLU 对应 2 \sqrt{2} 2 )

4.3 验证方差保持效果 ✅ / Verifying Variance Preservation

python 复制代码
import torch                                              # 导入 PyTorch 🔥
import torch.nn as nn                                     # 导入神经网络模块 🧠
import torch.nn.init as init                              # 导入初始化工具箱 🛠️
import math                                               # 导入数学库 🔢

# 设置随机种子,保证结果可复现 🎯
torch.manual_seed(42)

# ========== 实验对比 ==========

# 构建一个 5 层网络,每层 512 个神经元 🏗️
n_layers = 5
n_features = 512

# --- 对比 1:朴素初始化 N(0, 1) ---
print("=" * 60)                                           # 打印分隔线 📏
print("朴素初始化 N(0, 1):")                               # 打印标题 📝
print("=" * 60)                                           # 打印分隔线 📏

weights_naive = [torch.randn(n_features, n_features) for _ in range(n_layers)]  # 标准正态初始化 🎲
x = torch.randn(1, n_features)                            # 随机输入 🎲

for i, w in enumerate(weights_naive):
    x = x @ w.T                                           # 前向传播(线性变换) ⚡
    print(f"  第 {i+1} 层输出标准差: {x.std().item():.4f}")  # 观察方差变化 🔍

# --- 对比 2:Xavier 正态初始化 ---
print("=" * 60)                                           # 打印分隔线 📏
print("Xavier 正态初始化:")                                # 打印标题 📝
print("=" * 60)                                           # 打印分隔线 📏

weights_xavier = [torch.empty(n_features, n_features) for _ in range(n_layers)]  # 创建空权重 🏗️
for w in weights_xavier:
    init.xavier_normal_(w)                                # Xavier 初始化 ⚡

x = torch.randn(1, n_features)                            # 相同的随机输入 🎲

for i, w in enumerate(weights_xavier):
    x = x @ w.T                                           # 前向传播(线性变换) ⚡
    print(f"  第 {i+1} 层输出标准差: {x.std().item():.4f}")  # 观察方差变化 🔍

运行结果示例 📊:

markdown 复制代码
============================================================
朴素初始化 N(0, 1):
============================================================
  第 1 层输出标准差: 22.8856
  第 2 层输出标准差: 525.8154
  第 3 层输出标准差: 11866.4004
  第 4 层输出标准差: 269670.8750
  第 5 层输出标准差: 6072693.5000
============================================================
Xavier 正态初始化:
============================================================
  第 1 层输出标准差: 1.0172
  第 2 层输出标准差: 1.0472
  第 3 层输出标准差: 1.1137
  第 4 层输出标准差: 1.1729
  第 5 层输出标准差: 1.1200

可以看到:👀

  • 朴素初始化:标准差从第 1 层的 22.9 爆炸到第 5 层的 607 万,完全失控
  • Xavier 初始化:标准差始终稳定在 1.0 附近(1.02 ~ 1.17),方差保持效果显著

💡 Key Takeaways / 核心要点

  • Naive N(0,1) causes variance explosion --- std grows exponentially with depth / 朴素 N(0,1) 导致方差爆炸,标准差随深度指数增长
  • Xavier initialization preserves variance --- std stays near 1.0 across layers / Xavier 初始化保持方差稳定,标准差在各层保持在 1.0 附近
  • Variance preservation prevents gradient issues --- stable forward pass implies stable backward pass / 方差保持防止梯度问题,前向稳定意味着反向也稳定

参考资料:


5. 总结 📝 / Summary

本节深入讲解了 Xavier 初始化的原理、推导和实现,核心要点回顾:🎯

要点 内容
问题 🔍 朴素初始化导致方差逐层膨胀或收缩,引发梯度消失/爆炸
核心思想 💡 方差保持------使每层激活值和梯度的方差在传播中不变
正态分布公式 🔔 w∼N (0, 2 nin+nout ) w \sim \mathcal{N}\left(0, \dfrac{2}{n_{\text{in}} + n_{\text{out}}}\right) w∼N(0,nin+nout2)
均匀分布公式 📏 w∼U (− 6 nin+nout , 6 nin+nout ) w \sim U\left(-\sqrt{\dfrac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\dfrac{6}{n_{\text{in}} + n_{\text{out}}}}\right) w∼U(−nin+nout6 ,nin+nout6 )
适用激活函数 ⚙️ tanh、sigmoid 等线性/近似线性激活函数
PyTorch 函数 init.xavier_normal_()init.xavier_uniform_()

🔴 关键理解

  • 💡 Xavier 初始化通过缩小权重的初始方差( σ2= 2 nin+nout \sigma^2 = \frac{2}{n_{\text{in}} + n_{\text{out}}} σ2=nin+nout2),抵消了多输入求和带来的方差膨胀效应 ⚖️
  • ⚡ PyTorch 原生函数可直接使用,无需手动计算 fan-in 和 fan-out 🚀
  • 📐 对于 ReLU 激活函数,应使用 Kaiming 初始化(He 初始化),因为 ReLU 会将负值置零,改变了方差关系 🔧

最后更新时间:2026-06-16

相关推荐
thinking_talk1 小时前
会议纪要工具选型指南
人工智能·语音识别
牢七1 小时前
这几天的 ai 体验
人工智能
2601_961963381 小时前
Spring Boot集成电子签章的7个典型问题与解决方案:从入门到生产级实践
大数据·人工智能·spring boot·python·区块链·智能合约
诺***帝1 小时前
GPT-Image-2 氛围渲染能力全解析:光影、景深、材质还原的 Prompt 实战教程
人工智能·gpt
朱大喜1 小时前
机器学习驱动的异常检测:从统计基线到根因定位的工程化实战
人工智能
字节跳动数据库1 小时前
文章分享——好代码 - 半点没用的话题
人工智能·程序员
xcLeigh1 小时前
数学之美:数字革命背后的底层逻辑
人工智能·数学·ai·数学原理·书籍·数学之美·绝对边界
星轨zb1 小时前
[Corner项目实战]Spring Boot + LangChain4j Tool Calling实战:让AI自动选择推荐策略
人工智能·spring boot·后端·langchain4j
Deepoch1 小时前
VLA多模态架构赋能无人机 拓展全域智能巡检应用
人工智能·机器人·无人机·具身模型·deepoc