深度学习中的归一化技术详解:BN、LN、IN、GN

1. Batch Normalization (BN, 2015)

核心思想

  • 对Batch维度的每个特征通道进行归一化
  • 训练时用当前batch统计量,测试时用全局移动平均

计算步骤

python 复制代码
# 输入x形状: [B, C, H, W] (CNN) 或 [B, D] (全连接)
mean = x.mean(dim=[0, 2, 3])  # 沿Batch/空间维度求均值
var = x.var(dim=[0, 2, 3], unbiased=False)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta  # 可学习参数γ,β

μₖ = (1/(B×H×W)) × Σ(xᵢⱼₖₗ) (i=1...B, j=1...H, l=1...W)

σₖ² = (1/(B×H×W)) × Σ(xᵢⱼₖₗ - μₖ)²

x̂ᵢⱼₖₗ = (xᵢⱼₖₗ - μₖ) / √(σₖ² + ε)

yᵢⱼₖₗ = γₖ × x̂ᵢⱼₖₗ + βₖ

特点

优点 缺点
✅ 加速收敛 ✅ 允许更大学习率 ✅ 有正则化效果 ❌ 依赖大batch(通常>16) ❌ 不适用于RNN/Dynamic NN

PyTorch实现

python 复制代码
nn.BatchNorm2d(num_features)  # CNN
nn.BatchNorm1d(num_features)  # FC/RNN

2. Layer Normalization (LN, 2016)

核心思想

  • 对每个样本的所有特征进行归一化
  • 常用于Transformer和RNN

计算步骤

python 复制代码
# 输入x形状: [B, T, D] (Transformer)
mean = x.mean(dim=-1, keepdim=True)  # 特征维度
var = x.var(dim=-1, keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta

μᵢ = (1/D) × Σ(xᵢⱼ) (j=1...D)

σᵢ² = (1/D) × Σ(xᵢⱼ - μᵢ)²

x̂ᵢⱼ = (xᵢⱼ - μᵢ) / √(σᵢ² + ε)

yᵢⱼ = γⱼ × x̂ᵢⱼ + βⱼ

特点

优点 缺点
✅ 不依赖batch size ✅ 适合动态网络 ❌ CNN效果不如BN

PyTorch实现

python 复制代码
nn.LayerNorm(normalized_shape)  # normalized_shape=D

3. Instance Normalization (IN, 2017)

核心思想

  • 对每个样本的每个通道单独归一化
  • 风格迁移任务常用

计算步骤

python 复制代码
# 输入x形状: [B, C, H, W]
mean = x.mean(dim=[2, 3], keepdim=True)  # 空间维度
var = x.var(dim=[2, 3], keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta  # 可选

μᵢₖ = (1/(H×W)) × Σ(xᵢₖⱼₗ) (j=1...H, l=1...W)

σᵢₖ² = (1/(H×W)) × Σ(xᵢₖⱼₗ - μᵢₖ)²

x̂ᵢₖⱼₗ = (xᵢₖⱼₗ - μᵢₖ) / √(σᵢₖ² + ε)

yᵢₖⱼₗ = γₖ × x̂ᵢₖⱼₗ + βₖ (可选)

特点

优点 缺点
✅ 保留样本间独立性 ✅ 适合风格迁移 ❌ 破坏通道间相关性

PyTorch实现

python 复制代码
nn.InstanceNorm2d(num_features)

4. Group Normalization (GN, 2018)

核心思想

  • 将通道分组后对每组进行归一化
  • CNN小batch场景的BN替代方案

计算步骤

python 复制代码
# 输入x形状: [B, C, H, W], 设groups=G
x = x.view(B, G, C//G, H, W)  # 分组
mean = x.mean(dim=[2, 3, 4], keepdim=True)
var = x.var(dim=[2, 3, 4], keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = x_hat.view(B, C, H, W) * gamma + beta
分组后形状: [B, G, C//G, H, W]

μᵢ₉ = (1/((C//G)×H×W)) × Σ(xᵢ₉ₖⱼₗ)

σᵢ₉² = (1/((C//G)×H×W)) × Σ(xᵢ₉ₖⱼₗ - μᵢ₉)²

x̂ᵢ₉ₖⱼₗ = (xᵢ₉ₖⱼₗ - μᵢ₉) / √(σᵢ₉² + ε)

恢复形状后:

yᵢₖⱼₗ = γₖ × x̂ᵢₖⱼₗ + βₖ

特点

优点 缺点
✅ 小batch表现好 ✅ 精度接近BN ❌ 计算量稍大

PyTorch实现

python 复制代码
nn.GroupNorm(num_groups, num_channels)

5.对比总结

方法 归一化维度 适用场景 Batch依赖
BN [B, H, W] 大batch/CNN
LN [D] RNN/Transformer
IN [H, W] 风格迁移/生成模型
GN [G, H, W] 小batch CNN

代码示例(四种归一化对比)

python 复制代码
import torch.nn as nn

# 输入假设: [2, 6, 224, 224] (batch=2, channels=6)
bn = nn.BatchNorm2d(6)
ln = nn.LayerNorm([6, 224, 224])  # 全特征归一化
in = nn.InstanceNorm2d(6)
gn = nn.GroupNorm(num_groups=3, num_channels=6)  # 分2组

如何选择?

  1. CNN:优先尝试BN → batch<8时用GN
  2. RNN/Transformer:必选LN
  3. Style Transfer:首选IN
  4. 小batch CNN:GN+LN组合

📌 经验法则:当BN效果不佳时,根据任务特性尝试其他归一化方法


6. Transformer架构中的归一化标准方案

现代大语言模型普遍采用 Pre-LayerNorm 结构,即在注意力/FFN层之前进行归一化:

复制代码
输入 → LayerNorm → Attention → 残差连接 → LayerNorm → FFN → 残差连接

6.1 ChatGPT (OpenAI GPT系列)

模型版本 归一化方案 关键细节
GPT-2 LayerNorm 经典Post-LN
GPT-3 LayerNorm 改为Pre-LN
GPT-4 LayerNorm + 改进 可能引入RMSNorm

特点

  • 始终坚持LayerNorm
  • 从Post-LN转向更稳定的Pre-LN结构

6.2 DeepSeek

模型版本 归一化方案 关键细节
DeepSeek-MoE LayerNorm Pre-LN结构
DeepSeek-Coder LayerNorm 代码模型同样架构

创新点

  • 在MoE架构中保持LayerNorm一致性
  • 对长上下文优化了Norm位置

6.3 Qwen (阿里通义千问)

模型版本 归一化方案 关键细节
Qwen-1.8B LayerNorm 标准实现
Qwen-72B RMSNorm 性能优化

技术演进

  • 大参数模型改用RMSNorm减少计算量
  • 保留LayerNorm的缩放偏移参数

6.4为什么不用BatchNorm?

所有主流LLM都避免使用BN,原因包括:

  1. 序列长度可变:BN需要固定维度,但文本长度动态变化
  2. 小batch推理:预测时batch_size=1,BN统计量失效
  3. 训练不稳定:文本数据的稀疏性导致BN方差估计不准

6.5 进阶变体:RMSNorm

新兴模型(如LLaMA、Qwen-72B)开始采用 RMSNorm(Root Mean Square Normalization):

python 复制代码
def rms_norm(x, eps=1e-6):
    # 去均值操作(相比LayerNorm)
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps

RMS(x) = √((1/D) × Σ(xⱼ²) + ε)

yᵢ = (xᵢ / RMS(x)) × γᵢ

优势

  • 计算量减少约20%(适合超大模型)
  • 在Transformer中表现接近LayerNorm

6.6 模型实现对比表

模型 归一化方案 结构位置 是否含β/γ
GPT-4 LayerNorm Pre-LN
LLaMA-2 RMSNorm Pre-LN
Qwen-72B RMSNorm Pre-LN
DeepSeek-MoE LayerNorm Pre-LN

6.7关键结论

  1. LayerNorm仍是主流:90%以上的LLM使用
  2. Pre-LN成为标准:比原始Transformer的Post-LN更稳定
  3. RMSNorm是趋势:新模型为效率逐步转向RMSNorm
  4. 绝对不用BN:所有文本模型都避免BatchNorm
相关推荐
北辰alk8 分钟前
如何实现AI多轮对话功能及解决对话记忆持久化问题
人工智能
智驱力人工智能8 分钟前
极端高温下的智慧出行:危险检测与救援
人工智能·算法·安全·行为识别·智能巡航·高温预警·高温监测
Leo.yuan17 分钟前
数据分析师如何构建自己的底层逻辑?
大数据·数据仓库·人工智能·数据挖掘·数据分析
笑稀了的野生俊24 分钟前
ImportError: /lib/x86_64-linux-gnu/libc.so.6: version GLIBC_2.32‘ not found
linux·人工智能·ubuntu·大模型·glibc·flash-attn
吕永强25 分钟前
意识边界的算法战争—脑机接口技术重构人类认知的颠覆性挑战
人工智能·科普
二二孚日1 小时前
自用华为ICT云赛道AI第三章知识点-昇腾芯片硬件架构,昇腾芯片软件架构
人工智能·华为
蹦蹦跳跳真可爱5892 小时前
Python----OpenCV(几何变换--图像平移、图像旋转、放射变换、图像缩放、透视变换)
开发语言·人工智能·python·opencv·计算机视觉
蹦蹦跳跳真可爱5892 小时前
Python----循环神经网络(Transformer ----Layer-Normalization(层归一化))
人工智能·python·rnn·transformer
夜阳朔2 小时前
Conda环境激活失效问题
人工智能·后端·python
小Lu的开源日常2 小时前
AI模型太多太乱?用 OpenRouter,一个接口全搞定!
人工智能·llm·api