目录

Kaiming Uniform 初始化:神经网络权重初始化的优雅解决方案

Kaiming Uniform 初始化:神经网络权重初始化的优雅解决方案

在深度学习的模型训练中,权重初始化的选择对网络的收敛速度和性能有着深远影响。传统的随机初始化(如高斯分布)在浅层网络中尚可接受,但随着网络深度增加,梯度消失或爆炸的问题变得愈发严重。为了解决这一问题,Kaiming He 等人在 2015 年提出了 Kaiming 初始化(也称为 He 初始化),其中 kaiming_uniform_ 是 PyTorch 中实现的一种均匀分布形式。本文将深入探讨其原理、实现细节以及在 LoRA(Low-Rank Adaptation)中的应用。

一、背景与动机

在深度神经网络(如 CNN 或 Transformer)中,前向传播和反向传播的稳定性依赖于每一层的输入和输出的数值范围。如果权重初始化不当,可能导致:

  • 梯度消失:每一层的激活值过小,梯度在反向传播中逐渐衰减至零。
  • 梯度爆炸:激活值过大,梯度在反向传播中指数增长,导致训练不稳定。

Kaiming 初始化通过分析网络的方差传播,提出了一种基于层输入和输出维度的初始化方法,确保信号在深层网络中的稳定传递。它最初在论文 Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification(He et al., 2015)中提出,针对 ReLU 激活函数设计,后来被广泛应用于各类网络。

二、Kaiming Uniform 的数学原理

Kaiming 初始化假设网络使用 ReLU 或其变体(如 Leaky ReLU)作为激活函数,其核心目标是保持每一层的输出方差与输入方差一致。

1. 前向传播的方差分析

考虑一个线性层:
y = W x y = W x y=Wx

  • ( x ∈ R n i n x \in \mathbb{R}^{n_{in}} x∈Rnin ):输入向量,( n i n n_{in} nin ) 是输入维度。
  • ( W ∈ R n o u t × n i n W \in \mathbb{R}^{n_{out} \times n_{in}} W∈Rnout×nin ):权重矩阵,( n o u t n_{out} nout ) 是输出维度。
  • ( y ∈ R n o u t y \in \mathbb{R}^{n_{out}} y∈Rnout ):输出向量。

假设:

  • ( x x x ) 的每个元素是独立同分布(i.i.d.),方差为 ( Var ( x ) \text{Var}(x) Var(x) )。
  • ( W W W ) 的每个元素也是 i.i.d.,初始方差为 ( Var ( W ) \text{Var}(W) Var(W) )。

则输出 ( y i y_i yi ) 的方差为:
Var ( y i ) = Var ( ∑ j = 1 n i n W i j x j ) = ∑ j = 1 n i n Var ( W i j x j ) \text{Var}(y_i) = \text{Var}\left( \sum_{j=1}^{n_{in}} W_{ij} x_j \right) = \sum_{j=1}^{n_{in}} \text{Var}(W_{ij} x_j) Var(yi)=Var(j=1∑ninWijxj)=j=1∑ninVar(Wijxj)

若 ( W i j W_{ij} Wij ) 和 ( x j x_j xj ) 独立:
Var ( y i ) = n i n ⋅ Var ( W i j ) ⋅ Var ( x j ) = n i n ⋅ Var ( W ) ⋅ Var ( x ) \text{Var}(y_i) = n_{in} \cdot \text{Var}(W_{ij}) \cdot \text{Var}(x_j) = n_{in} \cdot \text{Var}(W) \cdot \text{Var}(x) Var(yi)=nin⋅Var(Wij)⋅Var(xj)=nin⋅Var(W)⋅Var(x)

对于 ReLU 激活 ( f ( y ) = max ⁡ ( 0 , y ) f(y) = \max(0, y) f(y)=max(0,y) ),输出方差减半(因为一半的激活被置零):
Var ( f ( y ) ) = 1 2 n i n ⋅ Var ( W ) ⋅ Var ( x ) \text{Var}(f(y)) = \frac{1}{2} n_{in} \cdot \text{Var}(W) \cdot \text{Var}(x) Var(f(y))=21nin⋅Var(W)⋅Var(x)

