批标准化概述
批标准化(Batch Normalization,BN)是深度学习中最具影响力的创新之一,由Ioffe和Szegedy于2015年提出。它通过标准化层输入,显著解决了内部协变量偏移问题。
原始输入 卷积/全连接层 批标准化 激活函数 下一层
BN核心原理
-
计算批统计量 :
μB=1m∑i=1mxi\mu_B = \frac{1}{m}\sum_{i=1}^m x_iμB=m1i=1∑mxi
σB2=1m∑i=1m(xi−μB)2\sigma_B^2 = \frac{1}{m}\sum_{i=1}^m (x_i - \mu_B)^2σB2=m1i=1∑m(xi−μB)2 -
标准化输入 :
x^i=xi−μBσB2+ϵ\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}x^i=σB2+ϵ xi−μB -
缩放与偏移 :(为了让隐藏层单元又不同分布)
yi=γx^i+βy_i = \gamma \hat{x}_i + \betayi=γx^i+β
其中:
- γ\gammaγ:可学习缩放参数
- β\betaβ:可学习偏移参数
- ϵ\epsilonϵ:防止除零的小常数(取值10−810^{-8}10−8)
Python手动实现
python
import numpy as np
class BatchNorm:
def __init__(self, num_features, momentum=0.9, eps=1e-5):
self.gamma = np.ones((1, num_features))
self.beta = np.zeros((1, num_features))
self.momentum = momentum
self.eps = eps
self.running_mean = None
self.running_var = None
def forward(self, X, training=True):
if training:
# 计算当前批次的均值和方差
mean = np.mean(X, axis=0, keepdims=True)
var = np.var(X, axis=0, keepdims=True)
# 更新运行平均值
if self.running_mean is None:
self.running_mean = mean
self.running_var = var
else:
self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
else:
mean = self.running_mean
var = self.running_var
# 标准化
X_hat = (X - mean) / np.sqrt(var + self.eps)
# 缩放和偏移
out = self.gamma * X_hat + self.beta
# 缓存反向传播所需的值
if training:
self.cache = (X, X_hat, mean, var)
return out
def backward(self, dout):
X, X_hat, mean, var = self.cache
m = X.shape[0]
eps = self.eps
# 计算梯度
dgamma = np.sum(dout * X_hat, axis=0, keepdims=True)
dbeta = np.sum(dout, axis=0, keepdims=True)
# 计算dX_hat
dX_hat = dout * self.gamma
# 计算dvar
dvar = np.sum(dX_hat * (X - mean) * (-0.5) * (var + eps)**(-1.5), axis=0, keepdims=True)
# 计算dmean
dmean = np.sum(dX_hat * (-1) / np.sqrt(var + eps), axis=0, keepdims=True) + \
dvar * np.sum(-2 * (X - mean), axis=0, keepdims=True) / m
# 计算dX
dX = (dX_hat / np.sqrt(var + eps)) + \
(dvar * 2 * (X - mean) / m) + \
(dmean / m)
return dX, dgamma, dbeta
BN的优势与效果
35% 25% 20% 15% 5% BatchNorm优势分布 训练加速 允许更高学习率 减少梯度消失 轻微正则化效果 降低初始化敏感度
神经网络调优
1. 学习率调优策略
学习率预热(Learning Rate Warmup)
python
def lr_warmup(current_step, warmup_steps, base_lr):
return base_lr * min(1.0, current_step / warmup_steps)
# 余弦退火学习率
def cosine_annealing(step, total_steps, base_lr):
return 0.5 * base_lr * (1 + np.cos(np.pi * step / total_steps))
周期性学习率(CLR)
python
def cyclical_lr(step, step_size, base_lr, max_lr):
cycle = np.floor(1 + step / (2 * step_size))
x = np.abs(step / step_size - 2 * cycle + 1)
return base_lr + (max_lr - base_lr) * np.maximum(0, (1 - x))
2. 权重初始化策略比较
初始化方法 | 适用激活函数 | 公式 | 特点 |
---|---|---|---|
Xavier/Glorot | Tanh/Sigmoid | W∼U(−6nin+nout,6nin+nout)W \sim \mathcal{U}(-\sqrt{\frac{6}{n_{in}+n_{out}}}, \sqrt{\frac{6}{n_{in}+n_{out}}})W∼U(−nin+nout6 ,nin+nout6 ) | 保持输入输出方差一致 |
He/Kaiming | ReLU | W∼N(0,2nin)W \sim \mathcal{N}(0, \sqrt{\frac{2}{n_{in}}})W∼N(0,nin2 ) | 解决ReLU负半轴问题 |
LeCun | SELU | W∼N(0,1nin)W \sim \mathcal{N}(0, \sqrt{\frac{1}{n_{in}}})W∼N(0,nin1 ) | 自归一化网络专用 |
3. 梯度裁剪(Gradient Clipping)
python
def gradient_clip(grads, max_norm):
total_norm = 0
for grad in grads:
grad_norm = np.sum(grad**2)
total_norm += grad_norm
total_norm = np.sqrt(total_norm)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for grad in grads:
grad *= clip_coef
return grads
神经网络调参技巧
1. 系统化调参流程
不足 满意 定义模型架构 选择优化器 设置学习率 添加正则化 训练模型 验证性能 调整超参数 最终评估
2. 关键超参数调优指南
超参数 | 推荐值/策略 | 调优方法 | 注意事项 |
---|---|---|---|
学习率 | 0.001-0.1 | 学习率扫描 | 配合学习率调度器 |
批量大小 | 32-256 | 2的幂次 | 内存允许下尽量大 |
网络深度 | 3-10层 | 逐步增加 | 配合残差连接 |
层宽度 | 64-1024 | 递增策略 | 与深度平衡 |
激活函数 | ReLU/Swish | 实验比较 | 输出层用Sigmoid/Softmax |
优化器 | Adam/NAdam | 比较收敛速度 | 搭配权重衰减 |
正则化强度 | L2:0.0001-0.01 | 交叉验证 | 配合Dropout(0.2-0.5) |
批标准化最佳实践
-
放置位置:通常在全连接/卷积层后,激活函数前
输入 全连接层 批标准化 激活函数 下一层
-
参数设置:
- 动量:0.9-0.99(推荐0.99)
- ε:e−8e^{-8}e−8
- γ和β:可学习参数,初始值γ=1,β=0
-
注意事项:
- 训练和测试模式需区分
- 小批量数据效果差(批量大小≥32)
- 与Dropout共用时需小心
- 不适合RNN(需用LayerNorm)
总结
-
批标准化是基础:
- 所有现代深度学习架构的标配
- 使训练更稳定、更快
- 允许使用更高学习率
-
系统化调参流程:
数据预处理 初始模型 学习率调优 正则化强度 架构搜索 高级优化
-
黄金组合:
- He初始化 + BatchNorm + AdamW
- 学习率预热 + 余弦退火
- 适度L2正则化 + 数据增强
-
避免常见陷阱:
- 不验证数据预处理一致性
- 忽略训练/测试模式差异
- 在小批量数据上使用BN
- 过早停止调参过程