【大模型手撕】pytorch实现LayerNorm, RMSNorm

LayerNorm介绍请参考:【AI知识】归一化、批量归一化 、 层归一化 和 实例归一化

RMSNorm介绍请参考:【大模型知识点】RMSNorm(Root Mean Square Normalization)均方根归一化

LayerNorm实现:

python 复制代码
import torch 
import torch.nn as nn


class LayerNorm(nn.Module):
    def __init__(self,dim,eps=1e-5,bias=False):
        super().__init__()
        self.dim = dim
        self.eps = eps
        # 可训练的缩放参数
        self.gamma = nn.Parameter(torch.ones(dim))

        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
    
    def forward(self,x):
        # x: (batch_size,seq_len,dim)
        # 计算均值 x_mean : (batch_size,seq_len,dim)
        x_mean = x.mean(-1,keepdim=True)
        # 计算均方根 rms :  (batch_size,seq_len,dim)
        rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)

        if self.bias:
            return self.gamma*((x-x_mean)/rms)+self.bias
        else:
            return self.gamma*((x-x_mean)/rms)

RMSNorm实现:

python 复制代码
import torch 
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self,dim,eps=1e-5,bias=False):
   		super().__init__()
        self.dim = dim 
        self.eps = eps
        # 可训练的缩放参数
        self.gamma = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
    def forward(self,x):
        # 计算输入的均方根
        # x: (batch_size,seq_len,dim)
        # .mean(-1,keepdim=True) : 在最后一个维度(特征维度)上计算平均值,并保持维度不变
        # rms : (batch_size,seq_len,1)
        rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)

        if self.bias:
            return self.gamma*(x/rms) + self.bias
        else:
            return self.gamma*(x/rms)
相关推荐
炘东5921 分钟前
vscode连接算力平台
pytorch·vscode·深度学习·gpu算力
ZHOU_WUYI27 分钟前
构建实时网络速度监控面板:Python Flask + SSE 技术详解
网络·python·flask
春末的南方城市35 分钟前
复旦&华为提出首个空间理解和生成统一框架UniUGG,支持参考图像和任意视图变换的 3D 场景生成和空间视觉问答 (VQA) 任务。
人工智能·科技·深度学习·计算机视觉·aigc
chinesegf36 分钟前
conda虚拟环境直接复制依赖包可能会报错
python·conda
开心-开心急了1 小时前
PySide6 打印(QPrinter)文本编辑器(QPlaintextEdit)内容
python·ui·pyqt
坐吃山猪1 小时前
Python-UV多环境管理
人工智能·python·uv
带娃的IT创业者1 小时前
从零开始掌握 uv:新一代超快 Python 项目与包管理器(含 Windows 支持)
windows·python·uv
努力也学不会java1 小时前
【Java并发】揭秘Lock体系 -- condition等待通知机制
java·开发语言·人工智能·机器学习·juc·condition
浔川python社1 小时前
浔川 AI 翻译 v7.0正式上线公告
python
Wiktok2 小时前
【WIT】ttkbootstrap全组件中文本地化解决方案
python·ttkbootstrap