pytorch中对nn.BatchNorm2d()函数的理解

pytorch中对BatchNorm2d函数的理解

  • 简介
  • 计算
  • [3. Pytorch的nn.BatchNorm2d()函数](#3. Pytorch的nn.BatchNorm2d()函数)
  • [4 代码示例](#4 代码示例)

简介

机器学习中,进行模型训练之前,需对数据做归一化处理,使其分布一致。在深度神经网络训练过程中,通常一次训练是一个batch,而非全体数据。每个batch具有不同的分布产生了internal covarivate shift问题------在训练过程中,数据分布会发生变化,对下一层网络的学习带来困难。Batch Normalization强行将数据拉回到均值为0,方差为1的正太分布上,一方面使得数据分布一致,另一方面避免梯度消失。

计算

如图所示:

3. Pytorch的nn.BatchNorm2d()函数

其主要需要输入4个参数:

(1)num_features:输入数据的shape一般为[batch_size, channel, height, width], num_features为其中的channel;

(2)eps: 分母中添加的一个值,目的是为了计算的稳定性,默认:1e-5;

(3)momentum: 一个用于运行过程中均值和方差的一个估计参数,默认值为0.1.

(4)affine:当设为true时,给定可以学习的系数矩阵 γ \gamma γ和 β \beta β

4 代码示例

复制代码
import torch

data = torch.ones(size=(2, 2, 3, 4))
data[0][0][0][0] = 25
print("data = ", data)

print("\n")

print("=========================使用封装的BatchNorm2d()计算================================")
BN = torch.nn.BatchNorm2d(num_features=2, eps=0, momentum=0)
BN_data = BN(data)
print("BN_data = ", BN_data)

print("\n")

print("=========================自行计算================================")
x = torch.cat((data[0][0], data[1][0]), dim=1)      # 1.将同一通道进行拼接(即把同一通道当作一个整体)
x_mean = torch.Tensor.mean(x)                       # 2.计算同一通道所有制的均值(即拼接后的均值)
x_var = torch.Tensor.var(x, False)                  # 3.计算同一通道所有制的方差(即拼接后的方差)

# 4.使用第一个数按照公式来求BatchNorm后的值
bn_first = ((data[0][0][0][0] - x_mean) / ( torch.pow(x_var, 0.5))) * BN.weight[0] + BN.bias[0]
print("bn_first = ", bn_first)


相关推荐
人工智能训练3 小时前
【极速部署】Ubuntu24.04+CUDA13.0 玩转 VLLM 0.15.0:预编译 Wheel 包 GPU 版安装全攻略
运维·前端·人工智能·python·ai编程·cuda·vllm
yaoming1683 小时前
python性能优化方案研究
python·性能优化
源于花海3 小时前
迁移学习相关的期刊和会议
人工智能·机器学习·迁移学习·期刊会议
码云数智-大飞4 小时前
使用 Python 高效提取 PDF 中的表格数据并导出为 TXT 或 Excel
python
DisonTangor5 小时前
DeepSeek-OCR 2: 视觉因果流
人工智能·开源·aigc·ocr·deepseek
薛定谔的猫19825 小时前
二十一、基于 Hugging Face Transformers 实现中文情感分析情感分析
人工智能·自然语言处理·大模型 训练 调优
发哥来了5 小时前
《AI视频生成技术原理剖析及金管道·图生视频的应用实践》
人工智能
biuyyyxxx5 小时前
Python自动化办公学习笔记(一) 工具安装&教程
笔记·python·学习·自动化
数智联AI团队5 小时前
AI搜索引领开源大模型新浪潮,技术创新重塑信息检索未来格局
人工智能·开源
极客数模5 小时前
【2026美赛赛题初步翻译F题】2026_ICM_Problem_F
大数据·c语言·python·数学建模·matlab