PyTorch -- Batch Normalization(BN) 快速实践

  • Batch Normalization 可以

    • 改善梯度消失/爆炸问题:前面层的梯度经过多次传递后会变得非常小(大),从而导致网络收敛速度慢(不收敛),应用 BN 可缓解
    • 加速网络收敛:BN 使得每个神经元的输入分布更加稳定
    • 减少过拟合:BN 可减少由于数据分布的变化导致的模型性能下降
    • 提高模型泛化能力:BN 使得模型对输入的微小变化更加稳定
    • 缓解超参敏感:对于 learning rate 等超参数敏感性降低
    • ...
  • Batch Normalization(BN):使 feature map 满足均值为 0,方差为 1 的分布规律

    • 如果batch size为m,则在前向传播过程中,网络中每个节点都有m个输出,所谓的Batch Normalization,就是对该层每个节点的这m个输出进行归一化再输出
    • 数学表达:每个 channel 下统计一个对应的均值和方差
      x norm = x − E x V a r x + ϵ ∗ γ + β x_{\text{norm}} = \frac{x - \mathbb{E}x}{\sqrt{Varx+\epsilon}} * \gamma + \beta xnorm=Varx+ϵ x−Ex∗γ+β
      • 其中 γ , β \gamma, \beta γ,β 为可学习的参数

  • 代码实践:

    python3 复制代码
    >>> import torch
    >>> import torch.nn as nn
    >>>
    >>> x = torch.rand(2,1,28,28)   		## *0.创建输入 x 
    >>> bn = nn.BatchNorm2d(				## *1. 创建 bn 层,
    						1,  				# -- 输入的 channel 数
    						training = False, 	# -- 是否为训练模式
    						affine = False) 	# -- 是否学习 γ β 				
    >>> out = bn(x) 						## *2 获取输出
    
    >>> # 查看相关数值 ------------------------------------------------
    >>> bn.running_mean					# 均值
    tensor([0.0507])
    >>> bn.running_var 					# 方差
    tensor([0.9080])
    >>> bn.weight						# γ
    Parameter containing:
    tensor([1.], requires_grad=True)
    >>> bn.bias							# β
    Parameter containing:
    tensor([0.], requires_grad=True)

相关推荐
IT空门:门主1 分钟前
MySQL MCP Server 从零安装到使用实战,AI 直接查询数据库
数据库·人工智能·mysql
Evand J2 分钟前
【自适应滤波】基于新息协方差匹配的自适应CKF目标跟踪 MATLAB 实战——在目标跟踪、雷达定位、组合导航和传感器融合等问题
人工智能·matlab·目标跟踪
Aipollo2 分钟前
多Agent架构设计模式、通讯间沟通对比分析
人工智能·ai
InternLM3 分钟前
从「模型类型不支持」到成功推理:Intern-S2-Preview oMLX 4bit 量化实录 | 与书生共创
人工智能·大模型·多模态模型
kcuwu.3 分钟前
模型压缩技术深度解析博客
人工智能
AI刀刀6 分钟前
豆包粘贴到 word 格式混乱,AI 导出鸭高效解决导出难题
人工智能·word·ai导出鸭
也非非也6 分钟前
Agnes AI 全模态 API 免费实测报告:文生图 + 文生视频完整测试
人工智能·音视频
KaMeidebaby11 分钟前
卡梅德生物技术快报|酵母表达系统工程:裂殖酵母穿梭载体分子改造与载体构建技术总结
网络·人工智能·网络协议·tcp/ip·算法
市象11 分钟前
可灵头上缺了一朵遮风挡雨的云
人工智能
盼小辉丶11 分钟前
PyTorch强化学习实战(11)——N步DQN(N-step DQN)
pytorch·python·深度学习·强化学习