RMSNorm 类中引入一些参数

RMSNorm 类中,引入可学习的参数,以增强模型的表达能力和适应性。以下是一些常见的方法:

  1. 可学习的缩放参数(Scale)

    除了 self.weight,可以为每个维度引入一个可学习的缩放参数。这可以通过创建一个与输入维度相同的权重矩阵来实现,而不是一个向量。这样,每个特征维度都会有一个独立的缩放因子。

    python 复制代码
    class RMSNorm(torch.nn.Module):
        def __init__(self, dim: int, eps: float = 1e-6):
            super().__init__()
            self.eps = eps
            self.weight = nn.Parameter(torch.ones((dim, 1)))  # 权重矩阵
    
        def forward(self, x):
            normed = self._norm(x)
            return normed * self.weight
  2. 可学习的偏移参数(Shift)

    除了缩放,还可以为每个维度引入一个可学习的偏移参数。这可以通过添加一个与 self.weight 类似的权重矩阵来实现,但用于添加到归一化后的输出上。

    python 复制代码
    class RMSNorm(torch.nn.Module):
        def __init__(self, dim: int, eps: float = 1e-6):
            super().__init__()
            self.eps = eps
            self.scale = nn.Parameter(torch.ones((dim, 1)))  # 缩放权重矩阵
            self.shift = nn.Parameter(torch.zeros((dim, 1)))  # 偏移权重矩阵
    
        def forward(self, x):
            normed = self._norm(x)
            return normed * self.scale + self.shift
  3. 可学习的归一化参数(Custom Normalization)

    可以设计一个自定义的归一化函数,其中包含可学习的参数。例如,可以学习一个参数来控制归一化过程中的动态范围。

python 复制代码
import torch
import torch.nn as nn

class CustomNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(CustomNorm, self).__init__()
        # 可学习的缩放参数 gamma,初始化为1
        self.gamma = nn.Parameter(torch.ones(num_features))
        # 可选的可学习偏移参数 beta,初始化为0
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.eps = eps

    def forward(self, x):
        # 计算均值和方差
        mean = x.mean(1, keepdim=True)
        var = x.var(1, keepdim=True)

        # 归一化
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        # 应用可学习的缩放和偏移
        x_out = self.gamma * x_norm + self.beta

        return x_out

# 示例使用
num_features = 10  # 假设输入特征的维度为10
custom_norm_layer = CustomNorm(num_features)

# 假设有一个随机生成的输入张量
input_tensor = torch.randn(5, num_features)  # 5个样本,每个样本有10个特征

# 前向传播
output_tensor = custom_norm_layer(input_tensor)
print(output_tensor)
  1. 可学习的激活函数参数

    在归一化之后,可以引入一个可学习的激活函数,其参数也可以是可训练的。这可以通过使用 nn.functional 中的激活函数,并将可学习参数作为激活函数的输入。

    python 复制代码
    class RMSNorm(torch.nn.Module):
        def __init__(self, dim: int, eps: float = 1e-6):
            super().__init__()
            self.eps = eps
            self.activation_param = nn.Parameter(torch.ones(1))  # 可学习的激活函数参数
    
        def forward(self, x):
            normed = self._norm(x)
            return torch.tanh(self.activation_param * normed)  # 使用tanh激活函数
相关推荐
(・Д・)ノ1 小时前
python打卡day22
python
BioRunYiXue1 小时前
一文了解氨基酸的分类、代谢和应用
人工智能·深度学习·算法·机器学习·分类·数据挖掘·代谢组学
achene_ql1 小时前
深入探索 RKNN 模型转换之旅
python·目标检测·rk3588·模型部署·rk3566
@十八子德月生2 小时前
8天Python从入门到精通【itheima】-1~5
大数据·开发语言·python·学习
每天一个秃顶小技巧3 小时前
02.Golang 切片(slice)源码分析(一、定义与基础操作实现)
开发语言·后端·python·golang
Blossom.1184 小时前
低代码开发:开启软件开发的新篇章
人工智能·深度学习·安全·低代码·机器学习·计算机视觉·数据挖掘
安特尼4 小时前
招行数字金融挑战赛数据赛道赛题一
人工智能·python·机器学习·金融·数据分析
serve the people4 小时前
解决osx-arm64平台上conda默认源没有提供 python=3.7 的官方编译版本的问题
开发语言·python·conda
机器学习之心5 小时前
SHAP分析!Transformer-GRU组合模型SHAP分析,模型可解释不在发愁!
深度学习·gru·transformer·shap分析
RK_Dangerous5 小时前
【深度学习】计算机视觉(18)——从应用到设计
人工智能·深度学习·计算机视觉