一文读懂BatchNorm2d()函数的计算逻辑

1.简介

机器学习中,进行模型训练之前,需对数据做归一化处理,使其分布一致。在深度神经网络训练过程中,通常一次训练是一个batch,而非全体数据。每个batch具有不同的分布产生了internal covarivate shift问题------在训练过程中,数据分布会发生变化,对下一层网络的学习带来困难。Batch Normalization将数据规范到均值为0,方差为1的分布上,一方面使得数据分布一致,另一方面避免梯度消失。

2.计算

如图所示:

上为输入数据,其shape=[5, 3, h, w]

Step1: 计算同一通道下的均值,如图中的红色图块,均表示同一通道

Step2: 计算同一通道下的方差,如图中的红色图块,均表示同一通道

Step3: 对当前通道下的每个数据做归一化 其中的x表示具体的一个点,如x = X[0][0][0][0][0]这个数据点。

Step4: 增加缩放和平移变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β, 归一化后的值为

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ是一个设置的常量,默认为1e^-5,其作用是防止除0。 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β这两个参数一般情况也不需要我们管(如果,参数affine=true, 就需要我们给定)。

3. Pytorch中的nn.BatchNorm2d()函数的解释

其主要需要输入4个参数:

(1)num_features:输入数据的shape一般为[batch_size, channel, height, width], num_features为其中的channel;

(2)eps: 分母中添加的一个值,目的是为了计算的稳定性,默认:1e-5;

(3)momentum : 一个用于运行过程中均值和方差的一个估计参数,默认值为0.1.

(4)affine :当设为true时,给定可以学习的系数矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β

4.代码示例:

scss 复制代码
import torch

data = torch.ones(size=(2, 2, 3, 4))
data[0][0][0][0] = 25
print("data = ", data)

print("\n")

print("=========================使用封装的BatchNorm2d()计算================================")
BN = torch.nn.BatchNorm2d(num_features=2, eps=0, momentum=0)
BN_data = BN(data)
print("BN_data = ", BN_data)

print("\n")

print("=========================自行计算================================")
x = torch.cat((data[0][0], data[1][0]), dim=1)      # 1.将同一通道进行拼接(即把同一通道当作一个整体)
x_mean = torch.Tensor.mean(x)                       # 2.计算同一通道所有制的均值(即拼接后的均值)
x_var = torch.Tensor.var(x, False)                  # 3.计算同一通道所有制的方差(即拼接后的方差)

# 4.使用第一个数按照公式来求BatchNorm后的值
bn_first = ((data[0][0][0][0] - x_mean) / ( torch.pow(x_var, 0.5))) * BN.weight[0] + BN.bias[0]
print("bn_first = ", bn_first)

运行结果:

(1)原数据

(2)使用BatchNorm()函数

(3)自行计算批归一化的值 图中标红的两个框数据完全相等,完结撒花!!!

注: 有借鉴该篇文章

相关推荐
圈圈编码17 分钟前
LeetCode Hot100刷题——合并两个有序链表
java·数据结构·算法·leetcode·链表
老歌老听老掉牙1 小时前
旋量理论:刚体运动的几何描述与机器人应用
python·算法·机器学习·机器人·旋量
无聊的小坏坏1 小时前
用递归算法解锁「子集」问题 —— LeetCode 78题解析
算法·深度优先
m0_738596321 小时前
十大排序算法
算法·排序算法
jingfeng5141 小时前
详解快排的四种方式
数据结构·算法·排序算法
MoRanzhi12032 小时前
245. 2019年蓝桥杯国赛 - 数正方形(困难)- 递推
python·算法·蓝桥杯·国赛·递推·2019
henyaoyuancc2 小时前
vla学习 富
人工智能·算法
Gyoku Mint3 小时前
机器学习×第五卷:线性回归入门——她不再模仿,而开始试着理解你
人工智能·python·算法·机器学习·pycharm·回归·线性回归
蒙奇D索大3 小时前
【数据结构】图论最短路径算法深度解析:从BFS基础到全算法综述
数据结构·算法·图论·广度优先·图搜索算法
trouvaille4 小时前
哈希数据结构的增强
算法·go