layer norm和 rms norm 对比

Layer norm

python 复制代码
# Layer Norm 公式
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True)
output = (x - mean) / sqrt(var + eps) * gamma + beta

特点:

  • 减去均值(去中心化)
  • 除以标准差(标准化)
  • 包含可学习参数 gamma 和 beta
  • 计算复杂度相对较高

RMS Norm(Root Mean Square归一化):

python 复制代码
# RMS Norm 公式
rms = sqrt(mean(x²))
output = x / rms * gamma

特点:

  • 不减去均值(保持中心)
  • 只除以RMS值
  • 只有一个可学习参数 gamma
  • 计算更简单高效

对比


代码对比

python 复制代码
import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.gamma
相关推荐
雍凉明月夜2 分钟前
深度学习网络笔记Ⅱ(常见网络分类1)
人工智能·笔记·深度学习
北岛寒沫3 分钟前
北京大学国家发展研究院 经济学辅修 经济学原理课程笔记(第十三课 垄断竞争)
人工智能·经验分享·笔记
AI营销实验室4 分钟前
AI 工具何高质量的为销售线索打分?
大数据·人工智能
Wang201220135 分钟前
RNN和LSTM对比
人工智能·算法·架构
xueyongfu9 分钟前
从Diffusion到VLA pi0(π0)
人工智能·算法·stable diffusion
jackylzh25 分钟前
配置pytorch环境,并调试YOLO
人工智能·pytorch·yolo
杜子不疼.33 分钟前
AI Ping双款新模型同步免费解锁:GLM-4.7与MiniMax M2.1实测
人工智能
打码人的日常分享34 分钟前
企业数据资产管控和数据治理解决方案
大数据·运维·网络·人工智能·云计算
百***787535 分钟前
小米MiMo-V2-Flash深度解析:国产开源大模型标杆与海外AI接入方案
人工智能·开源
大数据追光猿37 分钟前
【Prompt】Prompt Caching:原理、实现与高并发价值
人工智能·大模型·prompt·agent