为了保持 ( Var ( f ( y ) ) = Var ( x ) \text{Var}(f(y)) = \text{Var}(x) Var(f(y))=Var(x) ):
1 2 n i n ⋅ Var ( W ) = 1    ⟹    Var ( W ) = 2 n i n \frac{1}{2} n_{in} \cdot \text{Var}(W) = 1 \implies \text{Var}(W) = \frac{2}{n_{in}} 21nin⋅Var(W)=1⟹Var(W)=nin2

2. 均匀分布的参数

Kaiming 初始化使用均匀分布 ( U ( − a , a ) U(-a, a) U(−a,a) ) 初始化权重,其中 ( a a a ) 是边界值。对于均匀分布:
Var ( W ) = ( a − ( − a ) ) 2 12 = ( 2 a ) 2 12 = 4 a 2 12 = a 2 3 \text{Var}(W) = \frac{(a - (-a))^2}{12} = \frac{(2a)^2}{12} = \frac{4a^2}{12} = \frac{a^2}{3} Var(W)=12(a−(−a))2=12(2a)2=124a2=3a2

令:
a 2 3 = 2 n i n \frac{a^2}{3} = \frac{2}{n_{in}} 3a2=nin2

解得:
a = 6 n i n a = \sqrt{\frac{6}{n_{in}}} a=nin6

因此,权重初始化为:
W ∼ U ( − 6 n i n , 6 n i n ) W \sim U\left(-\sqrt{\frac{6}{n_{in}}}, \sqrt{\frac{6}{n_{in}}}\right) W∼U(−nin6 ,nin6 )

3. 参数 ( a a a ) 的调整

在 PyTorch 的 nn.init.kaiming_uniform_ 中,参数 ( a a a ) 允许用户调整分布的宽度,默认值为 ( 5 \sqrt{5} 5 )(对应 Xavier 初始化的一种变体)。但在 Kaiming 原始设计中,( a = 0 a = 0 a=0 )(纯 ReLU)或基于激活函数的斜率调整。


三、PyTorch 中的实现

PyTorch 提供了 nn.init.kaiming_uniform_ 函数,签名如下:

python 复制代码
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='relu')
  • tensor :要初始化的张量,通常是权重矩阵(形状为 ( [ n o u t , n i n ] [n_{out}, n_{in}] [nout,nin] ))。
  • a:负斜率参数,对于 ReLU 通常为 0,对于 Leaky ReLU 为斜率值。
  • mode :选择 fan_in(输入维度,推荐用于前向稳定性)或 fan_out(输出维度,用于反向稳定性)。
  • nonlinearity :激活函数类型(如 'relu''leaky_relu')。

在 LoRA 的代码中(例如 Microsoft LoRA 实现, https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py):

python 复制代码
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
  • 这里 ( a = 5 a = \sqrt{5} a=5 ) 是 Xavier 初始化的遗留参数,但实际效果接近 Kaiming。
  • self.lora_A.weight 是 ( [ d , r ] [d, r] [d,r] ) 形状的矩阵,( d d d ) 是输入维度,( r r r ) 是秩。

四、在 LoRA 中的应用

LoRA(Low-Rank Adaptation)通过低秩矩阵 ( A A A ) 和 ( B B B ) 对预训练权重进行微调。在 reset_lora_parameters 函数中:

  • ( A A A ) 使用 kaiming_uniform_ 初始化:

    python 复制代码
    nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
  • ( B B B ) 初始化为零:

    python 复制代码
    nn.init.zeros_(self.lora_B[adapter_name].weight)

1. 为什么用 Kaiming 初始化 ( A A A )?

  • 数值稳定性 :( A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r ) 的初始值需要适度随机化,以确保训练早期 ( Δ W = A B \Delta W = A B ΔW=AB) 能够探索不同的低秩子空间。Kaiming 初始化根据输入维度 ( d d d ) 设置合理的方差,避免 ( A A A ) 的值过大或过小。
  • 与 ( B ) 零初始化的配合 :( B = 0 B = 0 B=0 ) 确保初始 ( Δ W = 0 \Delta W = 0 ΔW=0),而 ( A A A ) 的 Kaiming 初始化为后续更新提供了非零起点,避免训练陷入全零状态。
  • 兼容性:LoRA 常用于 Transformer 等深层模型,Kaiming 初始化与这些架构的 ReLU 或 GELU 激活兼容。

