深度学习理论-直观理解 Attention

本文首先介绍 Attention 的原始公式,然后以 Self-Attention 为例,简化后逐步分析 Attention 计算结果表达的含义

Attention

Attention 公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> A t t e n t i o n = s o f t m a x ( Q ⋅ K T d k ) ⋅ V Attention = softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V </math>Attention=softmax(dk Q⋅KT)⋅V

其中 softmax 作用是归一化,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( x ) = e x ∑ i = 1 n e x i softmax(x) = \frac{e^x}{\sum_{i=1}^n{e^{x_i}}} </math>softmax(x)=∑i=1nexiex

我们将 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ⋅ K T d k \frac{Q \cdot K^T}{\sqrt{d_k}} </math>dk Q⋅KT 称为 attention score,归一化后 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( Q ⋅ K T d k ) softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) </math>softmax(dk Q⋅KT) 称为 attention weight

Self-Attention

在 Self-Attention 中,输入为 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X ,乘以不同的权重矩阵,就得到了不同的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V

<math xmlns="http://www.w3.org/1998/Math/MathML"> Q = X ⋅ W q Q = X \cdot W_q </math>Q=X⋅Wq

<math xmlns="http://www.w3.org/1998/Math/MathML"> K = X ⋅ W k K = X \cdot W_k </math>K=X⋅Wk

<math xmlns="http://www.w3.org/1998/Math/MathML"> V = X ⋅ W v V = X \cdot W_v </math>V=X⋅Wv

为了方便理解,我们先做简化,把权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k , W v W_q, W_k, W_v </math>Wq,Wk,Wv 和缩放因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 都假设为 1

简化后,Self-Attention 长这样

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( X ⋅ X T ) ⋅ X softmax(X\cdot X^T) \cdot X </math>softmax(X⋅XT)⋅X

1. Attention Score

首先来看 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ⋅ X T X \cdot X^T </math>X⋅XT 的含义,我们先复习一下,向量内积表示的是两个向量的相关性

假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × d n \times d </math>n×d 的矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是输入的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是特征的维度, <math xmlns="http://www.w3.org/1998/Math/MathML"> X ⋅ X T X \cdot X^T </math>X⋅XT 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × n n \times n </math>n×n 的矩阵,表示输入的每个元素,与其它元素的相关性

2. Attention Weight

softmax 就是做归一化,使得权重的和为 1,表达的含义跟 score 一致,相关性高的权重也高,通过非线性函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> e x e^x </math>ex 后,变成了概率分布

3. Attention Value

用相关性矩阵乘以输入向量,得到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × d n \times d </math>n×d 的矩阵,跟输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 的尺度一致,含义也一致,依然表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个输入变量 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 维特征,但这个特征已经是经过注意力加权的特征,相关性更高的元素响应更高。

看到这里,相信你对 attention 机制已经有了直观的理解。下面就把之前简化的细枝末节加回来。

4. 权重矩阵

权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k , W v W_q, W_k, W_v </math>Wq,Wk,Wv 都是可训练的参数,具有以下作用

  • 使用不同的权重矩阵,可以提升模型的表达能力。
  • 通过调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k W_q, W_k </math>Wq,Wk,模型可以学习到不同的注意力模式,使得某些输入 token 之间的关联更强或更弱。
  • 权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W v W_v </math>Wv 影响模型如何聚合信息,使得某些 token 在最终表示中占更重要的比重。

5. 缩放因子

dk​​ 作为缩放因子有如下两个作用:

5.1 防止数值过大,避免梯度消失或梯度爆炸

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 是两个向量的点积,其值范围随着维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 增大而增大。
  • 如果 不除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ​​,那么较大维度时,点积的数值会变得非常大,导致 softmax 结果变得极端(接近 0 或 1),从而导致梯度消失,影响训练稳定性。
  • 除以 dkdk 后,使得点积值的范围保持在适当区间,从而让 softmax 更平滑。

5.2 保持不同维度下的数值稳定性

  • 在深度学习中,通常希望输入数据的方差保持在一个稳定范围,否则网络难以收敛。
  • 设 Query 和 Key 向量的分量服从均值为 0、方差为 1 的标准正态分布,则点积的期望和方差为: <math xmlns="http://www.w3.org/1998/Math/MathML"> E [ Q K T ] = 0 , V a r [ Q K T ] = d k E[QK^T]=0,Var[QK^T]=d_k </math>E[QKT]=0,Var[QKT]=dk
  • 这意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 越大,点积值的方差也会随之增大,从而影响 softmax 的输出。
  • 除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 后,点积值的方差变为 1,保持了数值稳定性,使不同维度的注意力机制都能较好地工作。

复杂度

假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 的维度分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × d N \times d </math>N×d, <math xmlns="http://www.w3.org/1998/Math/MathML"> M × d M \times d </math>M×d, <math xmlns="http://www.w3.org/1998/Math/MathML"> M × d M \times d </math>M×d

  • 时间复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)

    • 计算 Q, K, V: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N d 2 ) O(Nd^2) </math>O(Nd2) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( M d 2 ) O(Md^2) </math>O(Md2) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( M d 2 ) O(Md^2) </math>O(Md2)(线性变换)
    • 计算 Attention-Score <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)
    • 计算 Softmax: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M ) O(NM) </math>O(NM)
    • 计算加权求和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( Q K T ) V softmax(QK^T)V </math>softmax(QKT)V: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)
    • 总体上,主要瓶颈是 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 的计算和加权求和,因此时间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)。
  • 空间复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M ) O(NM) </math>O(NM)

    • 由于需要存储 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT(一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × M N×M </math>N×M 的矩阵),因此空间复杂度是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M ) O(NM) </math>O(NM)。

对于 Self-Attention,由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = M N=M </math>N=M,时间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 d ) O(N^2d) </math>O(N2d),空间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) O(N^2) </math>O(N2)

思考

如果想屏蔽某些特征,应该如何做?mask 是怎样实现的?

Google T5 不除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 为什么也能够收敛?

参考资料

相关推荐
果冻人工智能2 分钟前
LLM 的注意力黑洞:为什么第一个 Token 吸走了所有注意力?
人工智能
意.远3 分钟前
批量归一化(Batch Normalization)原理与PyTorch实现
人工智能·pytorch·python·深度学习·神经网络·分类·batch
掘金安东尼9 分钟前
DeepSeek-R1 全托管无服务器上线亚马逊云 Bedrock,为何值得关注?
人工智能·llm
Love绘梨衣的Mr.lu9 分钟前
【benepar】benepar安装会自动更新pytorch
人工智能·pytorch·python
果冻人工智能17 分钟前
AI争霸新拐点:谷歌靠Gemini 2.5能翻盘吗?
人工智能
蹦蹦跳跳真可爱58919 分钟前
Python----机器学习(基于PyTorch的蘑菇逻辑回归)
开发语言·人工智能·pytorch·python·机器学习·逻辑回归
掘金安东尼43 分钟前
颠覆 LLM?Meta 提出 LCM 这个新范式
人工智能·llm
Goboy1 小时前
Java版的深度学习 · 手撕 DeepLearning4J实现手写数字识别 (附UI效果展示)
llm·aigc·ai编程
cnbestec1 小时前
介电弹性体传感器如何实现高灵敏度?Delfa多层结构设计详解
人工智能