面试常问系列(二)-神经网络参数初始化之自注意力机制

目录

(一)、transformer中的自注意力机制为什么要除以根号d?

[1. 点积的方差问题](#1. 点积的方差问题)

[2. 缩放的作用](#2. 缩放的作用)

[3. 类比初始化方法](#3. 类比初始化方法)

[4. 实验验证](#4. 实验验证)

5.总结


(一)、transformer中的自注意力机制为什么要除以根号d?

在Transformer的自注意力机制中,除以根号维度(即,其中是键向量的维度)的核心目的是为了稳定梯度,避免点积结果因维度增大而爆炸。以下是详细解释:

1. 点积的方差问题

  • 假设:查询向量(Q)和键向量(K)的每个分量独立且服从均值为0、方差为1的分布。
  • 点积计算 的结果是个独立分量的和(因为维度是d的点积,所以就是q1*k1+...+qn*kn)。根据独立变量的方差叠加性质,点积的方差为,标准差为
  • 问题 :当较大时(如512或1024),点积结果可能分布在的宽区间内。若直接输入Softmax,其梯度会因输入值过大而趋近于0(梯度饱和),导致训练困难。
    • 这里我在赘述一下
      • 随着d的增大,点积的范围可能变大,举例子
        • d=10,点击范围可能[-10,10]波动
        • d=1000,点击范围可能[-30,30]波动,甚至更高。
        • 讲到这里肯定有人好奇,这个为什么呢?除了直观上的加和项变多,该怎么理解呢?
          • 首选可以参考我上面的文章面试常问系列(一)-神经网络参数初始化-CSDN博客,会更好理解,下面我简单说一下。
          • 这个的方差就是上面说的,d个独立分量的和
          • 上面的公式是核心,不理解的可以再去看一下我的链接。这也就是说,新的标准差范围取决于维度个数!!!
    • 还有什么问题呢?
      • 1>**数值不稳定:**softmax的公式分子是e的x次幂,分母是求和,如果包含大树,e的指数运算会直接溢出。如果x = [1000,10,10],则1000就直接溢出了。

      • 2>权重分布退化: softmax的公式分子是e的x次幂,分母是求和,如果包含大数,比如,x = [100,10,10] 那么softmax之后的值为

        复制代码
        [1.0, 8.194012623990515e-40, 8.194012623990515e-40]这个结果应该比较直观了,某个权重接近1,其他权重接近0,注意力机制,将会只关注极个别特征,变成hard模式,而没有那么soft!
      • 3>训练困难 :在初始化阶段,过大的点积使 Softmax 输出极端化(接近0或1),梯度趋于零,导致梯度消失,阻碍模型收敛。不太理解的,可以去看下面试常问系列(一)-神经网络参数初始化-CSDN博客

2. 缩放的作用

  • 缩放公式 :将点积结果除以,即Softmax()。
  • 方差控制 :缩放后,点积的方差变为,标准差为1。这使得点积结果分布在更合理的区间(如-5到+5),Softmax的输入不会因维度增大而爆炸。
  • 梯度稳定性:Softmax的梯度在输入值适中时较大,缩放后梯度更稳定,模型训练更高效。

3. 类比初始化方法

  • 类似He初始化(考虑前一层神经元数量调整权重方差),此处通过缩放调整点积的方差,确保网络各层的输入分布稳定。

4. 实验验证

  • 原论文通过实验表明,缩放后模型收敛更快,性能更优。若省略缩放(即直接计算Softmax(Q⋅KT)),训练过程可能因梯度消失而失败。

5.总结

除以dk​​的本质是对点积结果进行方差归一化,确保Softmax的输入值不会随维度增大而失控,从而保持梯度稳定,提升训练效率和模型性能。这一设计是Transformer高效训练的关键细节之一。

6.赘述

后面这一部分,还会再出几个常见面试题,希望大家理论结合例子和实战,深入理解,加油!

相关推荐
生信碱移10 小时前
细胞内与细胞间网络整合分析!神经网络+细胞通讯,这个单细胞分析工具一箭双雕了(scTenifoldXct)
人工智能·经验分享·深度学习·神经网络·机器学习·数据分析·数据可视化
WHATEVER_LEO14 小时前
【每日论文】MetaSpatial: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse
人工智能·深度学习·神经网络·计算机视觉·3d·自然语言处理
橙色小博14 小时前
最最最基本神经网络及其原理、程序
人工智能·深度学习·神经网络
驼驼学编程15 小时前
卷积神经网络
人工智能·神经网络·cnn
云空16 小时前
《Keras 3 :AI 使用图神经网络和 LSTM 进行交通流量预测》
人工智能·神经网络·keras
豆芽81920 小时前
神经网络检测题
人工智能·python·深度学习·神经网络·学习
m0_7480385621 小时前
跟着StatQuest学知识08-RNN与LSTM
人工智能·rnn·深度学习·神经网络·机器学习·cnn·lstm
Ling_Ze1 天前
从图神经网络入门到gcn+lstm
人工智能·神经网络·lstm
点我头像干啥1 天前
卷积神经网络在图像分割中的应用:原理、方法与进展介绍
人工智能·神经网络·cnn
牛andmore牛2 天前
3、fabric实现多机多卡训练
深度学习·神经网络·fabric·fabric多机多卡