【Torch】nn.BatchNorm1d算法详解

1. 定义

nn.BatchNorm1d 是 PyTorch 中用于一维特征 的批归一化层(Batch Normalization),常用于对全连接层或一维卷积(Conv1d)输出进行归一化。它通过对每个特征通道的输出在批次维度(及可选的时间/序列维度)上做归一化,使网络训练更稳定、收敛更快,并具有一定的正则化效果。

2. 输入与输出

  • 输入(Input)

    • 形状1 : ( N , C ) (N, C) (N,C),即批大小 N N N,特征通道数 C C C。
    • 形状2 : ( N , C , L ) (N, C, L) (N,C,L),用于处理一维序列或时间序列, L L L为序列长度。
    • 类型 :浮点型张量(torch.float32 等)。
  • 输出(Output)

    • 与输入形状完全相同,每个通道经归一化与可选的仿射变换后返回,类型与输入相同。

3. 底层原理

  1. 批内统计

    对输入的每个通道 c c c计算:
    μ c = 1 m ∑ i = 1 m x i , c , σ c 2 = 1 m ∑ i = 1 m ( x i , c − μ c ) 2 \mu_c = \frac{1}{m}\sum_{i=1}^m x_{i,c},\quad \sigma_c^2 = \frac{1}{m}\sum_{i=1}^m (x_{i,c}-\mu_c)^2 μc=m1i=1∑mxi,c,σc2=m1i=1∑m(xi,c−μc)2

    其中 m = N m=N m=N(或 N × L N\times L N×L)为归一化的元素总数。

  2. 归一化与仿射变换
    x ^ i , c = x i , c − μ c σ c 2 + ε , y i , c = γ c x ^ i , c + β c \hat x_{i,c} = \frac{x_{i,c}-\mu_c}{\sqrt{\sigma_c^2 + \varepsilon}},\quad y_{i,c} = \gamma_c \hat x_{i,c} + \beta_c x^i,c=σc2+ε xi,c−μc,yi,c=γcx^i,c+βc

    • ε \varepsilon ε(eps)防止除零;
    • γ , β \gamma,\beta γ,β(可选仿射参数)恢复表达能力。
  3. 滑动平均

    在训练模式下,维护运行时统计 (running_mean, running_var):
    running_mean ← ( 1 − α )   running_mean + α   μ c , \text{running\_mean} \leftarrow (1-\alpha)\,\text{running\_mean} + \alpha\,\mu_c, running_mean←(1−α)running_mean+αμc,

    同理更新 running_var,其中 α \alpha α为 momentum

    在评估模式下,用这些滑动统计量代替当前批次统计量。

4. 构造函数参数详解

参数 类型 & 默认 说明
num_features int 必填。通道数 (C),即输入的第二维度大小。
eps float, 默认 1e-5 防止除零的微小常数,影响归一化精度。
momentum float, 默认 0.1 滑动平均更新速率 (\alpha)。较大时更快跟踪批内变化,较小时更平滑。
affine bool, 默认 True 是否学习仿射参数 (\gamma,\beta)。若为 False,则不做缩放平移。
track_running_stats bool, 默认 True 是否跟踪 running_meanrunning_var;若为 False,则训练/测试都使用当前批统计。
device, dtype 可选 指定层参数所在设备和数据类型,若为 None 则继承父模块。

5. 使用示例

python 复制代码
import torch
import torch.nn as nn

# 假设:batch_size=4,通道数=3,序列长度=5
x = torch.randn(4, 3, 5)  # e.g. Conv1d 输出

# 定义 BatchNorm1d
bn1d = nn.BatchNorm1d(
    num_features=3,      # 通道数
    eps=1e-5,
    momentum=0.1,
    affine=True,
    track_running_stats=True
)

# 训练模式(更新 running stats,用当前 batch 统计归一化)
bn1d.train()
y_train = bn1d(x)
print("训练模式输出 shape:", y_train.shape)

# 切换到评估模式(使用 running stats,等价于恒定变换)
bn1d.eval()
y_eval = bn1d(x)
print("评估模式输出 shape:", y_eval.shape)

如果应用于全连接层输出,可输入 ((N,C)):

python 复制代码
# batch_size=8, 特征维=16
x2 = torch.randn(8, 16)
bn_fc = nn.BatchNorm1d(num_features=16)
y2 = bn_fc(x2)  # 输出 shape=(8,16)

6. 注意事项

  1. 小批量问题
    • 当 batch size 很小时,批内统计不稳定,可能导致归一化偏差;可考虑 LayerNorm、GroupNorm 等替代。
  2. 模式切换
    • 务必 在训练阶段调用 model.train(),在测试/推理阶段调用 model.eval(),否则会一直使用当前 batch 统计或不更新滑动统计。
  3. momentum 设定
    • 默认为 0.1,若数据分布剧烈变化,可适当调大;若希望统计更平滑,可调小。
  4. affine=False 场景
    • 若不需要仿射变换(如已在后续层有缩放和平移),可关闭 affine 以减少参数。
  5. 与其他正则化配合
    • BatchNorm 本身有轻度正则化效果,但仍可搭配 Dropout、Weight Decay 等手段;在 Transformer 中常与 LayerNorm 互补使用。
  6. 序列 & 时序模型
    • 对 RNN/CNN + 时序数据,输入形状可为 ((N,C,L));若不希望跨时间步归一化,可使用 BatchNorm1d 或专门的 nn.GroupNorm
相关推荐
Gyoku Mint27 分钟前
深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹
人工智能·pytorch·python·深度学习·神经网络·算法·聚类
葫三生2 小时前
如何评价《论三生原理》在科技界的地位?
人工智能·算法·机器学习·数学建模·量子计算
拓端研究室4 小时前
视频讲解:门槛效应模型Threshold Effect分析数字金融指数与消费结构数据
前端·算法
随缘而动,随遇而安6 小时前
第八十八篇 大数据中的递归算法:从俄罗斯套娃到分布式计算的奇妙之旅
大数据·数据结构·算法
IT古董6 小时前
【第二章:机器学习与神经网络概述】03.类算法理论与实践-(3)决策树分类器
神经网络·算法·机器学习
水木兰亭9 小时前
数据结构之——树及树的存储
数据结构·c++·学习·算法
Jess0710 小时前
插入排序的简单介绍
数据结构·算法·排序算法
老一岁10 小时前
选择排序算法详解
数据结构·算法·排序算法
xindafu10 小时前
代码随想录算法训练营第四十二天|动态规划part9
算法·动态规划
xindafu10 小时前
代码随想录算法训练营第四十五天|动态规划part12
算法·动态规划