1. 定义
nn.BatchNorm1d
是 PyTorch 中用于一维特征 的批归一化层(Batch Normalization),常用于对全连接层或一维卷积(Conv1d)输出进行归一化。它通过对每个特征通道的输出在批次维度(及可选的时间/序列维度)上做归一化,使网络训练更稳定、收敛更快,并具有一定的正则化效果。
2. 输入与输出
-
输入(Input)
- 形状1 : ( N , C ) (N, C) (N,C),即批大小 N N N,特征通道数 C C C。
- 形状2 : ( N , C , L ) (N, C, L) (N,C,L),用于处理一维序列或时间序列, L L L为序列长度。
- 类型 :浮点型张量(
torch.float32
等)。
-
输出(Output)
- 与输入形状完全相同,每个通道经归一化与可选的仿射变换后返回,类型与输入相同。
3. 底层原理
-
批内统计
对输入的每个通道 c c c计算:
μ c = 1 m ∑ i = 1 m x i , c , σ c 2 = 1 m ∑ i = 1 m ( x i , c − μ c ) 2 \mu_c = \frac{1}{m}\sum_{i=1}^m x_{i,c},\quad \sigma_c^2 = \frac{1}{m}\sum_{i=1}^m (x_{i,c}-\mu_c)^2 μc=m1i=1∑mxi,c,σc2=m1i=1∑m(xi,c−μc)2其中 m = N m=N m=N(或 N × L N\times L N×L)为归一化的元素总数。
-
归一化与仿射变换
x ^ i , c = x i , c − μ c σ c 2 + ε , y i , c = γ c x ^ i , c + β c \hat x_{i,c} = \frac{x_{i,c}-\mu_c}{\sqrt{\sigma_c^2 + \varepsilon}},\quad y_{i,c} = \gamma_c \hat x_{i,c} + \beta_c x^i,c=σc2+ε xi,c−μc,yi,c=γcx^i,c+βc- ε \varepsilon ε(eps)防止除零;
- γ , β \gamma,\beta γ,β(可选仿射参数)恢复表达能力。
-
滑动平均
在训练模式下,维护运行时统计 (running_mean, running_var):
running_mean ← ( 1 − α ) running_mean + α μ c , \text{running\_mean} \leftarrow (1-\alpha)\,\text{running\_mean} + \alpha\,\mu_c, running_mean←(1−α)running_mean+αμc,同理更新
running_var
,其中 α \alpha α为momentum
。在评估模式下,用这些滑动统计量代替当前批次统计量。
4. 构造函数参数详解
参数 | 类型 & 默认 | 说明 |
---|---|---|
num_features |
int |
必填。通道数 (C),即输入的第二维度大小。 |
eps |
float , 默认 1e-5 |
防止除零的微小常数,影响归一化精度。 |
momentum |
float , 默认 0.1 |
滑动平均更新速率 (\alpha)。较大时更快跟踪批内变化,较小时更平滑。 |
affine |
bool , 默认 True |
是否学习仿射参数 (\gamma,\beta)。若为 False ,则不做缩放平移。 |
track_running_stats |
bool , 默认 True |
是否跟踪 running_mean 和 running_var ;若为 False ,则训练/测试都使用当前批统计。 |
device , dtype |
可选 | 指定层参数所在设备和数据类型,若为 None 则继承父模块。 |
5. 使用示例
python
import torch
import torch.nn as nn
# 假设:batch_size=4,通道数=3,序列长度=5
x = torch.randn(4, 3, 5) # e.g. Conv1d 输出
# 定义 BatchNorm1d
bn1d = nn.BatchNorm1d(
num_features=3, # 通道数
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True
)
# 训练模式(更新 running stats,用当前 batch 统计归一化)
bn1d.train()
y_train = bn1d(x)
print("训练模式输出 shape:", y_train.shape)
# 切换到评估模式(使用 running stats,等价于恒定变换)
bn1d.eval()
y_eval = bn1d(x)
print("评估模式输出 shape:", y_eval.shape)
如果应用于全连接层输出,可输入 ((N,C)):
python
# batch_size=8, 特征维=16
x2 = torch.randn(8, 16)
bn_fc = nn.BatchNorm1d(num_features=16)
y2 = bn_fc(x2) # 输出 shape=(8,16)
6. 注意事项
- 小批量问题
- 当 batch size 很小时,批内统计不稳定,可能导致归一化偏差;可考虑 LayerNorm、GroupNorm 等替代。
- 模式切换
- 务必 在训练阶段调用
model.train()
,在测试/推理阶段调用model.eval()
,否则会一直使用当前 batch 统计或不更新滑动统计。
- 务必 在训练阶段调用
momentum
设定- 默认为 0.1,若数据分布剧烈变化,可适当调大;若希望统计更平滑,可调小。
affine=False
场景- 若不需要仿射变换(如已在后续层有缩放和平移),可关闭
affine
以减少参数。
- 若不需要仿射变换(如已在后续层有缩放和平移),可关闭
- 与其他正则化配合
- BatchNorm 本身有轻度正则化效果,但仍可搭配 Dropout、Weight Decay 等手段;在 Transformer 中常与 LayerNorm 互补使用。
- 序列 & 时序模型
- 对 RNN/CNN + 时序数据,输入形状可为 ((N,C,L));若不希望跨时间步归一化,可使用
BatchNorm1d
或专门的nn.GroupNorm
。
- 对 RNN/CNN + 时序数据,输入形状可为 ((N,C,L));若不希望跨时间步归一化,可使用
