【Torch】nn.BatchNorm1d算法详解

1. 定义

nn.BatchNorm1d 是 PyTorch 中用于一维特征 的批归一化层(Batch Normalization),常用于对全连接层或一维卷积(Conv1d)输出进行归一化。它通过对每个特征通道的输出在批次维度(及可选的时间/序列维度)上做归一化,使网络训练更稳定、收敛更快,并具有一定的正则化效果。

2. 输入与输出

  • 输入(Input)

    • 形状1 : ( N , C ) (N, C) (N,C),即批大小 N N N,特征通道数 C C C。
    • 形状2 : ( N , C , L ) (N, C, L) (N,C,L),用于处理一维序列或时间序列, L L L为序列长度。
    • 类型 :浮点型张量(torch.float32 等)。
  • 输出(Output)

    • 与输入形状完全相同,每个通道经归一化与可选的仿射变换后返回,类型与输入相同。

3. 底层原理

  1. 批内统计

    对输入的每个通道 c c c计算:
    μ c = 1 m ∑ i = 1 m x i , c , σ c 2 = 1 m ∑ i = 1 m ( x i , c − μ c ) 2 \mu_c = \frac{1}{m}\sum_{i=1}^m x_{i,c},\quad \sigma_c^2 = \frac{1}{m}\sum_{i=1}^m (x_{i,c}-\mu_c)^2 μc=m1i=1∑mxi,c,σc2=m1i=1∑m(xi,c−μc)2

    其中 m = N m=N m=N(或 N × L N\times L N×L)为归一化的元素总数。

  2. 归一化与仿射变换
    x ^ i , c = x i , c − μ c σ c 2 + ε , y i , c = γ c x ^ i , c + β c \hat x_{i,c} = \frac{x_{i,c}-\mu_c}{\sqrt{\sigma_c^2 + \varepsilon}},\quad y_{i,c} = \gamma_c \hat x_{i,c} + \beta_c x^i,c=σc2+ε xi,c−μc,yi,c=γcx^i,c+βc

    • ε \varepsilon ε(eps)防止除零;
    • γ , β \gamma,\beta γ,β(可选仿射参数)恢复表达能力。
  3. 滑动平均

    在训练模式下,维护运行时统计 (running_mean, running_var):
    running_mean ← ( 1 − α )   running_mean + α   μ c , \text{running\_mean} \leftarrow (1-\alpha)\,\text{running\_mean} + \alpha\,\mu_c, running_mean←(1−α)running_mean+αμc,

    同理更新 running_var,其中 α \alpha α为 momentum

    在评估模式下,用这些滑动统计量代替当前批次统计量。

4. 构造函数参数详解

参数 类型 & 默认 说明
num_features int 必填。通道数 (C),即输入的第二维度大小。
eps float, 默认 1e-5 防止除零的微小常数,影响归一化精度。
momentum float, 默认 0.1 滑动平均更新速率 (\alpha)。较大时更快跟踪批内变化,较小时更平滑。
affine bool, 默认 True 是否学习仿射参数 (\gamma,\beta)。若为 False,则不做缩放平移。
track_running_stats bool, 默认 True 是否跟踪 running_meanrunning_var;若为 False,则训练/测试都使用当前批统计。
device, dtype 可选 指定层参数所在设备和数据类型,若为 None 则继承父模块。

5. 使用示例

python 复制代码
import torch
import torch.nn as nn

# 假设:batch_size=4,通道数=3,序列长度=5
x = torch.randn(4, 3, 5)  # e.g. Conv1d 输出

# 定义 BatchNorm1d
bn1d = nn.BatchNorm1d(
    num_features=3,      # 通道数
    eps=1e-5,
    momentum=0.1,
    affine=True,
    track_running_stats=True
)

# 训练模式(更新 running stats,用当前 batch 统计归一化)
bn1d.train()
y_train = bn1d(x)
print("训练模式输出 shape:", y_train.shape)

# 切换到评估模式(使用 running stats,等价于恒定变换)
bn1d.eval()
y_eval = bn1d(x)
print("评估模式输出 shape:", y_eval.shape)

如果应用于全连接层输出,可输入 ((N,C)):

python 复制代码
# batch_size=8, 特征维=16
x2 = torch.randn(8, 16)
bn_fc = nn.BatchNorm1d(num_features=16)
y2 = bn_fc(x2)  # 输出 shape=(8,16)

6. 注意事项

  1. 小批量问题
    • 当 batch size 很小时,批内统计不稳定,可能导致归一化偏差;可考虑 LayerNorm、GroupNorm 等替代。
  2. 模式切换
    • 务必 在训练阶段调用 model.train(),在测试/推理阶段调用 model.eval(),否则会一直使用当前 batch 统计或不更新滑动统计。
  3. momentum 设定
    • 默认为 0.1,若数据分布剧烈变化,可适当调大;若希望统计更平滑,可调小。
  4. affine=False 场景
    • 若不需要仿射变换(如已在后续层有缩放和平移),可关闭 affine 以减少参数。
  5. 与其他正则化配合
    • BatchNorm 本身有轻度正则化效果,但仍可搭配 Dropout、Weight Decay 等手段;在 Transformer 中常与 LayerNorm 互补使用。
  6. 序列 & 时序模型
    • 对 RNN/CNN + 时序数据,输入形状可为 ((N,C,L));若不希望跨时间步归一化,可使用 BatchNorm1d 或专门的 nn.GroupNorm
相关推荐
小徐不徐说41 分钟前
每日一算:华为-批萨分配问题
数据结构·c++·算法·leetcode·华为·动态规划·后端开发
菜鸟555553 小时前
图论:最小生成树
算法·图论
2401_872945094 小时前
【补题】Codeforces Round 735 (Div. 2) C. Mikasa
算法
叫我:松哥4 小时前
基于网络爬虫的在线医疗咨询数据爬取与医疗服务分析系统,技术采用django+朴素贝叶斯算法+boostrap+echart可视化
人工智能·爬虫·python·算法·django·数据可视化·朴素贝叶斯
Star在努力4 小时前
14-C语言:第14天笔记
c语言·笔记·算法
赴3356 小时前
Numpy 库 矩阵数学运算,点积,文件读取和保存等
人工智能·算法·numpy·random·dot
自由随风飘6 小时前
机器学习-SVM支持向量机
算法·机器学习·支持向量机
屁股割了还要学7 小时前
【C语言进阶】柔性数组
c语言·开发语言·数据结构·c++·学习·算法·柔性数组
草莓熊Lotso7 小时前
【LeetCode刷题指南】--有效的括号
c语言·数据结构·其他·算法·leetcode·刷题
Alla T8 小时前
【通识】算法案例
算法