以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

以nlp中一个小批次数据,详细区分BatchNorm、LayerNorm、GroupNorm、RMSNorm。这几种归一化的不同。如下表格,从计算范围、统计量、计算复杂度以及应用场景等方面的差异给出。

方法 计算范围 统计量 计算复杂度 应用场景
BatchNorm 跨所有句子的同一维度 使用批次统计量 O(batch_size * seq_len) 适合 CNN,需要较大 batch size
LayerNorm 单个 Token 的所有维度 使用单 Token 的统计量 O(embedding_dim) 适合 Transformer,独立于 batch size
GroupNorm 单个 Token 的维度组 使用组内统计量 O(embedding_dim / num_groups) 适合小 batch size 场景
RMSNorm 单个 Token 的所有维度,但简化计算 只使用 RMS 值 O(embedding_dim),但运算更简单 适合需要高效计算的场景

假设小批量数据是:

python 复制代码
句子1: "我来自加拿大,那是一个美丽的国家" (10个token)
句子2: "你来自加拿大的哪个城市?" (8个token)

假设各个句子的嵌入式数值如下:

python 复制代码
import torch

def create_sample_data():
    """创建示例数据:两个中文句子的词嵌入
    句子1: "我/来自/加拿大/,/那是/一个/美丽的/国家"
    句子2: "你/来自/加拿大/的/哪个/城市/?"
    """
    # 创建两个示例句子的嵌入
    # batch_size=2, max_seq_len=8 (用padding补齐), embedding_dim=10
    data = torch.tensor([
        # 句子1的词嵌入 (8个token)
        [
            [2.1, -1.5, 0.8, 3.2, -0.4, 1.7, -2.3, 0.5, 1.9, -1.1],  # "我"
            [1.5, 2.2, -0.7, 1.8, 2.5, -1.2, 1.6, -0.8, 2.0, 1.4],   # "来自"
            [2.8, -1.9, 1.5, -0.6, 2.1, 1.8, -1.4, 2.2, -0.5, 1.7],  # "加拿大"
            [0.5, 1.2, -1.8, 2.4, -0.9, 1.5, -2.0, 0.7, 1.6, -1.3],  # ","
            [1.9, -2.1, 0.6, 1.7, -1.5, 2.3, -0.8, 1.4, -1.9, 2.5],  # "那是"
            [2.2, 1.6, -1.1, 2.0, -0.3, 1.9, -1.7, 0.9, 2.4, -0.6],  # "一个"
            [1.7, -1.4, 2.3, -0.5, 1.8, -2.2, 0.8, 1.5, -1.2, 2.1],  # "美丽的"
            [2.4, -0.8, 1.6, -1.3, 2.7, -0.4, 1.2, -1.8, 2.3, -0.7], # "国家"
        ],
        # 句子2的词嵌入 (7个token + 1个padding)
        [
            [1.8, -2.0, 1.2, 2.8, -0.9, 2.1, -1.8, 0.3, 2.2, -0.7],  # "你"
            [1.4, 2.5, -0.8, 1.9, 2.3, -1.5, 1.7, -0.6, 2.1, 1.2],   # "来自"
            [2.6, -1.7, 1.4, -0.5, 2.2, 1.6, -1.3, 2.4, -0.4, 1.8],  # "加拿大"
            [0.7, 1.3, -1.6, 2.2, -1.0, 1.4, -1.9, 0.8, 1.5, -1.4],  # "的"
            [2.0, -1.8, 0.9, 1.6, -1.2, 2.5, -0.7, 1.3, -1.7, 2.3],  # "哪个"
            [2.3, 1.5, -1.2, 1.8, -0.5, 1.7, -1.6, 1.0, 2.2, -0.8],  # "城市"
            [1.6, -1.3, 2.1, -0.4, 1.9, -2.0, 0.6, 1.7, -1.1, 2.0],  # "?"
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],      # padding
        ]
    ], dtype=torch.float32)
    
    return data

这段代码生成了一个 2×8×10 的张量,表示 2 个句子,每个句子有 8个 Token,每个 Token 的嵌入维度为 10。下面是各种norm的实现。具体每种的详细讲解请看:

RMSNormGroupNormLayerNormBatchNorm

提示:关注形状。

python 复制代码
 # 创建掩码来处理不同长度的句子
    mask = torch.zeros(batch_size, max_seq_len)
    mask[0, :8] = 1  # 第一个句子长度为8
    mask[1, :7] = 1  # 第二个句子长度为7
    
    return embeddings, mask

def batch_norm(x, eps=1e-5):
    """BatchNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在batch和seq_len维度上计算均值和方差
    # mean shape: [1, 1, embedding_dim]
    mean = x.mean(dim=(0, 1), keepdim=True)
    # var shape: [1, 1, embedding_dim]
    var = x.var(dim=(0, 1), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def layer_norm(x, eps=1e-5):
    """LayerNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在最后一个维度(embedding_dim)上计算均值和方差
    # mean shape: [batch_size, seq_len, 1]
    mean = x.mean(dim=-1, keepdim=True)
    # var shape: [batch_size, seq_len, 1]
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def group_norm(x, num_groups=2, eps=1e-5):
    """GroupNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        num_groups: 分组数
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    batch_size, seq_len, embedding_dim = x.shape
    
    # 重塑张量以进行分组归一化
    # 将embedding_dim分成num_groups组
    x = x.reshape(batch_size, seq_len, num_groups, embedding_dim // num_groups)
    
    # 在seq_len和每组内计算均值和方差
    # mean shape: [batch_size, 1, num_groups, 1]
    mean = x.mean(dim=(1, 3), keepdim=True)
    # var shape: [batch_size, 1, num_groups, 1]
    var = x.var(dim=(1, 3), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    # 重塑回原始形状
    normalized = normalized.reshape(batch_size, seq_len, embedding_dim)
    
    return normalized

def rms_norm(x, eps=1e-5):
    """RMSNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 计算RMS (Root Mean Square)
    # rms shape: [batch_size, seq_len, 1]
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    
    # 归一化 (只除以RMS,不减均值)
    normalized = x / rms
    
    return normalized
相关推荐
机器视觉_Explorer11 小时前
【halcon】编程技巧:鼠标擦除
图像处理·人工智能·深度学习·算法·视觉检测
杨航 AI11 小时前
XGBoost · 登录防欺诈示例
人工智能
拖拖76511 小时前
Scaling Laws for Neural Language Models:大模型为什么可以被“规模化预测”?
人工智能
何陋轩11 小时前
Spring AI实战指南:在Java项目中集成大语言模型
人工智能·后端·机器学习
暗夜猎手-大魔王12 小时前
转载--Karpathy 怎么看 AI Agent(三):怎么给 Agent 搭一个真正能用的上下文
人工智能
每日综合12 小时前
UKey Wallet 产品体系:移动端应用、硬件安全设备与助记词备份设备
人工智能
天天进步201512 小时前
Python全栈项目实战:基于深度学习的语音合成(TTS)系统
开发语言·python·深度学习
阿里云大数据AI技术12 小时前
基于 MaxCompute Delta Table 实现 SCD Type 2:Time Travel 驱动的维度变更追踪方案
人工智能
听麟12 小时前
HarmonyOS 6.0+ PC端离线翻译工具开发实战:端侧AI模型集成与多格式内容翻译落地
人工智能·华为·harmonyos
摆烂大大王12 小时前
AI 日报|2026年5月8日:xAI解散、DeepSeek融资450亿美元、苹果AI耳机入DVT尾声
人工智能