深度学习中的归一化技术详解:BN、LN、IN、GN

1. Batch Normalization (BN, 2015)

核心思想

  • 对Batch维度的每个特征通道进行归一化
  • 训练时用当前batch统计量,测试时用全局移动平均

计算步骤

python 复制代码
# 输入x形状: [B, C, H, W] (CNN) 或 [B, D] (全连接)
mean = x.mean(dim=[0, 2, 3])  # 沿Batch/空间维度求均值
var = x.var(dim=[0, 2, 3], unbiased=False)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta  # 可学习参数γ,β

μₖ = (1/(B×H×W)) × Σ(xᵢⱼₖₗ) (i=1...B, j=1...H, l=1...W)

σₖ² = (1/(B×H×W)) × Σ(xᵢⱼₖₗ - μₖ)²

x̂ᵢⱼₖₗ = (xᵢⱼₖₗ - μₖ) / √(σₖ² + ε)

yᵢⱼₖₗ = γₖ × x̂ᵢⱼₖₗ + βₖ

特点

优点 缺点
✅ 加速收敛 ✅ 允许更大学习率 ✅ 有正则化效果 ❌ 依赖大batch(通常>16) ❌ 不适用于RNN/Dynamic NN

PyTorch实现

python 复制代码
nn.BatchNorm2d(num_features)  # CNN
nn.BatchNorm1d(num_features)  # FC/RNN

2. Layer Normalization (LN, 2016)

核心思想

  • 对每个样本的所有特征进行归一化
  • 常用于Transformer和RNN

计算步骤

