以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

以nlp中一个小批次数据,详细区分BatchNorm、LayerNorm、GroupNorm、RMSNorm。这几种归一化的不同。如下表格,从计算范围、统计量、计算复杂度以及应用场景等方面的差异给出。

方法 计算范围 统计量 计算复杂度 应用场景
BatchNorm 跨所有句子的同一维度 使用批次统计量 O(batch_size * seq_len) 适合 CNN,需要较大 batch size
LayerNorm 单个 Token 的所有维度 使用单 Token 的统计量 O(embedding_dim) 适合 Transformer,独立于 batch size
GroupNorm 单个 Token 的维度组 使用组内统计量 O(embedding_dim / num_groups) 适合小 batch size 场景
RMSNorm 单个 Token 的所有维度,但简化计算 只使用 RMS 值 O(embedding_dim),但运算更简单 适合需要高效计算的场景

假设小批量数据是:

python 复制代码
句子1: "我来自加拿大,那是一个美丽的国家" (10个token)
句子2: "你来自加拿大的哪个城市?" (8个token)

假设各个句子的嵌入式数值如下:

python 复制代码
import torch

def create_sample_data():
    """创建示例数据:两个中文句子的词嵌入
    句子1: "我/来自/加拿大/,/那是/一个/美丽的/国家"
    句子2: "你/来自/加拿大/的/哪个/城市/?"
    """
    # 创建两个示例句子的嵌入
    # batch_size=2, max_seq_len=8 (用padding补齐), embedding_dim=10
    data = torch.tensor([
        # 句子1的词嵌入 (8个token)
        [
            [2.1, -1.5, 0.8, 3.2, -0.4, 1.7, -2.3, 0.5, 1.9, -1.1],  # "我"
            [1.5, 2.2, -0.7, 1.8, 2.5, -1.2, 1.6, -0.8, 2.0, 1.4],   # "来自"
            [2.8, -1.9, 1.5, -0.6, 2.1, 1.8, -1.4, 2.2, -0.5, 1.7],  # "加拿大"
            [0.5, 1.2, -1.8, 2.4, -0.9, 1.5, -2.0, 0.7, 1.6, -1.3],  # ","
            [1.9, -2.1, 0.6, 1.7, -1.5, 2.3, -0.8, 1.4, -1.9, 2.5],  # "那是"
            [2.2, 1.6, -1.1, 2.0, -0.3, 1.9, -1.7, 0.9, 2.4, -0.6],  # "一个"
            [1.7, -1.4, 2.3, -0.5, 1.8, -2.2, 0.8, 1.5, -1.2, 2.1],  # "美丽的"
            [2.4, -0.8, 1.6, -1.3, 2.7, -0.4, 1.2, -1.8, 2.3, -0.7], # "国家"
        ],
        # 句子2的词嵌入 (7个token + 1个padding)
        [
            [1.8, -2.0, 1.2, 2.8, -0.9, 2.1, -1.8, 0.3, 2.2, -0.7],  # "你"
            [1.4, 2.5, -0.8, 1.9, 2.3, -1.5, 1.7, -0.6, 2.1, 1.2],   # "来自"
            [2.6, -1.7, 1.4, -0.5, 2.2, 1.6, -1.3, 2.4, -0.4, 1.8],  # "加拿大"
            [0.7, 1.3, -1.6, 2.2, -1.0, 1.4, -1.9, 0.8, 1.5, -1.4],  # "的"
            [2.0, -1.8, 0.9, 1.6, -1.2, 2.5, -0.7, 1.3, -1.7, 2.3],  # "哪个"
            [2.3, 1.5, -1.2, 1.8, -0.5, 1.7, -1.6, 1.0, 2.2, -0.8],  # "城市"
            [1.6, -1.3, 2.1, -0.4, 1.9, -2.0, 0.6, 1.7, -1.1, 2.0],  # "?"
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],      # padding
        ]
    ], dtype=torch.float32)
    
    return data

这段代码生成了一个 2×8×10 的张量,表示 2 个句子,每个句子有 8个 Token,每个 Token 的嵌入维度为 10。下面是各种norm的实现。具体每种的详细讲解请看:

RMSNormGroupNormLayerNormBatchNorm

提示:关注形状。

python 复制代码
 # 创建掩码来处理不同长度的句子
    mask = torch.zeros(batch_size, max_seq_len)
    mask[0, :8] = 1  # 第一个句子长度为8
    mask[1, :7] = 1  # 第二个句子长度为7
    
    return embeddings, mask

def batch_norm(x, eps=1e-5):
    """BatchNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在batch和seq_len维度上计算均值和方差
    # mean shape: [1, 1, embedding_dim]
    mean = x.mean(dim=(0, 1), keepdim=True)
    # var shape: [1, 1, embedding_dim]
    var = x.var(dim=(0, 1), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def layer_norm(x, eps=1e-5):
    """LayerNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 在最后一个维度(embedding_dim)上计算均值和方差
    # mean shape: [batch_size, seq_len, 1]
    mean = x.mean(dim=-1, keepdim=True)
    # var shape: [batch_size, seq_len, 1]
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    return normalized

def group_norm(x, num_groups=2, eps=1e-5):
    """GroupNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        num_groups: 分组数
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    batch_size, seq_len, embedding_dim = x.shape
    
    # 重塑张量以进行分组归一化
    # 将embedding_dim分成num_groups组
    x = x.reshape(batch_size, seq_len, num_groups, embedding_dim // num_groups)
    
    # 在seq_len和每组内计算均值和方差
    # mean shape: [batch_size, 1, num_groups, 1]
    mean = x.mean(dim=(1, 3), keepdim=True)
    # var shape: [batch_size, 1, num_groups, 1]
    var = x.var(dim=(1, 3), unbiased=False, keepdim=True)
    
    # 归一化
    normalized = (x - mean) / torch.sqrt(var + eps)
    
    # 重塑回原始形状
    normalized = normalized.reshape(batch_size, seq_len, embedding_dim)
    
    return normalized

def rms_norm(x, eps=1e-5):
    """RMSNorm实现
    Args:
        x: shape [batch_size, seq_len, embedding_dim]
        eps: 数值稳定性常数
    Returns:
        normalized: shape [batch_size, seq_len, embedding_dim]
    """
    # 计算RMS (Root Mean Square)
    # rms shape: [batch_size, seq_len, 1]
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    
    # 归一化 (只除以RMS,不减均值)
    normalized = x / rms
    
    return normalized
相关推荐
运器1235 分钟前
【一起来学AI大模型】PyTorch DataLoader 实战指南
大数据·人工智能·pytorch·python·深度学习·ai·ai编程
超龄超能程序猿19 分钟前
(5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
人工智能·python·机器学习·numpy·pandas·scipy
卷福同学20 分钟前
【AI编程】AI+高德MCP不到10分钟搞定上海三日游
人工智能·算法·程序员
帅次28 分钟前
系统分析师-计算机系统-输入输出系统
人工智能·分布式·深度学习·神经网络·架构·系统架构·硬件架构
AndrewHZ1 小时前
【图像处理基石】如何入门大规模三维重建?
人工智能·深度学习·大模型·llm·三维重建·立体视觉·大规模三维重建
5G行业应用1 小时前
【赠书福利,回馈公号读者】《智慧城市与智能网联汽车,融合创新发展之路》
人工智能·汽车·智慧城市
悟空胆好小1 小时前
分音塔科技(BABEL Technology) 的公司背景、股权构成、产品类型及技术能力的全方位解读
网络·人工智能·科技·嵌入式硬件
探讨探讨AGV1 小时前
以科技赋能未来,科聪持续支持青年创新实践 —— 第七届“科聪杯”浙江省大学生智能机器人创意竞赛圆满落幕
人工智能·科技·机器人
cwn_1 小时前
回归(多项式回归)
人工智能·机器学习·数据挖掘·回归
聚客AI2 小时前
🔥 大模型开发进阶:基于LangChain的异步流式响应与性能优化
人工智能·langchain·agent