PyTorch|BatchNorm 的两种方差

在阅读 PyTorch 的 BatchNorm1d 文档时,你可能会注意到这样一句话:训练阶段前向传播中使用的方差是 biased estimator(unbiased=False),而存入 running_var 的却是 unbiased estimator(unbiased=True)。

At train time in the forward pass, the variance is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False). However, the value stored in the moving average of the variance is calculated via the unbiased estimator, equivalent to torch.var(input, unbiased=True).

乍一看,这似乎是一个「训练与推理不一致(mismatch)」的设计,尤其在 batch size 很小的情况下,biased 方差的偏差会非常明显,让人直觉上感到不安。

但理解了 BatchNorm 在训练和推理阶段的目标差异之后,这种设计其实是非常自然、甚至是刻意为之的。

首先,训练阶段 forward pass 中使用的 mini-batch 方差,并不是在做统计意义上的「总体方差估计」。BatchNorm 在训练时的核心目的,是对当前 batch 的激活进行一次重标定(rescaling):减去 batch 均值、除以 batch 的标准差,让激活落在一个稳定的数值区间中,从而改善优化过程。这一步里的方差本质上只是一个归一化用的尺度因子,而不是一个我们希望「期望上等于真实分布方差」的统计量。biased 方差与 unbiased 方差之间只差一个常数比例 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N − 1 \frac{N}{N-1} </math>N−1N ,把训练阶段的除数从 N 换成 N-1,在数学上完全成立,但等价于对归一化后的激活整体乘上一个固定常数,而这个尺度差异会被后续的可学习参数 γ 自动吸收,因此不会改变模型的表达能力或最终能学到的函数形式。

换句话说,即便 batch 很小、biased 方差在统计意义上偏得很厉害,这种偏差也不会破坏 BatchNorm 在训练中「稳定激活分布」的功能。更重要的是,小 batch 情况下真正的问题并不是 bias,而是 noise。无论使用 biased 还是 unbiased 方差,当样本数量很少时,方差估计本身都会高度不稳定,在 batch 之间剧烈波动。unbiased estimator 只能在期望意义上校正偏差,却无法降低估计的随机性,因此它并不能解决小 batch 下 BatchNorm 不稳定的问题。这也是为什么在小 batch 场景中,我们通常会选择 SyncBatchNorm、GroupNorm 或 LayerNorm,而不是纠结于是否该把 biased 方差换成 unbiased 方差。

但当我们转向 running_var 时,目标就完全不同了。running_var 并不是用来服务当前 batch 的,而是要在训练过程中逐步逼近整个数据分布的总体方差,并在推理阶段作为固定统计量使用。此时,这个问题才真正变成了一个统计估计问题:我们希望每个 mini-batch 提供的是对总体方差的一个「无偏抽样估计」,然后通过指数滑动平均的方式,将这些估计平滑地累积起来。因此,PyTorch 在更新 running_var 时,先对每个 batch 计算 unbiased 方差(除以 N-1),再用 momentum 做 EMA 更新。这使得 running_var 在长期意义上更接近真实分布的方差,而不是系统性偏小。

因此,看似「训练与推理不一致」的设计,实际上是同一机制在服务两种完全不同的目标:训练阶段的 batch 方差只是一个即时归一化工具,关注的是数值稳定和优化行为;而推理阶段使用的 running_var 则是一个跨 batch 累积的统计近似,关注的是对整体数据分布的合理刻画。前者不需要无偏,后者必须尽量无偏。这正是 PyTorch 在 BatchNorm 中区分 biased 与 unbiased 方差的根本原因。

相关推荐
房开民4 小时前
可变参数模板
java·开发语言·算法
不知名的忻4 小时前
Morris遍历(力扣第99题)
java·算法·leetcode·morris遍历
状元岐5 小时前
C#反射从入门到精通
java·javascript·算法
_深海凉_5 小时前
LeetCode热题100-除了自身以外数组的乘积
数据结构·算法·leetcode
Kk.08026 小时前
项目《基于Linux下的mybash命令解释器》(一)
前端·javascript·算法
SteveSenna6 小时前
Trossen Arm MuJoCo自定义1:改变目标物体
人工智能·学习·算法·机器人
yong99907 小时前
IHAOAVOA:天鹰优化算法与非洲秃鹫优化算法的混合算法(Matlab实现)
开发语言·算法·matlab
米粒18 小时前
力扣算法刷题 Day 42(股票问题总结)
算法·leetcode·职场和发展
浅念-10 小时前
从LeetCode入门位运算:常见技巧与实战题目全解析
数据结构·数据库·c++·笔记·算法·leetcode·牛客
CoovallyAIHub10 小时前
无人机拍叶片→AI找缺陷:CEA-DETR改进RT-DETR做风电叶片表面缺陷检测,mAP50达89.4%
算法·架构·github