以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

以nlp中一个小批次数据,详细区分BatchNorm、LayerNorm、GroupNorm、RMSNorm。这几种归一化的不同。如下表格,从计算范围、统计量、计算复杂度以及应用场景等方面的差异给出。

方法 计算范围 统计量 计算复杂度 应用场景
BatchNorm 跨所有句子的同一维度 使用批次统计量 O(batch_size * seq_len) 适合 CNN,需要较大 batch size
LayerNorm 单个 Token 的所有维度 使用单 Token 的统计量 O(embedding_dim) 适合 Transformer,独立于 batch size
GroupNorm 单个 Token 的维度组 使用组内统计量 O(embedding_dim / num_groups) 适合小 batch size 场景
RMSNorm 单个 Token 的所有维度,但简化计算 只使用 RMS 值 O(embedding_dim),但运算更简单 适合需要高效计算的场景

假设小批量数据是:

python 复制代码
句子1: "我来自加拿大,那是一个美丽的国家" (10个token)
句子2: "你来自加拿大的哪个城市?" (8个token)

假设各个句子的嵌入式数值如下:

python 复制代码
import torch

def create_sample_data():
    """创建示例数据:两个中文句子的词嵌入
    句子1: "我/来自/加拿大/,/那是/一个/美丽的/国家"
    句子2: "你/来自/加拿大/的/哪个/城市/?"
    """
    # 创建两个示例句子的嵌入
    # batch_size=2, max_seq_len=8 (用padding补齐), embedding_dim=10
    data = torch.tensor([
        # 句子1的词嵌入 (8个token)
        [
            [2.1, -1.5, 0.8, 3.2, -0.4, 1.7, -2.3, 0.5, 1.9, -1.1],  # "我"
            [1.5, 2.2, -0.7, 1.8, 2.5, -1.2, 1.6, -0.8, 2.0, 1.4],   # "来自"
            [2.8, -1.9, 1.5, -0.6, 2.1, 1.8, -1.4, 2.2, -0.5, 1.7],  # "加拿大"
            [0.5, 1.2, -1.8, 2.4, -0.9, 1.5, -2.0, 0.7, 1.6, -1.3],  # ","
            [1.9, -2.1, 0.6, 1.7, -1.5, 2.3, -0.8, 1.4, -1.9, 2.5],  # "那是"
            [2.2, 1.6, -1.1, 2.0, -0.3, 1.9, -1.7, 0.9, 2.4, -0.6],  # "一个"
            [1.7, -1.4, 2.3, -0.5, 1.8, -2.2, 0.8, 1.5, -1.2, 2.1],  # "美丽的"
            [2.4, -0.8, 1.6, -1.3, 2.7, -0.4, 1.2, -1.8, 2.3, -0.7], # "国家"
        ],
        # 句子2的词嵌入 (7个token + 1个padding)
        [
            [1.8, -2.0, 1.2, 2.8, -0.9, 2.1, -1.8, 0.3, 2.2, -0.7],  # "你"
            [1.4, 2.5, -0.8, 1.9, 2.3, -1.5, 1.7, -0.6, 2.1, 1.2],   # "来自"
            [2.6, -1.7, 1.4, -0.5, 2.2, 1.6, -1.3, 2.4, -0.4, 1.8],  # "加拿大"
            [0.7, 1.3, -1.6, 2.2, -1.0, 1.4, -1.9, 0.8, 1.5, -1.4],  # "的"
            [2.0, -1.8, 0.9, 1.6, -1.2, 2.5, -0.7, 1.3, -1.7, 2.3],  # "哪个"
            [2.3, 1.5, -1.2, 1.8, -0.5, 1.7, -1.6, 1.0, 2.2, -0.8],  # "城市"
            [1.6, -1.3, 2.1, -0.4, 1.9, -2.0, 0.6, 1.7, -1.1, 2.0],  # "?"
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],      # padding
        ]
    ], dtype=torch.float32)
    
    return data

这段代码生成了一个 2×8×10 的张量,表示 2 个句子,每个句子有 8个 Token,每个 Token 的嵌入维度为 10。下面是各种norm的实现。具体每种的详细讲解请看:

RMSNormGroupNormLayerNormBatchNorm

提示:关注形状。

python 复制代码
 # 创建掩码来处理不同长度的句子
    mask = torch.zeros(batch_size, max_seq_len)
    mask[0, :8] = 1  # 第一个句子长度为8
    mask[1, :7] = 1  # 第二个句子长度为7
    
    return embeddings, mask

def batch_norm(x, eps=1e-5):
    """BatchNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在batch和seq_len维度上计算均值和方差
    # mean shape: [1, 1, embedding_dim]
    mean = x.mean(dim=(0, 1), keepdim=True)
    # var shape: [1, 1, embedding_dim]
    var = x.var(dim=(0, 1), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def layer_norm(x, eps=1e-5):
    """LayerNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在最后一个维度(embedding_dim)上计算均值和方差
    # mean shape: [batch_size, seq_len, 1]
    mean = x.mean(dim=-1, keepdim=True)
    # var shape: [batch_size, seq_len, 1]
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def group_norm(x, num_groups=2, eps=1e-5):
    """GroupNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        num_groups: 分组数
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    batch_size, seq_len, embedding_dim = x.shape
    
    # 重塑张量以进行分组归一化
    # 将embedding_dim分成num_groups组
    x = x.reshape(batch_size, seq_len, num_groups, embedding_dim // num_groups)
    
    # 在seq_len和每组内计算均值和方差
    # mean shape: [batch_size, 1, num_groups, 1]
    mean = x.mean(dim=(1, 3), keepdim=True)
    # var shape: [batch_size, 1, num_groups, 1]
    var = x.var(dim=(1, 3), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    # 重塑回原始形状
    normalized = normalized.reshape(batch_size, seq_len, embedding_dim)
    
    return normalized

def rms_norm(x, eps=1e-5):
    """RMSNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 计算RMS (Root Mean Square)
    # rms shape: [batch_size, seq_len, 1]
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    
    # 归一化 (只除以RMS,不减均值)
    normalized = x / rms
    
    return normalized
相关推荐
weixin_437497771 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端1 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat1 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技1 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪1 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子2 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z2 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人2 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程
橙汁味的风2 小时前
1隐马尔科夫模型HMM与条件随机场CRF
人工智能·深度学习·机器学习
itwangyang5202 小时前
AIDD-人工智能药物设计-AI 制药编码之战:预测癌症反应,选对方法是关键
人工智能