2. ( a = 5 a = \sqrt{5} a=5 ) 的意义

  • ( a = 5 a = \sqrt{5} a=5 ) 实际上是向后兼容 Xavier 初始化的一种选择,但与 Kaiming 的 ( a = 0 a = 0 a=0 )(纯 ReLU)略有不同。它使得分布稍宽,可能增加初始探索性,但仍接近 ( 6 n i n \sqrt{\frac{6}{n_{in}}} nin6 ) 的理论值。

五、优势与局限性

优势

  • 稳定性:通过控制方差,Kaiming 初始化显著减少了深层网络中的梯度问题。
  • 通用性:适用于大多数使用 ReLU 及其变体的网络。
  • 简单性:只需输入维度即可计算,无需复杂调参。

局限性

  • 激活函数依赖:主要针对 ReLU 设计,对于 Sigmoid 或 Tanh 等激活可能不适用。
  • 固定假设:假设输入是 i.i.d.,在实际复杂数据中可能不完全成立。
  • LoRA 场景的特殊性 :( r r r ) 远小于 ( d d d ),低秩约束可能削弱初始化的理论效果。

六、扩展与改进方向

对于 LoRA 或其他场景,可以考虑以下改进:

  1. 动态调整 ( a a a ) :根据 ( r r r ) 或任务复杂度自适应选择 ( a a a ),而不是固定为 ( 5 \sqrt{5} 5 )。
  2. 正交初始化 :对于低秩矩阵 ( A A A ),尝试 torch.nn.init.orthogonal_,可能提升多样性。
  3. 与 scaling 结合 :LoRA 中的 scaling = lora_alpha / r 可以与初始化协同设计,避免重复调整幅度。

七、结语

nn.init.kaiming_uniform_ 是深度学习中权重初始化的经典方法,通过均匀分布确保信号在网络中的稳定传播。在 LoRA 中,它为 ( A A A ) 提供了合理的初始值,与 ( B B B ) 的零初始化配合,兼顾了稳定性和学习能力。对于研究者而言,理解其背后的数学原理不仅有助于调参,还能启发新的初始化策略。欢迎在评论区分享你的使用经验或改进想法!

后记

2025年3月11日22点44分于上海,在Grok 3大模型辅助下完成。

本文是转载文章,点击查看原文
如有侵权,请联系 xyy@jishuzhan.net 删除
相关推荐
孔令飞7 分钟前
16 | 实现简洁架构的 Store 层
人工智能·ai·云原生·golang·kubernetes
zzzyzh35 分钟前
Work【2】:PGP-SAM —— 无需额外提示的自动化 SAM!
人工智能·深度学习·计算机视觉·sam·medical·image segment
极客 - L U42 分钟前
机器学习 : 训练过程
人工智能·机器学习
醒李1 小时前
AP AR
人工智能
今天炼丹了吗1 小时前
RTDETR融合[CVPR2025]ARConv中的自适应矩阵卷积
人工智能·深度学习·计算机视觉
pen-ai2 小时前
【NLP】 5. Word Analogy Task(词类比任务)与 Intrinsic Metric(内在度量)
人工智能·自然语言处理·word
大湾区经济门户网2 小时前
科技工作者之家建设扬帆起航,为科技人才提供更多优质服务
大数据·人工智能·科技·媒体
爱嘿嘿的小黑2 小时前
宇宙厂学到的思维模型,工作学习必备
前端·人工智能·面试
夏莉莉iy2 小时前
[ICLR 2025]CBraMod: A Criss-Cross Brain Foundation Model for EEG Decoding
人工智能·python·深度学习·神经网络·机器学习·计算机视觉·transformer
广药门徒2 小时前
yolo环境 pytorch环境配置 CUDA安装
人工智能·pytorch·yolo