llama源码学习·model.py[1]RMSNorm归一化

一、model.py中的RMSNorm源码

python 复制代码
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
 
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

二、RMSNorm原理

归一化(Normalization)通常指的是将数据按比例缩放,使之落入一个小的特定区间,如0到1。这

个过程通常用于在不同特征或数据点之间建立一致性,以便它们可以在相同的尺度上比较或处理。

在深度学习中,归一化有助于加快训练速度,提高模型性能,因为它确保了不同特征在训练过程中具

有相似的分布。

RMSNorm的基本思想是对网络层的激活输出进行归一化,以使它们具有统一的规模

(scale),这样做可以加速训练过程并提高模型的稳定性。

R M S N o r m ( x ) = x 1 d ∑ i = 1 d x i 2 + ϵ RMSNorm(x) = \frac{x}{\sqrt{\frac{1}{d} \sum^{d}_{i=1} x_i^2 + \epsilon}} RMSNorm(x)=d1∑i=1dxi2+ϵ x

  • x x x 是网络层的原始输出向量
  • d d d 是输出向量的维度
  • x i x_i xi 是输出向量中的第 i i i 个元素
  • ϵ \epsilon ϵ 是一个很小的常数,用来防止除以 0 0 0 ,通常是 1 0 − 7 10^{-7} 10−7 这样的小数,增加数值稳定性

在某些深层网络和序列模型中效果显著,但 R M S N o r m RMSNorm RMSNorm 可能不适用于任何类型的网络

三、源码注释

python 复制代码
class RMSNorm(torch.nn.Module):
    def __i
    nit__(self, dim: int, eps: float = 1e-6):
        # 初始化:dim------维度 eps------epsilon
        super().__init__()
        self.eps = eps
        # weight,一个可以学习的权重,初始值1,维度与输出向量相同
        self.weight = nn.Parameter(torch.ones(dim))
 
    def _norm(self, x):
        # rsqrt------开方分之一
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
    def forward(self, x):
        # type_as------确保归一化的结果和输入x有相同的数据类型
        output = self._norm(x.float()).type_as(x)
        # 将归一化的输出乘以权重参数,得到最终的输出
        return output * self.weight

四、举例说明

  1. 构造输入x

    python 复制代码
    x = [[1, 2], [5, 6]]
    x_tensor = torch.tensor(x, dtype = torch.float)
    x_tensor

    tensor([[1., 2.],

    ​ [5., 6.]])

  2. 对输入数据求平方

    python 复制代码
    x_square = x_tensor.pow(2)
    x_square

    tensor([[ 1., 4.],

    ​ [25., 36.]])

  3. 沿着最后一个维度计算平均值

    1 d ∑ i = 1 d x i 2 \frac{1}{d} \sum^{d}_{i=1} x_i^2 d1i=1∑dxi2

    • d d d 是维度

    对于每一个样本来说,先求出每一个特征的平方,在计算样本平方的均值

    python 复制代码
    x_square_mean = x_square.mean(-1, keepdim=True)
    x_square_mean

    tensor([[ 2.5000],

    ​ [30.5000]])

  4. 计算均方根的倒数

    python 复制代码
    eps = 1e-6
    rsqrt = 1.0 / torch.sqrt(x_square_mean + eps)
    rsqrt

    tensor([[0.6325],

    ​ [0.1811]])

  5. 输入数据,得到归一化的结果

    python 复制代码
    normalized_x = x_tensor * rsqrt
    normalized_x

    tensor([[0.6325, 1.2649],

    ​ [0.9054, 1.0864]])

  6. 假设weight = [2, 3],那么最后的输出将是

    python 复制代码
    weight = torch.tensor([2,3], dtype=torch.float)
    output = normalized_x * weight
    output

    tensor([[1.2649, 3.7947],

    ​ [1.8107, 3.2593]])

相关推荐
LeonDL1684 分钟前
HALCON 深度学习训练 3D 图像的几种方式优缺点
人工智能·python·深度学习·3d·halcon·halcon训练3d图像·深度学习训练3d图像
慧都小妮子1 小时前
跨平台浏览器集成库JxBrowser 支持 Chrome 扩展程序,高效赋能 Java 桌面应用
开发语言·python·api·jxbrowser·chrome 扩展程序
tanyyinyu2 小时前
Python函数参数详解:从位置参数到灵活调用的艺术
运维·开发语言·python
qq_214782612 小时前
mac下通过anaconda安装Python
python·macos·jupyter
junyuz3 小时前
Dify docker内网部署常见问题记录
python·docker
@HNUSTer3 小时前
Python数据可视化科技图表绘制系列教程(一)
python·数据可视化·科技论文·专业制图·科研图表
reasonsummer4 小时前
【办公类-48-04】202506每月电子屏台账汇总成docx-5(问卷星下载5月范围内容,自动获取excel文件名,并转移处理)
python·excel
AmazingKO4 小时前
5分钟申请edu邮箱【方案本周有效】
python·chatgpt·ai编程·竹相左边·edu教育邮箱
幸存者1554 小时前
从零开始:亲手搭建你的第一个AI Agent(简单上手,先跑起来!)
python
点云SLAM4 小时前
Python中os模块详解
开发语言·前端·人工智能·python·计算机视觉