【大模型】RMS Normalization原理及实现

1.RMS Normalization的原理

说RMS Normalization之前,先讲Layer Normalization 和 Batch Normalization。

BN和LN是两种常见的归一化方法。它们的目的是帮助模型训练更稳定、收敛更快。BN是在Batch维度上进行归一化,即对同一batch中每个特征维度的值进行归一化。LN则是在层的维度上进行归一化,即对每一个样本的特征进行归一化。

RMS Normalization属于LN。

再来说RMS Normalization和Layer Normalization。

Layer Normalization:利用均值方差对特征进行归一化。

RMS Normalization:利用均方根对特征进行归一化。

LLaMA架构中采用RMS Normalization的原因是通过只计算均方根,从而减少计算量,同时在实验中也确实获得了更加稳定的训练。

在这里插入一点NLP任务中,对于将特征进行"归一化"目的的一些个人小理解:在NLP中,使用Layer Normalization进行归一化是为了使输入特征在每一层的神经元中保持稳定的分布,避免特征值之间出现过大的波动。通过归一化,Layer Normalization 将特征重新调整为均值为 0、方差为 1 的分布,从而让模型的训练更加稳定和高效,使得数据变得更加**"平滑"** 。这里的"平滑"是指数值的尺度更一致、更稳定 ,不会有特别大的数值差异,能够防止特征值在网络层中传递时变得过大或过小。这种一致性有助于缓解模型训练中的一些问题,如梯度爆炸梯度消失 ,并能让模型更容易优化。在使用RMS Normalization进行归一化则是直接使特征本身的数值变得更加"平滑"。

2.RMS Normalization公式

2.RMS Normalization的实现

该函数在神经网络中需要对输入的数据进行处理,再输出相应的处理好的数据,对应的实现方式就用层来实现

因为RMS Normalization属于LN,所以,x-->[batch_size, hidden_states]

python 复制代码
import torch


class RMSNorm(torch.nn.Module):  # nn.Module是所有层的父类,层元素就必须继承nn.Module
    def __init__(self, dim, eps):  # 用于储存层的元素
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(dim))  # 初始化权重参数
        self.eps = eps  # 防止根号下为0

    def _norm(self, x):  # 定义类函数里的方法("_"表示只在该类的内部调用)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        # x.pow(2):求平方
        # x.pow(2).mean(-1, keepdim=True):所有的平方求一个均值
        # x.pow(2).mean(-1, keepdim=True) + self.eps:加上一个防止根号下为0的元素
        # torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps):开平方再求导
        # rsqrt(x) = 1 / sqrt(x)
        # x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps):最后用求得的导数乘以x

    def forward(self, x):  # 数据流
        output = self._norm(x.float().type_as(x))  # 将x变成浮点数进行归一化,并保持x原始的数据类型
        return output * self.weight  # 将归一化后的输出乘以可学习的参数 weight,调整每一个维度的缩放


if __name__ == '__main__':

    batch_size = 1
    dim = 4  # 特征维度
    x = torch.Tensor([0.1, 0.1, 0.2, 0.3])
    # 初始化RMSNorm对象
    rms_norm = RMSNorm(dim=dim, eps=0)
    output = rms_norm(x)

    print("输入数据: \n", x)
    print("RMSNorm输出数据: \n", output)
相关推荐
qq 6178276752 分钟前
PIC16F648A-I/SS 8位微控制器 -MCU 微芯 PIC16F648AT-I/SS 8位微控制器 -MCU 在售公司有完全可替代PIC16F648
人工智能·单片机·嵌入式硬件·车载系统·自动驾驶
HyperAI超神经4 分钟前
贝式计算的 AI4S 观察:使用机器学习对世界进行感知与推演,最大魅力在于横向扩展的有效性
人工智能·深度学习·机器学习·数据集·ai4s·科研领域·工科
发菜君6 分钟前
LangChain大模型应用开发指南:打造个性化LLM
人工智能·学习·langchain·大模型·大模型学习·大模型入门·大模型教程
孙同学要努力8 分钟前
《目标检测》——基础理论知识(目标检测的数据集、评价指标:IOU、mAP、非极大抑制NMS)
人工智能·目标检测·目标跟踪
Jurio.39 分钟前
【IEEE出版,稳定检索】第六届国际科技创新学术交流大会暨信息技术与计算机应用学术会议(ITCA 2024,12月06-08)
图像处理·人工智能·科技·深度学习·机器学习·云计算·学术会议
fajianchen43 分钟前
大语言模型工作原理笔记
人工智能·笔记·语言模型
___Dream1 小时前
【MRAN】情感分析中情态缺失问题的多模态重构和对齐网络
人工智能·深度学习·机器学习·人机交互
网易智企1 小时前
“双十一”电商狂欢进行时,在AI的加持下看网易云信IM、RTC如何助力商家!
大数据·人工智能·音视频·实时音视频·娱乐·交友·教育电商
子午1 小时前
【野生动物识别系统】Python+深度学习+人工智能+卷积神经网络算法+TensorFlow+ResNet+图像识别
人工智能·python·深度学习
王哈哈^_^2 小时前
【数据集】【YOLO】【目标检测】摔跤识别数据集 5097 张,YOLO行人摔倒识别算法实战训练教程!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt