以nlp中一个小批次数据,详细区分BatchNorm、LayerNorm、GroupNorm、RMSNorm。这几种归一化的不同。如下表格,从计算范围、统计量、计算复杂度以及应用场景等方面的差异给出。
方法 | 计算范围 | 统计量 | 计算复杂度 | 应用场景 |
---|---|---|---|---|
BatchNorm | 跨所有句子的同一维度 | 使用批次统计量 | O(batch_size * seq_len) | 适合 CNN,需要较大 batch size |
LayerNorm | 单个 Token 的所有维度 | 使用单 Token 的统计量 | O(embedding_dim) | 适合 Transformer,独立于 batch size |
GroupNorm | 单个 Token 的维度组 | 使用组内统计量 | O(embedding_dim / num_groups) | 适合小 batch size 场景 |
RMSNorm | 单个 Token 的所有维度,但简化计算 | 只使用 RMS 值 | O(embedding_dim),但运算更简单 | 适合需要高效计算的场景 |
假设小批量数据是:
python
句子1: "我来自加拿大,那是一个美丽的国家" (10个token)
句子2: "你来自加拿大的哪个城市?" (8个token)
假设各个句子的嵌入式数值如下:
python
import torch
def create_sample_data():
"""创建示例数据:两个中文句子的词嵌入
句子1: "我/来自/加拿大/,/那是/一个/美丽的/国家"
句子2: "你/来自/加拿大/的/哪个/城市/?"
"""
# 创建两个示例句子的嵌入
# batch_size=2, max_seq_len=8 (用padding补齐), embedding_dim=10
data = torch.tensor([
# 句子1的词嵌入 (8个token)
[
[2.1, -1.5, 0.8, 3.2, -0.4, 1.7, -2.3, 0.5, 1.9, -1.1], # "我"
[1.5, 2.2, -0.7, 1.8, 2.5, -1.2, 1.6, -0.8, 2.0, 1.4], # "来自"
[2.8, -1.9, 1.5, -0.6, 2.1, 1.8, -1.4, 2.2, -0.5, 1.7], # "加拿大"
[0.5, 1.2, -1.8, 2.4, -0.9, 1.5, -2.0, 0.7, 1.6, -1.3], # ","
[1.9, -2.1, 0.6, 1.7, -1.5, 2.3, -0.8, 1.4, -1.9, 2.5], # "那是"
[2.2, 1.6, -1.1, 2.0, -0.3, 1.9, -1.7, 0.9, 2.4, -0.6], # "一个"
[1.7, -1.4, 2.3, -0.5, 1.8, -2.2, 0.8, 1.5, -1.2, 2.1], # "美丽的"
[2.4, -0.8, 1.6, -1.3, 2.7, -0.4, 1.2, -1.8, 2.3, -0.7], # "国家"
],
# 句子2的词嵌入 (7个token + 1个padding)
[
[1.8, -2.0, 1.2, 2.8, -0.9, 2.1, -1.8, 0.3, 2.2, -0.7], # "你"
[1.4, 2.5, -0.8, 1.9, 2.3, -1.5, 1.7, -0.6, 2.1, 1.2], # "来自"
[2.6, -1.7, 1.4, -0.5, 2.2, 1.6, -1.3, 2.4, -0.4, 1.8], # "加拿大"
[0.7, 1.3, -1.6, 2.2, -1.0, 1.4, -1.9, 0.8, 1.5, -1.4], # "的"
[2.0, -1.8, 0.9, 1.6, -1.2, 2.5, -0.7, 1.3, -1.7, 2.3], # "哪个"
[2.3, 1.5, -1.2, 1.8, -0.5, 1.7, -1.6, 1.0, 2.2, -0.8], # "城市"
[1.6, -1.3, 2.1, -0.4, 1.9, -2.0, 0.6, 1.7, -1.1, 2.0], # "?"
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # padding
]
], dtype=torch.float32)
return data
这段代码生成了一个 2×8×10 的张量,表示 2 个句子,每个句子有 8个 Token,每个 Token 的嵌入维度为 10。下面是各种norm的实现。具体每种的详细讲解请看:
RMSNorm 、GroupNorm、LayerNorm、BatchNorm
提示:关注形状。
python
# 创建掩码来处理不同长度的句子
mask = torch.zeros(batch_size, max_seq_len)
mask[0, :8] = 1 # 第一个句子长度为8
mask[1, :7] = 1 # 第二个句子长度为7
return embeddings, mask
def batch_norm(x, eps=1e-5):
"""BatchNorm实现
Args:
x: shape [batch_size, seq_len, embedding_dim]
eps: 数值稳定性常数
Returns:
normalized: shape [batch_size, seq_len, embedding_dim]
"""
# 在batch和seq_len维度上计算均值和方差
# mean shape: [1, 1, embedding_dim]
mean = x.mean(dim=(0, 1), keepdim=True)
# var shape: [1, 1, embedding_dim]
var = x.var(dim=(0, 1), unbiased=False, keepdim=True)
# 归一化
normalized = (x - mean) / torch.sqrt(var + eps)
return normalized
def layer_norm(x, eps=1e-5):
"""LayerNorm实现
Args:
x: shape [batch_size, seq_len, embedding_dim]
eps: 数值稳定性常数
Returns:
normalized: shape [batch_size, seq_len, embedding_dim]
"""
# 在最后一个维度(embedding_dim)上计算均值和方差
# mean shape: [batch_size, seq_len, 1]
mean = x.mean(dim=-1, keepdim=True)
# var shape: [batch_size, seq_len, 1]
var = x.var(dim=-1, unbiased=False, keepdim=True)
# 归一化
normalized = (x - mean) / torch.sqrt(var + eps)
return normalized
def group_norm(x, num_groups=2, eps=1e-5):
"""GroupNorm实现
Args:
x: shape [batch_size, seq_len, embedding_dim]
num_groups: 分组数
eps: 数值稳定性常数
Returns:
normalized: shape [batch_size, seq_len, embedding_dim]
"""
batch_size, seq_len, embedding_dim = x.shape
# 重塑张量以进行分组归一化
# 将embedding_dim分成num_groups组
x = x.reshape(batch_size, seq_len, num_groups, embedding_dim // num_groups)
# 在seq_len和每组内计算均值和方差
# mean shape: [batch_size, 1, num_groups, 1]
mean = x.mean(dim=(1, 3), keepdim=True)
# var shape: [batch_size, 1, num_groups, 1]
var = x.var(dim=(1, 3), unbiased=False, keepdim=True)
# 归一化
normalized = (x - mean) / torch.sqrt(var + eps)
# 重塑回原始形状
normalized = normalized.reshape(batch_size, seq_len, embedding_dim)
return normalized
def rms_norm(x, eps=1e-5):
"""RMSNorm实现
Args:
x: shape [batch_size, seq_len, embedding_dim]
eps: 数值稳定性常数
Returns:
normalized: shape [batch_size, seq_len, embedding_dim]
"""
# 计算RMS (Root Mean Square)
# rms shape: [batch_size, seq_len, 1]
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
# 归一化 (只除以RMS,不减均值)
normalized = x / rms
return normalized