稀疏注意力机制(ProbSparse Self-attention)
Efficient Self-attention Mechanism
经典的自注意力机制(Vaswani et al. 2017)是基于三元组输入定义的,即:查询(query)、键(key)和值(value),它执行缩放点积,如公式
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A ( Q , K , V ) = Softmax ( Q K T / d ) V \mathcal{A}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax}(\mathbf{Q}\mathbf{K}^T/\sqrt{d})\mathbf{V} </math>A(Q,K,V)=Softmax(QKT/d )V
所示,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ∈ R L q × d \mathbf{Q} \in \mathbb{R}^{L_q \times d} </math>Q∈RLq×d、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K ∈ R L k × d \mathbf{K} \in \mathbb{R}^{L_k \times d} </math>K∈RLk×d 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V ∈ R L v × d \mathbf{V} \in \mathbb{R}^{L_v \times d} </math>V∈RLv×d,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是输入维度。
为了进一步讨论自注意力机制,让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i , k i , v i \mathbf{q}_i, \mathbf{k}_i, \mathbf{v}_i </math>qi,ki,vi 分别代表 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q , K , V \mathbf{Q}, \mathbf{K}, \mathbf{V} </math>Q,K,V 中的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 行。
根据 (Tsai et al. 2019) 的定义,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个查询的注意力被定义为概率形式下的核平滑:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A ( q i , K , V ) = ∑ j k ( q i , k j ) ∑ l k ( q i , k l ) v j = E p ( k j ∣ q i ) [ v j ] ( 1 ) \mathcal{A}(\mathbf{q}i, \mathbf{K}, \mathbf{V}) = \sum{j} \frac{k(\mathbf{q}_i, \mathbf{k}j)}{\sum{l} k(\mathbf{q}_i, \mathbf{k}_l)} \mathbf{v}j = \mathbb{E}{p(\mathbf{k}_j|\mathbf{q}_i)}[\mathbf{v}_j](1) </math>A(qi,K,V)=j∑∑lk(qi,kl)k(qi,kj)vj=Ep(kj∣qi)[vj](1)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( k j ∣ q i ) = k ( q i , k j ) / ∑ l k ( q i , k l ) p(\mathbf{k}_j|\mathbf{q}_i) = k(\mathbf{q}_i, \mathbf{k}j)/\sum{l} k(\mathbf{q}_i, \mathbf{k}_l) </math>p(kj∣qi)=k(qi,kj)/∑lk(qi,kl),且 <math xmlns="http://www.w3.org/1998/Math/MathML"> k ( q i , k j ) k(\mathbf{q}_i, \mathbf{k}_j) </math>k(qi,kj) 选择非对称指数核 <math xmlns="http://www.w3.org/1998/Math/MathML"> exp ( q i k j T / d ) \text{exp}(\mathbf{q}_i \mathbf{k}_j^T/\sqrt{d}) </math>exp(qikjT/d )。
自注意力结合了values,并基于计算概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( k j ∣ q i ) p(\mathbf{k}_j|\mathbf{q}_i) </math>p(kj∣qi) 获得输出。它需要进行二次时间的点积计算和 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( L q L k ) \mathcal{O}(L_q L_k) </math>O(LqLk) 的内存使用,这是在提升预测能力时的主要缺陷。
一些先前的尝试揭示了自注意力概率分布可能具有稀疏性,并且他们设计了"选择性"计数策略来计算所有 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( k j ∣ q i ) p(\mathbf{k}_j|\mathbf{q}_i) </math>p(kj∣qi),而不显著影响性能。稀疏Transformer(Child et al. 2019)结合了行输出和列输入,其中稀疏性来自于分离的空间相关性。LogSparse Transformer(Li et al. 2019)注意到自注意力中的周期性模式,并通过指数步长强制每个单元格关注其前一个。Longformer(Beltagy, Peters, and Cohan 2020)扩展了先前的两个工作,以更复杂的稀疏配置。然而,它们受限于从以下启发式方法进行的理论分析,并以相同的策略处理每个多头自注意力,限制了它们的进一步改进。
为了激励我们的方法,我们首先对经典自注意力的学习注意力模式进行定性评估。"稀疏"自注意力得分形成长尾分布,即,少数点积对对主要注意力贡献,而其他的产生微不足道的注意力。然后,下一个问题是如何区分它们?
Query Sparsity Measurement
从公式 (1) 出发,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个查询对所有键的注意力被定义为一个概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( k j ∣ q i ) p(k_j|q_i) </math>p(kj∣qi),输出是这些概率与值 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v 的组合。主导的点积对促进了相应查询的注意力概率分布远离均匀分布。如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( k j ∣ q i ) p(k_j|q_i) </math>p(kj∣qi) 接近于均匀分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( k j ∣ q i ) = 1 / L K q(k_j|q_i) = 1/L_K </math>q(kj∣qi)=1/LK,则自注意力成为简单的值 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 的和,对输入过程而言是冗余的。本质上可以通过分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 的"相似性"来区分"重要的"查询。我们通过 Kullback-Leibler 散度来测量"相似性":
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> K L ( q ∥ p ) = ln ∑ l = 1 L K e q i k l ⊤ / d − 1 L K ∑ j = 1 L K q i k j ⊤ / d − ln L K KL(q\|p) = \ln \sum_{l=1}^{L_K} e^{q_i k_l^\top / \sqrt{d}} - \frac{1}{L_K} \sum_{j=1}^{L_K} q_i k_j^\top / \sqrt{d} - \ln L_K </math>KL(q∥p)=lnl=1∑LKeqikl⊤/d −LK1j=1∑LKqikj⊤/d −lnLK
去掉常数项后,我们定义第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个查询的稀疏性测度为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> M ( q i , K ) = ln ∑ j = 1 L K e q i k j ⊤ d − 1 L K ∑ j = 1 L K q i k j ⊤ d ( 2 ) M(q_i, \mathbf{K}) = \ln \sum_{j=1}^{L_K} e^{\frac{{q_i k_j^\top}}{\sqrt{d}}} - \frac{1}{L_K} \sum_{j=1}^{L_K} \frac{q_i k_j^\top}{\sqrt{d}} \quad (2) </math>M(qi,K)=lnj=1∑LKed qikj⊤−LK1j=1∑LKd qikj⊤(2)
其中,第一个项是 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 对所有键的 Log-Sum-Exp (LSE),第二个项是它们的算术平均值。如果第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个查询获得更大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ( q i , K ) M(q_i, \mathbf{K}) </math>M(qi,K),则其 attention probability p 更具区别性,且更有可能包含长尾 self-attention 分布的头部区域中的主要点积对。如下图红框区域:
ProbSparse Self-attention
ProbSparse 自注意力 基于所提出的度量方法,我们通过允许每个键仅关注于 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u 个主要query来实现 ProbSparse 自注意力:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A ( Q , K , V ) = Softmax ( Q ‾ K ⊤ d ) V ( 3 ) \mathcal{A}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax}\left(\frac{\overline{\mathbf{Q}}\mathbf{K}^\top}{\sqrt{d}}\right)\mathbf{V} \quad (3) </math>A(Q,K,V)=Softmax(d QK⊤)V(3)
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ‾ \overline{\mathbf{Q}} </math>Q 是与 <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q 尺寸相同的稀疏矩阵,它仅包含在稀疏性度量 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ( q , K ) M(\mathbf{q}, \mathbf{K}) </math>M(q,K) 下的前 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u 个查询。通过常数采样因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 控制,我们设置 <math xmlns="http://www.w3.org/1998/Math/MathML"> u = c ⋅ ln L Q u = c \cdot \ln L_Q </math>u=c⋅lnLQ,这使得 ProbSparse 自注意力仅需为每个查询-键查找计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( ln L Q ) \mathcal{O}(\ln L_Q) </math>O(lnLQ) 次点积,且层的内存使用维持在 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( L K ln L Q ) \mathcal{O}(L_K \ln L_Q) </math>O(LKlnLQ)。在多头机制的视角下,这种注意力为每个头生成不同的稀疏查询-键对,从而避免了严重的信息丢失。
然而,遍历所有查询以进行度量 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ( q i , K ) M(q_i, \mathbf{K}) </math>M(qi,K) 需要计算每个点积对,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( L Q L K ) \mathcal{O}(L_Q L_K) </math>O(LQLK) 的平方复杂度,此外 LSE 操作可能会引发潜在的数值稳定性问题。受此启发,我们提出了一种经验近似方法,以有效地获取查询稀疏性度量。
基于引理1(这里跳过),我们提出了最大均值测度:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> M ‾ ( q i , K ) = max j { q i k j ⊤ d } − 1 L K ∑ j = 1 L K { q i k j ⊤ d } \overline{M}(q_i, \mathbf{K}) = \max_j \left\{\frac{q_i k_j^\top}{\sqrt{d}}\right\} - \frac{1}{L_K} \sum_{j=1}^{L_K} \left\{\frac{q_i k_j^\top}{\sqrt{d}}\right\} </math>M(qi,K)=jmax{d qikj⊤}−LK1j=1∑LK{d qikj⊤}
Top-u 的范围大致在命题1 (这里跳过) 的边界放宽中成立。在长尾分布下,我们只需要随机采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> U = L K ln L Q U = L_K \ln L_Q </math>U=LKlnLQ 点积对来计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ‾ ( q i , K ) \overline{M}(q_i, \mathbf{K}) </math>M(qi,K),即把其他对填充为零。然后,我们从中选择稀疏的前 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u 个作为 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ‾ \overline{\mathbf{Q}} </math>Q。 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ‾ ( q i , K ) \overline{M}(q_i, \mathbf{K}) </math>M(qi,K) 中的最大运算符对零值不太敏感且数值稳定。在实践中,查询和键的输入长度通常在自注意力计算中是相等的,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> L Q = L K = L L_Q = L_K = L </math>LQ=LK=L,因此总的 ProbSparse 自注意力时间复杂度和空间复杂度均为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( L ln L ) \mathcal{O}(L \ln L) </math>O(LlnL)。
Encoder:引入自注意力蒸馏机制(Self-attention Distilling)
编码器:在内存使用限制下,允许处理更长序列的输入。
该编码器旨在提取长序列输入的鲁棒长程依赖性。在输入表示之后,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 个序列输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X t \mathbf{X}^t </math>Xt 被形成为一个矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> X en t ∈ R L x × d model \mathbf{X}{\text{en}}^t \in \mathbb{R}^{L_x \times d{\text{model}}} </math>Xent∈RLx×dmodel。我们在图3中给出编码器的草图以供说明。
自注意力蒸馏 作为 ProbSparse 自注意力机制的自然结果,编码器的特征图存在值 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 的冗余组合。我们使用蒸馏操作来突出具有主导特征的优越特征,并在下一层中生成一个集中的自注意力特征图。它锐化裁剪了输入的时间维度,查看图3中注意力块的 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n-头加权矩阵(重叠的红色方框)。受到稀疏卷积的启发(Yu, Koltun, and Funkhouser 2017; Gupta and Rush 2017),我们的"蒸馏"过程从第 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j 层转移到第 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( j + 1 ) (j+1) </math>(j+1) 层为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X j + 1 t = MaxPool ( ELU ( Conv1d ( [ X j t ] A B ) ) ) \mathbf{X}^{t}_{j+1} = \text{MaxPool}\left(\text{ELU}\left(\text{Conv1d}([\mathbf{X}j^t]{AB})\right)\right) </math>Xj+1t=MaxPool(ELU(Conv1d([Xjt]AB)))
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ ⋅ ] A B [\cdot]_{AB} </math>[⋅]AB 表示注意力块。它包含了多头 ProbSparse 自注意力和基本操作,其中 Conv1d( <math xmlns="http://www.w3.org/1998/Math/MathML"> ⋅ \cdot </math>⋅) 在时间维度上执行一维卷积滤波(内核宽度=3),并使用 ELU( <math xmlns="http://www.w3.org/1998/Math/MathML"> ⋅ \cdot </math>⋅) 激活函数(Clevert, Unterthiner, and Hochreiter 2016)。我们在层堆叠之后添加一个跨度为2的最大池化层,将样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> X t \mathbf{X}^t </math>Xt 向下采样到其一半切片,这减少了整体内存使用至 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( ( 2 − ϵ ) L log L ) \mathcal{O}((2-\epsilon)L\log L) </math>O((2−ϵ)LlogL),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 是一个小数。为了增强蒸馏操作的鲁棒性,我们构建主要堆叠的副本,用半量输入,并通过一次丢弃一层的方式逐步减少自注意力蒸馏层的数量,如图2中的金字塔,使得输出尺寸对齐。因此,我们连接所有堆叠的输出,并获得编码器的最终隐藏表示。
Decoder:一次性生成长序列输出
解码器:通过一次前向过程生成长序列输出
我们在图2中使用了一个标准的解码器结构(Vaswani et al. 2017),它由两个相同的多头注意力层堆叠而成。然而,我们采用生成推理以缓解长预测中的速度下降。我们将以下向量输入到解码器中:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X de t = Concat ( X token t , X 0 t ) ∈ R ( L token + L y ) × d model \mathbf{X}{\text{de}}^t = \text{Concat}(\mathbf{X}{\text{token}}^t, \mathbf{X}0^t) \in \mathbb{R}^{(L{\text{token}}+L_y)\times d_{\text{model}}} </math>Xdet=Concat(Xtokent,X0t)∈R(Ltoken+Ly)×dmodel
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> X token t ∈ R L token × d model \mathbf{X}{\text{token}}^t \in \mathbb{R}^{L{\text{token}} \times d_{\text{model}}} </math>Xtokent∈RLtoken×dmodel 是起始标记, <math xmlns="http://www.w3.org/1998/Math/MathML"> X 0 t ∈ R L y × d model \mathbf{X}0^t \in \mathbb{R}^{L_y \times d{\text{model}}} </math>X0t∈RLy×dmodel 是作为目标序列的占位符(设定为标量0)。在 ProbSparse 自注意力计算中应用了掩码多头注意力,通过将掩码点积设置为 <math xmlns="http://www.w3.org/1998/Math/MathML"> − ∞ -\infty </math>−∞。这防止了每个位置关注即将到来的位置,从而避免了自回归。一个全连接层获取最终输出,其输出大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> d y d_y </math>dy 取决于我们执行的是单变量预测还是多变量预测。
生成推理 起始标记在NLP的"动态解码"(Devlin et al. 2018)中被有效应用,并且我们以生成方式扩展它。我们不是选择特定标记作为标志,而是从输入序列中采样一个长为 <math xmlns="http://www.w3.org/1998/Math/MathML"> L token L_{\text{token}} </math>Ltoken 的序列,比如输出序列之前的一个较早的切片。以预测168个点(实验部分的7天温度预测)为例,我们将目标序列之前已知的5天作为"起始标记",并将生成式推理解码器输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X de = { X 5 d , X 0 } \mathbf{X}{\text{de}} = \{\mathbf{X}{5d}, \mathbf{X}_0\} </math>Xde={X5d,X0}。 <math xmlns="http://www.w3.org/1998/Math/MathML"> X 0 \mathbf{X}_0 </math>X0 包含目标序列的时间戳,即目标周的上下文。然后,我们提出的解码器通过一次前向过程来预测输出,而不是在传统的编码-解码架构中耗时的"动态解码"。详细的性能比较在计算效率部分给出。
损失函数 我们选择均方误差(MSE)损失函数用于目标序列的预测,损失通过整个模型的解码器输出进行反向传播。