【大模型手撕】pytorch实现LayerNorm, RMSNorm

LayerNorm介绍请参考:【AI知识】归一化、批量归一化 、 层归一化 和 实例归一化

RMSNorm介绍请参考:【大模型知识点】RMSNorm(Root Mean Square Normalization)均方根归一化

LayerNorm实现:

python 复制代码
import torch 
import torch.nn as nn


class LayerNorm(nn.Module):
    def __init__(self,dim,eps=1e-5,bias=False):
        super().__init__()
        self.dim = dim
        self.eps = eps
        # 可训练的缩放参数
        self.gamma = nn.Parameter(torch.ones(dim))

        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
    
    def forward(self,x):
        # x: (batch_size,seq_len,dim)
        # 计算均值 x_mean : (batch_size,seq_len,dim)
        x_mean = x.mean(-1,keepdim=True)
        # 计算均方根 rms :  (batch_size,seq_len,dim)
        rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)

        if self.bias:
            return self.gamma*((x-x_mean)/rms)+self.bias
        else:
            return self.gamma*((x-x_mean)/rms)

RMSNorm实现:

python 复制代码
import torch 
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self,dim,eps=1e-5,bias=False):
   		super().__init__()
        self.dim = dim 
        self.eps = eps
        # 可训练的缩放参数
        self.gamma = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
    def forward(self,x):
        # 计算输入的均方根
        # x: (batch_size,seq_len,dim)
        # .mean(-1,keepdim=True) : 在最后一个维度(特征维度)上计算平均值,并保持维度不变
        # rms : (batch_size,seq_len,1)
        rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)

        if self.bias:
            return self.gamma*(x/rms) + self.bias
        else:
            return self.gamma*(x/rms)
相关推荐
结局无敌2 分钟前
深度探究cann仓库下的infra:AI计算的底层基础设施底座
人工智能
m0_466525292 分钟前
绿盟科技风云卫AI安全能力平台成果重磅发布
大数据·数据库·人工智能·安全
慢半拍iii4 分钟前
从零搭建CNN:如何高效调用ops-nn算子库
人工智能·神经网络·ai·cnn·cann
java干货6 分钟前
为什么 “File 10“ 排在 “File 2“ 前面?解决文件名排序的终极算法:自然排序
开发语言·python·算法
机器懒得学习8 分钟前
智能股票分析系统
python·深度学习·金融
毕设源码-郭学长8 分钟前
【开题答辩全过程】以 基于python的二手房数据分析与可视化为例,包含答辩的问题和答案
开发语言·python·数据分析
晟诺数字人8 分钟前
2026年海外直播变革:数字人如何改变游戏规则
大数据·人工智能·产品运营
蛋王派9 分钟前
DeepSeek-OCR-v2 模型解析和部署应用
人工智能·ocr
SR_shuiyunjian11 分钟前
Python第三次作业
python
vx_biyesheji000112 分钟前
豆瓣电影推荐系统 | Python Django 协同过滤 Echarts可视化 深度学习 大数据 毕业设计源码
大数据·爬虫·python·深度学习·django·毕业设计·echarts