PyTorch|BatchNorm 的两种方差

在阅读 PyTorch 的 BatchNorm1d 文档时,你可能会注意到这样一句话:训练阶段前向传播中使用的方差是 biased estimator(unbiased=False),而存入 running_var 的却是 unbiased estimator(unbiased=True)。

At train time in the forward pass, the variance is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False). However, the value stored in the moving average of the variance is calculated via the unbiased estimator, equivalent to torch.var(input, unbiased=True).

乍一看,这似乎是一个「训练与推理不一致(mismatch)」的设计,尤其在 batch size 很小的情况下,biased 方差的偏差会非常明显,让人直觉上感到不安。

但理解了 BatchNorm 在训练和推理阶段的目标差异之后,这种设计其实是非常自然、甚至是刻意为之的。

首先,训练阶段 forward pass 中使用的 mini-batch 方差,并不是在做统计意义上的「总体方差估计」。BatchNorm 在训练时的核心目的,是对当前 batch 的激活进行一次重标定(rescaling):减去 batch 均值、除以 batch 的标准差,让激活落在一个稳定的数值区间中,从而改善优化过程。这一步里的方差本质上只是一个归一化用的尺度因子,而不是一个我们希望「期望上等于真实分布方差」的统计量。biased 方差与 unbiased 方差之间只差一个常数比例 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N − 1 \frac{N}{N-1} </math>N−1N ,把训练阶段的除数从 N 换成 N-1,在数学上完全成立,但等价于对归一化后的激活整体乘上一个固定常数,而这个尺度差异会被后续的可学习参数 γ 自动吸收,因此不会改变模型的表达能力或最终能学到的函数形式。

换句话说,即便 batch 很小、biased 方差在统计意义上偏得很厉害,这种偏差也不会破坏 BatchNorm 在训练中「稳定激活分布」的功能。更重要的是,小 batch 情况下真正的问题并不是 bias,而是 noise。无论使用 biased 还是 unbiased 方差,当样本数量很少时,方差估计本身都会高度不稳定,在 batch 之间剧烈波动。unbiased estimator 只能在期望意义上校正偏差,却无法降低估计的随机性,因此它并不能解决小 batch 下 BatchNorm 不稳定的问题。这也是为什么在小 batch 场景中,我们通常会选择 SyncBatchNorm、GroupNorm 或 LayerNorm,而不是纠结于是否该把 biased 方差换成 unbiased 方差。

但当我们转向 running_var 时,目标就完全不同了。running_var 并不是用来服务当前 batch 的,而是要在训练过程中逐步逼近整个数据分布的总体方差,并在推理阶段作为固定统计量使用。此时,这个问题才真正变成了一个统计估计问题:我们希望每个 mini-batch 提供的是对总体方差的一个「无偏抽样估计」,然后通过指数滑动平均的方式,将这些估计平滑地累积起来。因此,PyTorch 在更新 running_var 时,先对每个 batch 计算 unbiased 方差(除以 N-1),再用 momentum 做 EMA 更新。这使得 running_var 在长期意义上更接近真实分布的方差,而不是系统性偏小。

因此,看似「训练与推理不一致」的设计,实际上是同一机制在服务两种完全不同的目标:训练阶段的 batch 方差只是一个即时归一化工具,关注的是数值稳定和优化行为;而推理阶段使用的 running_var 则是一个跨 batch 累积的统计近似,关注的是对整体数据分布的合理刻画。前者不需要无偏,后者必须尽量无偏。这正是 PyTorch 在 BatchNorm 中区分 biased 与 unbiased 方差的根本原因。

相关推荐
yuer20252 小时前
我把 GPT 当成 Runtime 用:只用一个客户端,跑一个可控、可审计的投资决策 DEMO
算法
栀秋6662 小时前
面试常考的最长递增子序列(LIS),到底该怎么想、怎么写?
前端·javascript·算法
l1t2 小时前
在duckdb 递归CTE中实现深度优先搜索DFS
sql·算法·深度优先·duckdb·cte
陈陈爱java2 小时前
RRT建模
算法
智算菩萨3 小时前
摩擦电纳米发电机近期进展的理论脉络梳理:从接触起电到统一建模与能量转换
linux·人工智能·算法
xiaolang_8616_wjl3 小时前
c++超级细致的基本框架
开发语言·数据结构·c++·算法
艾醒3 小时前
大模型原理剖析——拆解预训练、微调、奖励建模与强化学习四阶段(以ChatGPT构建流程为例)
算法
冷崖3 小时前
排序--基数排序
c++·算法
F_D_Z3 小时前
哈希表解Two Sum问题
python·算法·leetcode·哈希表