深度学习理论-直观理解 Attention

本文首先介绍 Attention 的原始公式,然后以 Self-Attention 为例,简化后逐步分析 Attention 计算结果表达的含义

Attention

Attention 公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> A t t e n t i o n = s o f t m a x ( Q ⋅ K T d k ) ⋅ V Attention = softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V </math>Attention=softmax(dk Q⋅KT)⋅V

其中 softmax 作用是归一化,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( x ) = e x ∑ i = 1 n e x i softmax(x) = \frac{e^x}{\sum_{i=1}^n{e^{x_i}}} </math>softmax(x)=∑i=1nexiex

我们将 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ⋅ K T d k \frac{Q \cdot K^T}{\sqrt{d_k}} </math>dk Q⋅KT 称为 attention score,归一化后 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( Q ⋅ K T d k ) softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) </math>softmax(dk Q⋅KT) 称为 attention weight

Self-Attention

在 Self-Attention 中,输入为 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X ,乘以不同的权重矩阵,就得到了不同的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V

<math xmlns="http://www.w3.org/1998/Math/MathML"> Q = X ⋅ W q Q = X \cdot W_q </math>Q=X⋅Wq

<math xmlns="http://www.w3.org/1998/Math/MathML"> K = X ⋅ W k K = X \cdot W_k </math>K=X⋅Wk

<math xmlns="http://www.w3.org/1998/Math/MathML"> V = X ⋅ W v V = X \cdot W_v </math>V=X⋅Wv

为了方便理解,我们先做简化,把权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k , W v W_q, W_k, W_v </math>Wq,Wk,Wv 和缩放因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 都假设为 1

简化后,Self-Attention 长这样

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( X ⋅ X T ) ⋅ X softmax(X\cdot X^T) \cdot X </math>softmax(X⋅XT)⋅X

1. Attention Score

首先来看 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ⋅ X T X \cdot X^T </math>X⋅XT 的含义,我们先复习一下,向量内积表示的是两个向量的相关性

假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × d n \times d </math>n×d 的矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是输入的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是特征的维度, <math xmlns="http://www.w3.org/1998/Math/MathML"> X ⋅ X T X \cdot X^T </math>X⋅XT 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × n n \times n </math>n×n 的矩阵,表示输入的每个元素,与其它元素的相关性

2. Attention Weight

softmax 就是做归一化,使得权重的和为 1,表达的含义跟 score 一致,相关性高的权重也高,通过非线性函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> e x e^x </math>ex 后,变成了概率分布

3. Attention Value

用相关性矩阵乘以输入向量,得到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × d n \times d </math>n×d 的矩阵,跟输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 的尺度一致,含义也一致,依然表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个输入变量 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 维特征,但这个特征已经是经过注意力加权的特征,相关性更高的元素响应更高。

看到这里,相信你对 attention 机制已经有了直观的理解。下面就把之前简化的细枝末节加回来。

4. 权重矩阵

权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k , W v W_q, W_k, W_v </math>Wq,Wk,Wv 都是可训练的参数,具有以下作用

  • 使用不同的权重矩阵,可以提升模型的表达能力。
  • 通过调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k W_q, W_k </math>Wq,Wk,模型可以学习到不同的注意力模式,使得某些输入 token 之间的关联更强或更弱。
  • 权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W v W_v </math>Wv 影响模型如何聚合信息,使得某些 token 在最终表示中占更重要的比重。

5. 缩放因子

dk​​ 作为缩放因子有如下两个作用:

5.1 防止数值过大,避免梯度消失或梯度爆炸

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 是两个向量的点积,其值范围随着维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 增大而增大。
  • 如果 不除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ​​,那么较大维度时,点积的数值会变得非常大,导致 softmax 结果变得极端(接近 0 或 1),从而导致梯度消失,影响训练稳定性。
  • 除以 dkdk 后,使得点积值的范围保持在适当区间,从而让 softmax 更平滑。

5.2 保持不同维度下的数值稳定性

  • 在深度学习中,通常希望输入数据的方差保持在一个稳定范围,否则网络难以收敛。
  • 设 Query 和 Key 向量的分量服从均值为 0、方差为 1 的标准正态分布,则点积的期望和方差为: <math xmlns="http://www.w3.org/1998/Math/MathML"> E [ Q K T ] = 0 , V a r [ Q K T ] = d k E[QK^T]=0,Var[QK^T]=d_k </math>E[QKT]=0,Var[QKT]=dk
  • 这意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 越大,点积值的方差也会随之增大,从而影响 softmax 的输出。
  • 除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 后,点积值的方差变为 1,保持了数值稳定性,使不同维度的注意力机制都能较好地工作。

复杂度

假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 的维度分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × d N \times d </math>N×d, <math xmlns="http://www.w3.org/1998/Math/MathML"> M × d M \times d </math>M×d, <math xmlns="http://www.w3.org/1998/Math/MathML"> M × d M \times d </math>M×d

  • 时间复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)

    • 计算 Q, K, V: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N d 2 ) O(Nd^2) </math>O(Nd2) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( M d 2 ) O(Md^2) </math>O(Md2) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( M d 2 ) O(Md^2) </math>O(Md2)(线性变换)
    • 计算 Attention-Score <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)
    • 计算 Softmax: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M ) O(NM) </math>O(NM)
    • 计算加权求和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( Q K T ) V softmax(QK^T)V </math>softmax(QKT)V: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)
    • 总体上,主要瓶颈是 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 的计算和加权求和,因此时间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M d ) O(NMd) </math>O(NMd)。
  • 空间复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M ) O(NM) </math>O(NM)

    • 由于需要存储 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT(一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × M N×M </math>N×M 的矩阵),因此空间复杂度是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N M ) O(NM) </math>O(NM)。

对于 Self-Attention,由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = M N=M </math>N=M,时间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 d ) O(N^2d) </math>O(N2d),空间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) O(N^2) </math>O(N2)

思考

如果想屏蔽某些特征,应该如何做?mask 是怎样实现的?

Google T5 不除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 为什么也能够收敛?

参考资料

相关推荐
jndingxin8 分钟前
OpenCV 图像哈希类cv::img_hash::AverageHash
人工智能·opencv·哈希算法
堆栈future16 分钟前
深度剖析Manus:如何打造低幻觉、高效率、安全可靠的Agentic AI系统
llm·aigc·mcp
Jamence21 分钟前
多模态大语言模型arxiv论文略读(153)
论文阅读·人工智能·语言模型·自然语言处理·论文笔记
晨曦54321029 分钟前
量子计算突破:8比特扩散模型实现指数级加速
人工智能
Albert_Lsk41 分钟前
【2025/07/11】GitHub 今日热门项目
人工智能·开源·github·开源协议
莫彩43 分钟前
【大模型推理论文阅读】Enhancing Latent Computation in Transformerswith Latent Tokens
论文阅读·人工智能·语言模型
康斯坦丁师傅44 分钟前
全球最强模型Grok4,国内已可免费使用!(附教程)
人工智能·grok
崔高杰44 分钟前
微调性能赶不上提示工程怎么办?Can Gradient Descent Simulate Prompting?——论文阅读笔记
论文阅读·人工智能·笔记·语言模型
元气小嘉1 小时前
前端技术小结
开发语言·前端·javascript·vue.js·人工智能
AI大模型1 小时前
大模型炼丹术(七):LLM微调实战:训练一个垃圾邮件分类器
程序员·llm·agent