python 复制代码
# 输入x形状: [B, T, D] (Transformer)
mean = x.mean(dim=-1, keepdim=True)  # 特征维度
var = x.var(dim=-1, keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta

μᵢ = (1/D) × Σ(xᵢⱼ) (j=1...D)

σᵢ² = (1/D) × Σ(xᵢⱼ - μᵢ)²

x̂ᵢⱼ = (xᵢⱼ - μᵢ) / √(σᵢ² + ε)

yᵢⱼ = γⱼ × x̂ᵢⱼ + βⱼ

特点

优点 缺点
✅ 不依赖batch size ✅ 适合动态网络 ❌ CNN效果不如BN

PyTorch实现

python 复制代码
nn.LayerNorm(normalized_shape)  # normalized_shape=D

3. Instance Normalization (IN, 2017)

核心思想

  • 对每个样本的每个通道单独归一化
  • 风格迁移任务常用

计算步骤

python 复制代码
# 输入x形状: [B, C, H, W]
mean = x.mean(dim=[2, 3], keepdim=True)  # 空间维度
var = x.var(dim=[2, 3], keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta  # 可选

μᵢₖ = (1/(H×W)) × Σ(xᵢₖⱼₗ) (j=1...H, l=1...W)

σᵢₖ² = (1/(H×W)) × Σ(xᵢₖⱼₗ - μᵢₖ)²

x̂ᵢₖⱼₗ = (xᵢₖⱼₗ - μᵢₖ) / √(σᵢₖ² + ε)

yᵢₖⱼₗ = γₖ × x̂ᵢₖⱼₗ + βₖ (可选)

特点

优点 缺点
✅ 保留样本间独立性 ✅ 适合风格迁移 ❌ 破坏通道间相关性

PyTorch实现

python 复制代码
nn.InstanceNorm2d(num_features)

4. Group Normalization (GN, 2018)

核心思想

  • 将通道分组后对每组进行归一化
  • CNN小batch场景的BN替代方案

计算步骤

python 复制代码
# 输入x形状: [B, C, H, W], 设groups=G
x = x.view(B, G, C//G, H, W)  # 分组
mean = x.mean(dim=[2, 3, 4], keepdim=True)
var = x.var(dim=[2, 3, 4], keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = x_hat.view(B, C, H, W) * gamma + beta
分组后形状: [B, G, C//G, H, W]

μᵢ₉ = (1/((C//G)×H×W)) × Σ(xᵢ₉ₖⱼₗ)

σᵢ₉² = (1/((C//G)×H×W)) × Σ(xᵢ₉ₖⱼₗ - μᵢ₉)²

x̂ᵢ₉ₖⱼₗ = (xᵢ₉ₖⱼₗ - μᵢ₉) / √(σᵢ₉² + ε)

恢复形状后:

yᵢₖⱼₗ = γₖ × x̂ᵢₖⱼₗ + βₖ

特点

优点 缺点
✅ 小batch表现好 ✅ 精度接近BN ❌ 计算量稍大

PyTorch实现

python 复制代码
nn.GroupNorm(num_groups, num_channels)

5.对比总结

方法 归一化维度 适用场景 Batch依赖
BN [B, H, W] 大batch/CNN
LN [D] RNN/Transformer
IN [H, W] 风格迁移/生成模型
GN [G, H, W] 小batch CNN

代码示例(四种归一化对比)

python 复制代码
import torch.nn as nn

# 输入假设: [2, 6, 224, 224] (batch=2, channels=6)
bn = nn.BatchNorm2d(6)
ln = nn.LayerNorm([6, 224, 224])  # 全特征归一化
in = nn.InstanceNorm2d(6)
gn = nn.GroupNorm(num_groups=3, num_channels=6)  # 分2组

如何选择?

  1. CNN:优先尝试BN → batch<8时用GN
  2. RNN/Transformer:必选LN
  3. Style Transfer:首选IN
  4. 小batch CNN:GN+LN组合

📌 经验法则:当BN效果不佳时,根据任务特性尝试其他归一化方法


6. Transformer架构中的归一化标准方案

现代大语言模型普遍采用 Pre-LayerNorm 结构,即在注意力/FFN层之前进行归一化:

复制代码
输入 → LayerNorm → Attention → 残差连接 → LayerNorm → FFN → 残差连接

6.1 ChatGPT (OpenAI GPT系列)

模型版本 归一化方案 关键细节
GPT-2 LayerNorm 经典Post-LN
GPT-3 LayerNorm 改为Pre-LN
GPT-4 LayerNorm + 改进 可能引入RMSNorm

特点

  • 始终坚持LayerNorm
  • 从Post-LN转向更稳定的Pre-LN结构

6.2 DeepSeek

模型版本 归一化方案 关键细节
DeepSeek-MoE LayerNorm Pre-LN结构
DeepSeek-Coder LayerNorm 代码模型同样架构

创新点

  • 在MoE架构中保持LayerNorm一致性
  • 对长上下文优化了Norm位置

6.3 Qwen (阿里通义千问)

模型版本 归一化方案 关键细节
Qwen-1.8B LayerNorm 标准实现
Qwen-72B RMSNorm 性能优化

技术演进

  • 大参数模型改用RMSNorm减少计算量
  • 保留LayerNorm的缩放偏移参数

6.4为什么不用BatchNorm?

所有主流LLM都避免使用BN,原因包括:

  1. 序列长度可变:BN需要固定维度,但文本长度动态变化
  2. 小batch推理:预测时batch_size=1,BN统计量失效
  3. 训练不稳定:文本数据的稀疏性导致BN方差估计不准

6.5 进阶变体:RMSNorm

新兴模型(如LLaMA、Qwen-72B)开始采用 RMSNorm(Root Mean Square Normalization):

python 复制代码
def rms_norm(x, eps=1e-6):
    # 去均值操作(相比LayerNorm)
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps

RMS(x) = √((1/D) × Σ(xⱼ²) + ε)

yᵢ = (xᵢ / RMS(x)) × γᵢ

优势

  • 计算量减少约20%(适合超大模型)
  • 在Transformer中表现接近LayerNorm

6.6 模型实现对比表

模型 归一化方案 结构位置 是否含β/γ
GPT-4 LayerNorm Pre-LN
LLaMA-2 RMSNorm Pre-LN
Qwen-72B RMSNorm Pre-LN
DeepSeek-MoE LayerNorm Pre-LN

6.7关键结论

  1. LayerNorm仍是主流:90%以上的LLM使用
  2. Pre-LN成为标准:比原始Transformer的Post-LN更稳定
  3. RMSNorm是趋势:新模型为效率逐步转向RMSNorm
  4. 绝对不用BN:所有文本模型都避免BatchNorm
相关推荐
机器之心13 分钟前
FlashAttention-4震撼来袭,原生支持Blackwell GPU,英伟达的护城河更深了?
人工智能·openai
IT_陈寒13 分钟前
Python 3.12 新特性实战:5个让你的代码效率提升50%的技巧!🔥
前端·人工智能·后端
点云SLAM27 分钟前
PyTorch中 nn.Linear详解和实战示例
人工智能·pytorch·python·深度学习·cnn·transformer·mlp
耳东哇1 小时前
在使用spring ai进行llm处理的rag的时候,选择milvus还是neo4j呢?
人工智能·neo4j·milvus
过往入尘土1 小时前
深入浅出 PyTorch:从下载安装到核心知识点全解析
人工智能·pytorch·python
youcans_1 小时前
【AGI使用教程】GPT-OSS 本地部署(2)
人工智能·gpt·大语言模型·模型部署·webui
鲸鱼24011 小时前
支持向量机
人工智能·机器学习·支持向量机
CoovallyAIHub2 小时前
无需ReID网络!FastTracker凭借几何与场景认知实现多目标跟踪新SOTA,助力智慧交通更轻更快
深度学习·算法·计算机视觉
AImatters2 小时前
透视光合组织大会:算力生态重构金融AI落地新实践
人工智能·合合信息·国产算力·海光dcu·光合组织·光合大会·青云
DDC楼宇自控与IBMS集成系统解读2 小时前
BA 楼宇自控系统 + AI:重构楼宇设备管理的 “智能决策” 体系
大数据·网络·数据库·人工智能·3d·重构