layer norm和 rms norm 对比

Layer norm

python 复制代码
# Layer Norm 公式
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True)
output = (x - mean) / sqrt(var + eps) * gamma + beta

特点:

  • 减去均值(去中心化)
  • 除以标准差(标准化)
  • 包含可学习参数 gamma 和 beta
  • 计算复杂度相对较高

RMS Norm(Root Mean Square归一化):

python 复制代码
# RMS Norm 公式
rms = sqrt(mean(x²))
output = x / rms * gamma

特点:

  • 不减去均值(保持中心)
  • 只除以RMS值
  • 只有一个可学习参数 gamma
  • 计算更简单高效

对比


代码对比

python 复制代码
import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.gamma
相关推荐
学历真的很重要1 分钟前
Hello-Agents —— 03大语言模型基础 通俗总结
开发语言·人工智能·后端·语言模型·自然语言处理·面试·langchain
OpenCSG1 小时前
OpenCSG 2025年11月月报:智能体平台、AI技术合作与开源生态进展
人工智能·开源·opencsg·csghub
围炉聊科技1 小时前
当AI成为“大脑”:人类如何在机器时代找到不可替代的价值?
人工智能
لا معنى له1 小时前
残差网络论文学习笔记:Deep Residual Learning for Image Recognition全文翻译
网络·人工智能·笔记·深度学习·学习·机器学习
菜只因C2 小时前
深度学习:从技术本质到未来图景的全面解析
人工智能·深度学习
工业机器视觉设计和实现2 小时前
lenet改vgg训练cifar10突破71分
人工智能·机器学习
咚咚王者2 小时前
人工智能之数据分析 Matplotlib:第四章 图形类型
人工智能·数据分析·matplotlib
TTGGGFF2 小时前
人工智能:用Gemini 3一键生成3D粒子电子手部映射应用
人工智能·3d·交互
LitchiCheng2 小时前
Mujoco 基础:获取模型中所有 body 的 name, id 以及位姿
人工智能·python
Allen_LVyingbo2 小时前
面向医学影像检测的深度学习模型参数分析与优化策略研究
人工智能·深度学习