以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

以nlp中一个小批次数据,详细区分BatchNorm、LayerNorm、GroupNorm、RMSNorm。这几种归一化的不同。如下表格,从计算范围、统计量、计算复杂度以及应用场景等方面的差异给出。

方法 计算范围 统计量 计算复杂度 应用场景
BatchNorm 跨所有句子的同一维度 使用批次统计量 O(batch_size * seq_len) 适合 CNN,需要较大 batch size
LayerNorm 单个 Token 的所有维度 使用单 Token 的统计量 O(embedding_dim) 适合 Transformer,独立于 batch size
GroupNorm 单个 Token 的维度组 使用组内统计量 O(embedding_dim / num_groups) 适合小 batch size 场景
RMSNorm 单个 Token 的所有维度,但简化计算 只使用 RMS 值 O(embedding_dim),但运算更简单 适合需要高效计算的场景

假设小批量数据是:

python 复制代码
句子1: "我来自加拿大,那是一个美丽的国家" (10个token)
句子2: "你来自加拿大的哪个城市?" (8个token)

假设各个句子的嵌入式数值如下:

python 复制代码
import torch

def create_sample_data():
    """创建示例数据:两个中文句子的词嵌入
    句子1: "我/来自/加拿大/,/那是/一个/美丽的/国家"
    句子2: "你/来自/加拿大/的/哪个/城市/?"
    """
    # 创建两个示例句子的嵌入
    # batch_size=2, max_seq_len=8 (用padding补齐), embedding_dim=10
    data = torch.tensor([
        # 句子1的词嵌入 (8个token)
        [
            [2.1, -1.5, 0.8, 3.2, -0.4, 1.7, -2.3, 0.5, 1.9, -1.1],  # "我"
            [1.5, 2.2, -0.7, 1.8, 2.5, -1.2, 1.6, -0.8, 2.0, 1.4],   # "来自"
            [2.8, -1.9, 1.5, -0.6, 2.1, 1.8, -1.4, 2.2, -0.5, 1.7],  # "加拿大"
            [0.5, 1.2, -1.8, 2.4, -0.9, 1.5, -2.0, 0.7, 1.6, -1.3],  # ","
            [1.9, -2.1, 0.6, 1.7, -1.5, 2.3, -0.8, 1.4, -1.9, 2.5],  # "那是"
            [2.2, 1.6, -1.1, 2.0, -0.3, 1.9, -1.7, 0.9, 2.4, -0.6],  # "一个"
            [1.7, -1.4, 2.3, -0.5, 1.8, -2.2, 0.8, 1.5, -1.2, 2.1],  # "美丽的"
            [2.4, -0.8, 1.6, -1.3, 2.7, -0.4, 1.2, -1.8, 2.3, -0.7], # "国家"
        ],
        # 句子2的词嵌入 (7个token + 1个padding)
        [
            [1.8, -2.0, 1.2, 2.8, -0.9, 2.1, -1.8, 0.3, 2.2, -0.7],  # "你"
            [1.4, 2.5, -0.8, 1.9, 2.3, -1.5, 1.7, -0.6, 2.1, 1.2],   # "来自"
            [2.6, -1.7, 1.4, -0.5, 2.2, 1.6, -1.3, 2.4, -0.4, 1.8],  # "加拿大"
            [0.7, 1.3, -1.6, 2.2, -1.0, 1.4, -1.9, 0.8, 1.5, -1.4],  # "的"
            [2.0, -1.8, 0.9, 1.6, -1.2, 2.5, -0.7, 1.3, -1.7, 2.3],  # "哪个"
            [2.3, 1.5, -1.2, 1.8, -0.5, 1.7, -1.6, 1.0, 2.2, -0.8],  # "城市"
            [1.6, -1.3, 2.1, -0.4, 1.9, -2.0, 0.6, 1.7, -1.1, 2.0],  # "?"
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],      # padding
        ]
    ], dtype=torch.float32)
    
    return data

这段代码生成了一个 2×8×10 的张量,表示 2 个句子,每个句子有 8个 Token,每个 Token 的嵌入维度为 10。下面是各种norm的实现。具体每种的详细讲解请看:

RMSNormGroupNormLayerNormBatchNorm

提示:关注形状。

python 复制代码
 # 创建掩码来处理不同长度的句子
    mask = torch.zeros(batch_size, max_seq_len)
    mask[0, :8] = 1  # 第一个句子长度为8
    mask[1, :7] = 1  # 第二个句子长度为7
    
    return embeddings, mask

def batch_norm(x, eps=1e-5):
    """BatchNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在batch和seq_len维度上计算均值和方差
    # mean shape: [1, 1, embedding_dim]
    mean = x.mean(dim=(0, 1), keepdim=True)
    # var shape: [1, 1, embedding_dim]
    var = x.var(dim=(0, 1), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def layer_norm(x, eps=1e-5):
    """LayerNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在最后一个维度(embedding_dim)上计算均值和方差
    # mean shape: [batch_size, seq_len, 1]
    mean = x.mean(dim=-1, keepdim=True)
    # var shape: [batch_size, seq_len, 1]
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def group_norm(x, num_groups=2, eps=1e-5):
    """GroupNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        num_groups: 分组数
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    batch_size, seq_len, embedding_dim = x.shape
    
    # 重塑张量以进行分组归一化
    # 将embedding_dim分成num_groups组
    x = x.reshape(batch_size, seq_len, num_groups, embedding_dim // num_groups)
    
    # 在seq_len和每组内计算均值和方差
    # mean shape: [batch_size, 1, num_groups, 1]
    mean = x.mean(dim=(1, 3), keepdim=True)
    # var shape: [batch_size, 1, num_groups, 1]
    var = x.var(dim=(1, 3), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    # 重塑回原始形状
    normalized = normalized.reshape(batch_size, seq_len, embedding_dim)
    
    return normalized

def rms_norm(x, eps=1e-5):
    """RMSNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 计算RMS (Root Mean Square)
    # rms shape: [batch_size, seq_len, 1]
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    
    # 归一化 (只除以RMS,不减均值)
    normalized = x / rms
    
    return normalized
相关推荐
QQ_77813297417 分钟前
基于深度学习的图像超分辨率重建
人工智能·机器学习·超分辨率重建
清 晨30 分钟前
Web3 生态全景:创新与发展之路
人工智能·web3·去中心化·智能合约
公众号Codewar原创作者1 小时前
R数据分析:工具变量回归的做法和解释,实例解析
开发语言·人工智能·python
IT古董1 小时前
【漫话机器学习系列】020.正则化强度的倒数C(Inverse of regularization strength)
人工智能·机器学习
进击的小小学生1 小时前
机器学习连载
人工智能·机器学习
Trouvaille ~1 小时前
【机器学习】从流动到恒常,无穷中归一:积分的数学诗意
人工智能·python·机器学习·ai·数据分析·matplotlib·微积分
dundunmm1 小时前
论文阅读:Deep Fusion Clustering Network With Reliable Structure Preservation
论文阅读·人工智能·数据挖掘·聚类·深度聚类·图聚类
szxinmai主板定制专家2 小时前
【国产NI替代】基于FPGA的4通道电压 250M采样终端边缘计算采集板卡,主控支持龙芯/飞腾
人工智能·边缘计算
是十一月末2 小时前
Opencv实现图像的腐蚀、膨胀及开、闭运算
人工智能·python·opencv·计算机视觉
云空2 小时前
《探索PyTorch计算机视觉:原理、应用与实践》
人工智能·pytorch·python·深度学习·计算机视觉