在阅读 PyTorch 的 BatchNorm1d 文档时,你可能会注意到这样一句话:训练阶段前向传播中使用的方差是 biased estimator(unbiased=False),而存入 running_var 的却是 unbiased estimator(unbiased=True)。
At train time in the forward pass, the variance is calculated via the biased estimator, equivalent to
torch.var(input, unbiased=False). However, the value stored in the moving average of the variance is calculated via the unbiased estimator, equivalent totorch.var(input, unbiased=True).
乍一看,这似乎是一个「训练与推理不一致(mismatch)」的设计,尤其在 batch size 很小的情况下,biased 方差的偏差会非常明显,让人直觉上感到不安。
但理解了 BatchNorm 在训练和推理阶段的目标差异之后,这种设计其实是非常自然、甚至是刻意为之的。
首先,训练阶段 forward pass 中使用的 mini-batch 方差,并不是在做统计意义上的「总体方差估计」。BatchNorm 在训练时的核心目的,是对当前 batch 的激活进行一次重标定(rescaling):减去 batch 均值、除以 batch 的标准差,让激活落在一个稳定的数值区间中,从而改善优化过程。这一步里的方差本质上只是一个归一化用的尺度因子,而不是一个我们希望「期望上等于真实分布方差」的统计量。biased 方差与 unbiased 方差之间只差一个常数比例 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N − 1 \frac{N}{N-1} </math>N−1N ,把训练阶段的除数从 N 换成 N-1,在数学上完全成立,但等价于对归一化后的激活整体乘上一个固定常数,而这个尺度差异会被后续的可学习参数 γ 自动吸收,因此不会改变模型的表达能力或最终能学到的函数形式。
换句话说,即便 batch 很小、biased 方差在统计意义上偏得很厉害,这种偏差也不会破坏 BatchNorm 在训练中「稳定激活分布」的功能。更重要的是,小 batch 情况下真正的问题并不是 bias,而是 noise。无论使用 biased 还是 unbiased 方差,当样本数量很少时,方差估计本身都会高度不稳定,在 batch 之间剧烈波动。unbiased estimator 只能在期望意义上校正偏差,却无法降低估计的随机性,因此它并不能解决小 batch 下 BatchNorm 不稳定的问题。这也是为什么在小 batch 场景中,我们通常会选择 SyncBatchNorm、GroupNorm 或 LayerNorm,而不是纠结于是否该把 biased 方差换成 unbiased 方差。
但当我们转向 running_var 时,目标就完全不同了。running_var 并不是用来服务当前 batch 的,而是要在训练过程中逐步逼近整个数据分布的总体方差,并在推理阶段作为固定统计量使用。此时,这个问题才真正变成了一个统计估计问题:我们希望每个 mini-batch 提供的是对总体方差的一个「无偏抽样估计」,然后通过指数滑动平均的方式,将这些估计平滑地累积起来。因此,PyTorch 在更新 running_var 时,先对每个 batch 计算 unbiased 方差(除以 N-1),再用 momentum 做 EMA 更新。这使得 running_var 在长期意义上更接近真实分布的方差,而不是系统性偏小。
因此,看似「训练与推理不一致」的设计,实际上是同一机制在服务两种完全不同的目标:训练阶段的 batch 方差只是一个即时归一化工具,关注的是数值稳定和优化行为;而推理阶段使用的 running_var 则是一个跨 batch 累积的统计近似,关注的是对整体数据分布的合理刻画。前者不需要无偏,后者必须尽量无偏。这正是 PyTorch 在 BatchNorm 中区分 biased 与 unbiased 方差的根本原因。