【大模型手撕】pytorch实现LayerNorm, RMSNorm

LayerNorm介绍请参考:【AI知识】归一化、批量归一化 、 层归一化 和 实例归一化

RMSNorm介绍请参考:【大模型知识点】RMSNorm(Root Mean Square Normalization)均方根归一化

LayerNorm实现:

python 复制代码
import torch 
import torch.nn as nn


class LayerNorm(nn.Module):
    def __init__(self,dim,eps=1e-5,bias=False):
        super().__init__()
        self.dim = dim
        self.eps = eps
        # 可训练的缩放参数
        self.gamma = nn.Parameter(torch.ones(dim))

        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
    
    def forward(self,x):
        # x: (batch_size,seq_len,dim)
        # 计算均值 x_mean : (batch_size,seq_len,dim)
        x_mean = x.mean(-1,keepdim=True)
        # 计算均方根 rms :  (batch_size,seq_len,dim)
        rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)

        if self.bias:
            return self.gamma*((x-x_mean)/rms)+self.bias
        else:
            return self.gamma*((x-x_mean)/rms)

RMSNorm实现:

python 复制代码
import torch 
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self,dim,eps=1e-5,bias=False):
   		super().__init__()
        self.dim = dim 
        self.eps = eps
        # 可训练的缩放参数
        self.gamma = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
    def forward(self,x):
        # 计算输入的均方根
        # x: (batch_size,seq_len,dim)
        # .mean(-1,keepdim=True) : 在最后一个维度(特征维度)上计算平均值,并保持维度不变
        # rms : (batch_size,seq_len,1)
        rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)

        if self.bias:
            return self.gamma*(x/rms) + self.bias
        else:
            return self.gamma*(x/rms)
相关推荐
枫叶丹49 小时前
ModelEngine应用编排创新实践:通过可视化编排构建大模型应用工作流
开发语言·前端·人工智能·modelengine
轻竹办公PPT9 小时前
AI 自动生成 2026 年工作计划 PPT,哪种更接近可交付
人工智能·python·powerpoint
dagouaofei9 小时前
2026 年工作计划 PPT 框架怎么搭?AI 一步完成
python·powerpoint
zhongtianhulian9 小时前
江苏物联网平台价格解析:5大方案报价与选型指南,助您精准控制
python
亿信华辰软件9 小时前
金融租赁行业迎监管新考:EAST 2.0制度深度解读与高效合规之道
人工智能·金融
lisw059 小时前
AI宠物(AI pets)概述!
人工智能·机器人·宠物
白云千载尽9 小时前
LLaMA-Factory 入门(一):Ubuntu20 下大模型微调与部署
人工智能·算法·大模型·微调·llama
zandy10119 小时前
指标管理的AI自治之路:衡石平台如何实现异常检测、血缘分析与智能推荐的自动化治理
运维·人工智能·自动化·指标·指标管理
Coder_Boy_9 小时前
Spring AI 源码核心分析
java·人工智能·spring
net3m339 小时前
websocket下发mp3帧数据时一个包被分包为几个子包而导致mp3解码失败而播放卡顿有杂音或断播的解决方法
开发语言·数据库·python