深度学习 --- stanford cs231学习笔记五(训练神经网络的几个重要组成部分之三,权重矩阵的初始化)

权重矩阵的初始化

3,权重矩阵的初始化

深度学习所学习的重点就是要根据损失函数训练权重矩阵中的系数。即便如此,权重函数也不能为空,总是需要初始化为某个值。


3,1 全都初始化为同一个常数可以吗?

首先要简单回顾一下隐含层中的神经元,他是由权重矩阵中的每行系数与输入x的乘积后接一个非线性函数决定的。 W有多少行,隐藏层就有多少个神经元。

因此,当矩阵权重W中的所有元素都是同一个常数时,所有神经元的计算结果不论是在前向传播的过程中,还是在反向传播的过程中计算结果都是一样的。如此一来,隐藏层所有的神经元的功效都废了,变成了只有一个神经元。

例如,把W的所有元素都初始化为0。不论有多少个神经元,那么前向传播的计算结果都是0,反向传播的结果都相同。


3,2 把W初始化为一组小的随机数

下面是一个6层的神经网络,有5个隐含层,每层都有4096个神经元。

在初始化的时候把权重矩阵W初始化为均值为0标准差为1的随机数,并且让这组数统一乘以一个很小的数。使用的激活函数为tanh,每层的计算结果也就是神经元的值,保存在hs中。

下图为每一层神经元值的分布:

可见随着神经网络的深度越来越深,越来越多的神经元的值为0。

对于第i层而言,前向传播的公式为:

其中表示第i层的神经元。结合上面的结果来看,当前向传播到很深层的网络后,深层的神经元就全是死神经元了

此外,在反向传播时,关于第i层的权重W的本地梯度为:

因此,当深层网络神经元的值(也就是上面公式中的)很多都是0或者趋近于0后,梯度最终会趋于0,即,梯度消失 。也就是说, 把W初始化为一组小的随机数是行不通的。


3,3 把W初始化为一组不太小的随机数

既然乘以0.01不行,容易出现梯度消失,何不试一试乘以0.05呢?同样是6层网络,每层同样是4096个神经元,同样是用tanh为激活函数。

结合每层神经元值的分布来看,出现1和-1的概率比较高。

结合本地梯度来看,容易让1-tanh(x*w)^2为0,即,本地梯度为0。


3,4 如果依然要用随机数,缩放的比例是多少才合适呢?Xavier

同样是6层网络,每层同样是4096个神经元,同样是用tanh为激活函数。所不同的是,之前是通过手动调整缩放系数观察神经元值的分布。现在是基于输入的尺寸,自适应的选择缩放系数。这种初始化的方法被称之为Xavier初始化。他有严格的数学证明,其目的是使得每层神经网络在前向传播和反向传播过程中保持输出的方差一致。

计算结果如下图所示,经过xavier初始化后,所有隐藏层的神经元即不会集中在0附近,也不会徘徊于+-1两端。(对于tanh激活函数而言)


3,5 Kaiming初始化/He初始化

上面提到的Xavier初始化,对于激活函数为tanh的网络是适用的,表现结果也比较好。但当激活函数为ReLU的网络中,依然会出现梯度消失的情况。这是ReLU函数自身天然决定的。

为了克服这个问题何凯明发明了一种适合ReLU函数的初始化方式。采用kaiming初始化后的后的直方图会分散的更加均匀,而不是集中在0附近。


4, 批归一化(Batch Normalization)

前面讲的初始化权重函数W,其主要目的是通过慎重的选择权重函数W的初值以避免神经元值要么产生大量的0值,要么集中在+-1。最好能保证神经元值的分布能够尽可能的均等,具体来说,每层神经元值(激活函数的输出)的分布应该尽量朝着以下这些特点努力:

**1,0均值。**即正负值出现的频次都有,且差不多相同。

**2,适当的方差。**因为如果方差太大,容易出现梯度爆炸,而方差太小,就会引起梯度消失。Xavier初始化和He初始化就是为了确保每层的激活值方差适当而设计的。


4,1 批归一化的处理对象与维度

为了达到这一目的,相对于尝试不同初始化W的方法。Batch normalization则着重于处理全连接层的计算结果,也就是对线性变换的输出做二次处理,即对进行再处理。

在下面的这张ppt中,我们看到输入x的维度是NxD,也就是全连接层输出的维度。要搞清楚每个维度代表什么,这里我们可以稍微先回顾一下全连接层。

