layer norm和 rms norm 对比

Layer norm

python 复制代码
# Layer Norm 公式
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True)
output = (x - mean) / sqrt(var + eps) * gamma + beta

特点:

  • 减去均值(去中心化)
  • 除以标准差(标准化)
  • 包含可学习参数 gamma 和 beta
  • 计算复杂度相对较高

RMS Norm(Root Mean Square归一化):

python 复制代码
# RMS Norm 公式
rms = sqrt(mean(x²))
output = x / rms * gamma

特点:

  • 不减去均值(保持中心)
  • 只除以RMS值
  • 只有一个可学习参数 gamma
  • 计算更简单高效

对比


代码对比

python 复制代码
import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.gamma
相关推荐
行云流水剑7 分钟前
【学习记录】深入解析 AI 交互中的五大核心概念:Prompt、Agent、MCP、Function Calling 与 Tools
人工智能·学习·交互
love530love15 分钟前
【笔记】在 MSYS2(MINGW64)中正确安装 Rust
运维·开发语言·人工智能·windows·笔记·python·rust
狂小虎15 分钟前
02 Deep learning神经网络的编程基础 逻辑回归--吴恩达
深度学习·神经网络·逻辑回归
A林玖20 分钟前
【机器学习】主成分分析 (PCA)
人工智能·机器学习
Jamence24 分钟前
多模态大语言模型arxiv论文略读(108)
论文阅读·人工智能·语言模型·自然语言处理·论文笔记
tongxianchao25 分钟前
双空间知识蒸馏用于大语言模型
人工智能·语言模型·自然语言处理
苗老大27 分钟前
MMRL: Multi-Modal Representation Learning for Vision-Language Models(多模态表示学习)
人工智能·学习·语言模型
中达瑞和-高光谱·多光谱38 分钟前
中达瑞和SHIS高光谱相机在黑色水彩笔墨迹鉴定中的应用
人工智能·数码相机
F_D_Z1 小时前
8K样本在DeepSeek-R1-7B模型上的复现效果
人工智能·deepseek·deepseek-r1-7b
地藏Kelvin1 小时前
Spring Ai 从Demo到搭建套壳项目(二)实现deepseek+MCP client让高德生成昆明游玩4天攻略
人工智能·spring boot·后端