下图为神经网络中的一张PPT,如果说batch normalization中的输入x是****的话,那他的维度就应该等于这里h的维度。h的维度又是由W的其中一个维度决定,他的另一个维度等于前一层的输入。如果是单张图像则输入x的维度为Dx1,W为HxD,输出h的维度为H,H就是神经元的个数。如果输入是N张图,则输入x的维度为DxN,W为HxD,输出h的维度为HxN。

这也就是说,在batch normalization的PPT中维度是NxD的输入x,其中N表示样本数,D表示神经元的个数。


4,2 批归一化具体的处理方式

Batch normalization的做法和前面提过的data preprocessing很像,即,数据减去均值然后再除以标准差(虽然确实存在一些差异)。只不过data preprocessing的对象是最原始的输入数据,而Batch normalization,也叫BN层,是放在全连接层和激活函数之间的。

与data preprocessing处理数据的不同之处是,除了下图中的第一步完全一样之外。Batch normalization的不同之处在下面图中的第二步中。首先,在除以标准差的时候,为了避免除0,所除的标准差会加上一个很小的数。此外,在减去均值再除以标准差之后,又要再经过一个以为缩放以为偏置的线性化处理。


4,2,1 全连接层FC的Batch Normalization

对于全连接层FC而言,在batch normalization的PPT中输入x的维度是NxD,其中N表示样本数,D表示神经元的个数。 Batch normalization的处理是对N个样本求均值


4,2,2 CNN卷积层的Batch Normalization

对于CNN的卷积层而言,若,输入图像的维度是CinxWxH,共N张图,即NxCxWxH。filter的维度是CinxKwxKh,总共有Cout个filter,即CoutxCinxKwxKh。则输出结果的维度是NxCoutxW'xH'(即下图中输入x的维度)。 Batch normalization的处理是对N个WxH的样本求均值


4,2,3 全连接层FC的Layer Normalization

除了Batch normalization以外,类似的,还有一个变种叫Layer Normalization。对于全连接层FC而言,输入x的维度是NxD,其中N表示样本数,D表示神经元的个数。 Layer normalization的处理是对D个神经元求均值


4,2,4 CNN卷积层的Instance Normalization

对于CNN的卷积层而言,若,输入图像的维度是CinxWxH,共N张图,即NxCxWxH。filter的维度是CinxKwxKh,总共有Cout个filter,即CoutxCinxKwxKh。则输出结果的维度是NxCoutxW'xH'。 Instance Normalization的处理是对WxH的样本求均值


(全文完)

--- 作者,松下J27

参考文献(鸣谢):

1,Stanford University CS231n: Deep Learning for Computer Vision

2,训练神经网络(第一部分)_哔哩哔哩_bilibili

3,10 Training Neural Networks I_哔哩哔哩_bilibili

4,Schedule | EECS 498-007 / 598-005: Deep Learning for Computer Vision

**版权声明:**所有的笔记,可能来自很多不同的网站和说明,在此没法一一列出,如有侵权,请告知,立即删除。欢迎大家转载,但是,如果有人引用或者COPY我的文章,必须在你的文章中注明你所使用的图片或者文字来自于我的文章,否则,侵权必究。 ----松下J27

相关推荐
ehiway12 分钟前
FPGA+GPU+CPU国产化人工智能平台
人工智能·fpga开发·硬件工程·国产化
天天爱吃肉821815 分钟前
碳化硅(SiC)功率器件:新能源汽车的“心脏”革命与技术突围
大数据·人工智能
萧鼎1 小时前
利用 OpenCV 进行棋盘检测与透视变换
人工智能·opencv·计算机视觉
神秘的土鸡1 小时前
使用Open WebUI下载的模型文件(Model)默认存放在哪里?
人工智能·llama·ollama·openwebui
梦里是谁N1 小时前
【deepseek之我问】如何把AI技术与教育相结合,适龄教育,九年义务教育,以及大学教育,更着重英语学习。如何结合,给出观点。结合最新智能体Deepseek
人工智能·学习
小白狮ww2 小时前
国产超强开源大语言模型 DeepSeek-R1-70B 一键部署教程
人工智能·深度学习·机器学习·语言模型·自然语言处理·开源·deepseek
风口猪炒股指标2 小时前
想象一个AI保姆机器人使用场景分析
人工智能·机器人·deepseek·深度思考
Blankspace空白2 小时前
【小白学AI系列】NLP 核心知识点(八)多头自注意力机制
人工智能·自然语言处理
Sodas(填坑中....)2 小时前
SVM对偶问题
人工智能·机器学习·支持向量机·数据挖掘
forestsea2 小时前
DeepSeek 提示词:定义、作用、分类与设计原则
人工智能·prompt·deepseek