混合注意力学习(1): 线性注意力
目录
- [混合注意力学习(1): 线性注意力](#混合注意力学习(1): 线性注意力)
- Prefill、Decode与KVCache
- 混合注意力架构
- 线性注意力
- [Transformers 就是 RNNs[4]](#Transformers 就是 RNNs[4])
- [Fast Weight Programmers与DeltaNet[9]](#Fast Weight Programmers与DeltaNet[9])
- [Gated DeltaNet[11]](#Gated DeltaNet[11])
- 线性注意力并行机制[7:3]
- 传统FWP形式的线性注意力并行
- FWP分块并行机制
- [DeltaNet 分块并行机制](#DeltaNet 分块并行机制)
- [重新定义 \(\mathbf{u}_t\) 降低内存占用](#重新定义 \mathbf{u}_t 降低内存占用)
- DeltaNet的分块并行
- [GatedDeltaNet 分块并行机制[11:1]](#GatedDeltaNet 分块并行机制[11:1])
- 线性注意力机制总结
- 线性注意力
Prefill、Decode与KVCache
在开始本文之前,首先应该介绍一下什么是prefill,什么是decode,以及对应的KVCache。这样可以更好理解内存复杂度。
小学生们都日益了解了,现有的LLM大语言模型的组成成分主要是Transformer Block。其中,注意力机制具有如下的计算公式:
\\\begin{aligned} \\text{Attn}(Q,K,V)=\\text{softmax}\\left(\\dfrac{QK\^\\top}{h}\\right)V \\end{aligned} \\
在推理过程中,我们采用 decoder-only 架构,因此注意力机制为"掩码自注意力"。流程如下:
从上面的流程图中,我们看到Key和Value是复用的,即每一次新的计算都需要用到先前的输入生成的Key和Value。因此我们将其称为KVCache 。我们可以发现,每一个Transformer Block的KVCache都是随着序列长度线性增长的。因此整体的空间复杂度为 \(O(Lnd)\) (假定\(d_k=d_v=d\))。
因此我们的prefill和decode流程简化来说是如下所示的,prefill要求多个seq并行进入进行计算,decode则每次接收一个上次生成的token进行计算。重要的两个指标为:TTFT(Prefill开始到第一个token生成所需要的时间),TPOT(每个token生成之间的时间)[[1]](#[1])。
[[1:1]](#[1:1])
[[1:2]](#[1:2])
同时我们也可以用roofline模型来刻画我们的序列长度,请求次数导致的计算上限和存储上限。对于短序列prefill和decode阶段 ,主要是存储受限(计算强度低,但是需要大量读写KVCache)。对于长序列prefill阶段,主要是计算受限(计算强度高,GEMM)
这样我们就介绍完了KVCache,Prefill和Decode,了解了他们的不同和受限情况。
混合注意力架构
混合注意力架构包括了稀疏注意力与线性注意力,而线性注意力机制又起源于这篇Transformers are RNNs的文章。我们都大致介绍一下作为我们的基本背景。
线性注意力
这里也可以阅读苏神的文章[[2]](#[2])[[3]](#[3])。我们总结了一张图如下:
Transformers 就是 RNNs^[4](#Transformers 就是 RNNs[4])^
在上文中我们已经提到了经典的自注意力机制的计算,我们这里再展示一次:
\\\begin{aligned} Q\&=xW\^Q\\\\ K\&=xW\^K\\\\ V\&=xW\^V\\\\ A_{l}(x)\&=\\text{softmax}\\left(\\dfrac{QK\^\\top}{\\sqrt{d_k}}\\right)V\\\\ x\&\\in\\mathbb{R}\^{s\\times f}, \\\\ W\^Q, W\^K\&\\in\\mathbb{R}\^{f\\times d_k}\\\\ W\^V \&\\in\\mathbb{R}\^{f\\times d_v} \\end{aligned} \\
对于上述式子,我们很容易想当然地采用近似函数来代替。我们首先分析我们的矩阵运算结果如下:
\\\begin{aligned} O_{ij}\&=\\dfrac{e\^{q_ik_j\^\\top}}{\\sum\\limits_{l=1}\^{s}e\^{q_ik_l\^\\top}}\\\\ A_{i}\&=\\sum_{l=1}\^{s}O_{il}\\cdot v_{l}=\\sum_{j=1}\^{s}\\dfrac{e\^{q_ik_j\^\\top}}{\\sum\\limits_{l=1}\^{s}e\^{q_ik_l\^\\top}}v_{j}=\\dfrac{\\sum\\limits_{j=1}\^{s}e\^{q_ik_j\^\\top}v_{j}}{\\sum\\limits_{l=1}\^{s}e\^{q_ik_l\^\\top}} \\end{aligned} \\
接下来我们就可以尝试用近似函数代替\(exp(\cdot)\),也就是 \(sim(\cdot)\):
\A_i=\\dfrac{\\sum\\limits_{j=1}\^{s}sim(q_i,k_j)v_j}{\\sum\\limits_{j=1}\^{s}sim(q_i,k_j)} \\
这里距离我们的线性注意力还有一段距离,但是已经不远了。我们回想在古老的分离向量机中,为了实现非线性支持向量机,我们需要使用到核函数技巧。
- 为了将非线性问题转换成线性问题,我们采用核函数技巧。对于输入空间 \(X\to H\),\(H\) 为特征空间(希尔伯特空间),如果存在映射函数 \(\phi(x):X\to H\) 满足 \(sim(x, z)=\phi(x)^\top\cdot \phi(z)\), 则 \(sim(\cdot)\) 为核函数,\(\phi(\cdot)\) 为映射函数。[[5]](#[5])
- 在注意力机制中,我们只需要要求\(sim(\cdot)\)函数非负来表示其概率性。对此可以参见[[6]](#[6])
因此,我们采用核函数技巧进行分解,就可以得到如下的公式:
\\\begin{aligned} A_i=\\dfrac{\\sum\\limits_{j=1}\^{s}\\phi\^\\top(q_i)\\phi(k_j)v_j}{\\sum\\limits_{j=1}\^{s}\\phi\^\\top(q_i)\\phi(k_j)} \\end{aligned} \\
很容易注意到我们可以采用结合律来提取出公因式 \(\phi^\top(q_i)\)。注意到 \(\phi(k_j)\in\mathbb{R}^{d_k}, \phi(v_j)\in\mathbb{R}^{d_v}\),采用结合律需要转置。因此我们有:
\A_i=\\dfrac{\\phi\^\\top(q_i)\\sum\\limits_{j=1}\^{s}\\phi(k_j)v_j\^\\top}{\\phi\^\\top(q_i)\\sum\\limits_{j=1}\^{s}\\phi(k_j)} \\
这样我们很容易可以看出:原先softmax的计算复杂度是 \(O(s^2(d_k+d_v))\),随着序列以 \(O(s^2)\) 的平方复杂度增长。而核函数在维持原有隐维度的条件下保持 \(O(sd_kd_v)\) 的线性复杂度增长,并且隐维度 \(d_k\) 仍然具有可优化的空间。这就是线性注意力的由来。
这样,如果我们考虑到推理模式下的decode-only掩码自注意力场景下,我们不会计算全部的序列 \(s\),而是计算到当前序列 \(i\),这样我们就有:
\\\begin{aligned} A_i\&=\\dfrac{\\sum\\limits_{j=1}\^{i}\\phi\^\\top(q_i)\\phi(k_j)v_j}{\\sum\\limits_{j=1}\^{i}\\phi\^\\top(q_i)\\phi(k_j)}=\\dfrac{\\sum\\limits_{j=1}\^{i}\\phi\^\\top(q_i)\\phi(k_j)v_j}{\\sum\\limits_{j=1}\^{i}\\phi\^\\top(q_i)\\phi(k_j)}=\\dfrac{\\phi\^\\top(q_i)\\sum\\limits_{j=1}\^{i}\\phi(k_j)v_j\^\\top}{\\phi\^\\top(q_i)\\sum\\limits_{j=1}\^{i}\\phi(k_j)} \\end{aligned} \\
我们令 \(S_i=\sum\limits_{j=1}^{i}\phi(k_j)v_j^\top, Z_i=\sum\limits_{j=1}^{i}\phi(k_j)\),就可以得到:
\\\begin{aligned} S_0\&=0\\\\ Z_0\&=0\\\\ S_i\&=S_{i-1}+\\phi(k_i)v_i\^\\top\\\\ Z_i\&=Z_{i-1}+\\phi(k_i)\\\\ A_i\&=\\dfrac{\\phi\^\\top(q_i)S_i}{\\phi\^\\top(q_i)Z_i} \\end{aligned} \\
很明显就是我们对应的最简单的RNN架构。这也说明了:(1) Transformer其实是大号RNN;(2)线性注意力在理论层面是可行的。大量的实验发现分母会导致严重的数值不稳定问题,并且可以无需映射函数直接采用分子参与计算。[[7]](#[7])。这样实际上为 \(A_i=q_i^\top S_i\).
我们最后再看看梯度的计算。注意在训练场景下我们是全序列,因此我们在给定分子 \(A_n\) 和损失函数 \(\mathcal{L}\) 的条件下,参考[[8]](#[8])处的运算,我们有:
\\\begin{aligned} \\bar{A\^{(i)}_n}\&=\\phi\^\\top(q_i)S_n=\\phi\^\\top(q_i)\\sum\\limits_{j=1}\^{n}\\phi(k_j)v_j\^\\top\\in\\mathbb{R}\^{1\\times d_v} \\\\ \\nabla_{\\phi(q_i)}\\mathcal{L}\&=\\left(\\nabla_{\\bar{A\^{(i)}_n}}\\mathcal{L}\\cdot\\nabla_{\\phi(q_i)}\\bar{A_n\^{(i)}}\\right)\^\\top=\\sum\\limits_{j=1}\^{i}\\phi(k_j)v_j\^\\top\\cdot\\left(\\nabla_{\\bar{A\^{(i)}_n}}\\mathcal{L}\\right)\^\\top\\in\\mathbb{R}\^{d_k} \\\\ \\nabla_{\\phi(k_i)}\\mathcal{L}\&=\\left(\\sum_{j=i}\^{n}\\nabla_{\\bar{A\^{(j)}_n}}\\mathcal{L}\\cdot\\nabla_{\\phi(k_i)}\\bar{A\^{(j)}_n}\\right)\^\\top=\\sum\\limits_{j=i}\^{n}\\left(\\nabla_{\\bar{A\^{(j)}_n}}\\mathcal{L}\\cdot v_i\\cdot\\phi(q_j)\\right)\\in\\mathbb{R}\^{d_k}\\\\ \\nabla_{v_i}\\mathcal{L}\&=\\sum\\limits_{j=i}\^{n}\\nabla_{\\bar{A_n\^{(j)}}}\\mathcal{L}\\cdot\\nabla_{v_i}\\bar{A_n\^{(j)}}=\\left(\\sum\\limits_{j=i}\^{n}\\phi(q_j)\^\\top\\phi(k_i)\\cdot\\left(\\nabla_{\\bar{A_n\^{(j)}}}\\mathcal{L}\\right)\^\\top\\right)\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
这样我们就把所有的梯度都计算出来了。
小贴士:在参考原有的链式法则基础上,适当通过各种转置方法保证梯度张量和参数张量保持一致(因为梯度张量应该和参数张量相同,这样才能利用梯度下降更新)可以提升我们的计算速度。
Fast Weight Programmers与DeltaNet^[9](#Fast Weight Programmers与DeltaNet[9])^
线性注意力允许我们使用如下的方式更新当前状态模仿RNN:
\\\begin{aligned} \\mathbf{S}_{t+1}=\\mathbf{S}_{t}+\\mathbf{k}_t\\mathbf{v}_t\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}, \~\~\~ \\mathbf{O}_{t+1}=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
这样对于一个长为 \(s\) 的序列:我们每次计算步数 为 \(O(s)\),计算复杂度为 \(O(s)\times O(d^2)=O(sd^2)\),训练过程中需要的空间复杂度为 \(O(sd^2)\),推理过程中需要的时间复杂度为 \(O(d^2)\).
并且,在这篇论文[[9:1]](#[9:1])中,\(\mathbf{S}\in\mathbb{R}^{d_k\times d_v}\) 实际上是一个关联性的内存,存储了当前瞬态从key到Value的映射。这样的更新可以看作是一个无上界的关联损失函数的梯度下降,从而持续强化最近的键值对,没有任何遗忘。也就是文章[[7:1]](#[7:1])中说的:
\\\mathcal{L}_t(\\mathbf{S})=-\\left\<\\mathbf{S}\^\\top\\mathbf{k}_t,\\mathbf{v}_t\\right\> \\
这样的持续无遗忘将会在长上下文中造成严重的干扰。(我们的人脑也会通过忘记无关紧要的,久远的非必要记忆来保证我们对当前上下文的专注)。
我的某位朋友指出我需要提供为什么梯度更新和上面的线性注意力更新是等价的。在此给出确切的证明。
\\\begin{aligned} \\mathcal{L}_{t}(\\mathbf{S})\&=-\\left\<\\mathbf{S}\^\\top\\mathbf{k}_t,\\mathbf{v}_t\\right\>\\\\ \&= -\\mathbf{k}_t\^\\top\\mathbf{S}\\mathbf{v}_t\\\\ \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}-\\eta\\nabla_{\\mathbf{S}}\\mathcal{L}_t(\\mathbf{S}_t)\\\\ \&=\\mathbf{S}_{t}-\\eta(-\\mathbf{k}_t\\mathbf{v}_t\^\\top)\\\\ \&=\\mathbf{S}_t+\\eta\\mathbf{k}_t\\mathbf{v}_t\^\\top \\end{aligned} \\
对于此,[[9:2]](#[9:2]) [[10]](#[10])提出了如下的更新方式,也称为Delta Rule。
- 新的 \(\{\mathbf{k}{t+1}, \mathbf{v}{t+1}\}\) 到来。
- (Read)取出上一次的 \(\mathbf{S}{t}\),构造未更新前 的前Key-我们看到Key和Value关联模式:\(\bar{\mathbf{v}}{t+1} = \mathbf{S}{t}^\top\mathbf{k}{t+1}\)
- 通过一个学习率网络 \(\mathbf{W}\beta\) 来构造动态学习率 \(\textcolor{red}{\beta{t+1}}=\sigma(\mathbf{W}\beta\mathbf{x}{t+1})\),\(\sigma(\cdot)\) 是激励函数
- 通过学习率控制K-V关联性:\(\mathbf{v}'{t+1}\leftarrow\textcolor{red}{\beta{t+1}}\mathbf{v}{t+1} + (1-\textcolor{red}{\beta{t+1}})\bar{\mathbf{v}}_{t+1}\)。
- 更新状态矩阵实现(Write)遗忘。
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}+\\underbrace{\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^{'\\top}}_{Write}-\\underbrace{\\mathbf{k}_{t+1}\\bar{\\mathbf{v}}_{t+1}\^\\top}_{forget}\\\\ \&=\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}(\\mathbf{v}_{t+1}-\\bar{\\mathbf{v}}_{t+1})\^\\top\\\\ \&=\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}(\\mathbf{v}_{t+1}-\\mathbf{S}_{t}\^\\top\\mathbf{k}_{t+1})\^\\top\\\\ \&=\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^\\top-\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{k}_{t+1}\^\\top\\mathbf{S}_t\\\\ \&=\\mathbf{S}_t-\\textcolor{red}{\\beta_{t+1}}\\nabla_{\\mathbf{S}}\\mathcal{L}(\\mathbf{S}_{t}) \\end{aligned} \\
这样也等价于重建了一个无上界的Loss重建成如下形式:
\\\mathcal{L}_t(\\mathbf{S})=\\dfrac{1}{2}\\Vert\\mathbf{S}\^\\top\\mathbf{k}_t-\\mathbf{v}_t\\Vert_2\^2 \\
通过学习系数 \(\beta\) 来单步梯度下降,修正自己当前时间步下的记忆关联 \(\mathbf{k}_t\to\mathbf{v}_t\)。这样的变换允许了硬件通过分块并行来提升计算速度。这篇文章中详细说明了针对线性注意力的高效并行[[7:2]](#[7:2])。
证明如下:
\\\begin{aligned} \\mathcal{L}_t(\\mathbf{S})\&=\\left(\\dfrac{1}{2}\\mathbf{S}\^\\top\\mathbf{k}_t-\\mathbf{v}_t\\right)\^\\top\\left(\\mathbf{S}\^\\top \\mathbf{k}_t-\\mathbf{v}_t\\right)\\\\ \&=\\dfrac{1}{2}\\left(\\mathbf{k}_t\^\\top\\mathbf{S}\\mathbf{S}\^\\top\\mathbf{k}_t-\\mathbf{k}_t\^\\top\\mathbf{S}\\mathbf{v}_t-\\mathbf{v}_t\^\\top\\mathbf{S}\^\\top\\mathbf{k}_t+\\mathbf{v}_t\^\\top\\mathbf{v}_t\\right)\\\\ \\nabla_{\\mathbf{s}}\\mathcal{L}_t(\\mathbf{S})\&=\\mathbf{k}_t\\mathbf{k}_t\^\\top\\mathbf{S}-\\mathbf{k}_t\\mathbf{v}_t\^\\top \\end{aligned} \\
然后通过 \(\textcolor{red}{\beta_{t+1}}\) 来更新我们的权重。更新的公式则为:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}-\\textcolor{red}{\\beta_{t+1}}\\nabla_{\\mathbf{S}}\\mathcal{L}_t(\\mathbf{S}_{t})=(\\mathbf{I}-\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{k}_{t+1}\^\\top)\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}\\\\ \\mathbf{O}_{t+1}\&=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
这就是DeltaNet提出的重建损失函数------来实现一定的遗忘性。\(\textcolor{red}{\beta_{t+1}}\) 也被称为 Delta 系数。
Gated DeltaNet^[11](#Gated DeltaNet[11])^
为了进一步减少历史记忆和状态对现有的影响,Mamba2进一步通过权重衰减来实现对过去的遗忘:
\\\begin{aligned} \\mathbf{S}_{t+1}=\\textcolor{blue}{\\alpha_{t+1}}\\mathbf{S}_{t}+\\mathbf{k}_t\\mathbf{v}_t\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}, \~\~\~ \\mathbf{O}_{t+1}=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v},\~\~\~\\textcolor{blue}{\\alpha_t}\\in(0,1) \\end{aligned} \\
更进一步,结合Delta规则实现遗忘:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\textcolor{blue}{\\alpha_{t+1}}(\\mathbf{I}-\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{k}_{t+1}\^\\top)\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}\\\\ \\mathbf{O}_{t+1}\&=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
这样就等价于通过如下的损失函数进行梯度下降。
\\\mathcal{L}_t(\\mathbf{S})=\\dfrac{\\textcolor{red}{\\beta_t}}{2\\textcolor{blue}{\\alpha_t}}\\Vert\\textcolor{blue}{\\alpha_t}\\mathbf{S}\^\\top\\mathbf{k}_t-\\mathbf{v}_t\\Vert_2\^2+\\dfrac{1-\\textcolor{blue}{\\alpha_t}}{2}\\Vert\\mathbf{S}\\Vert_F\^2 \\
因此 \(\textcolor{blue}{\alpha_t}\in(0,1)\) 成为门控权重衰减系数(gating weight decay),这就是Gated Delta Network的具体形式。
线性注意力并行机制[[7:3]](#线性注意力并行机制[7:3])
传统FWP形式的线性注意力并行
线性注意力在时间迭代的形式如下:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}+\\mathbf{k}_t\\mathbf{v}_t\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}\\\\ \\mathbf{O}_{t+1}\&=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
我们对比一下线性注意力的并行形式与迭代形式。针对并行形式,我们将 \(\mathbf{q}, \mathbf{k}\in\mathbb{R}^{d_k}, \mathbf{v}\in\mathbb{R}^{d_v}\) 堆叠在一起形成一个整体,这样就形成了 \(\mathbf{Q}, \mathbf{K}\in\mathbb{R}^{s\times d_k}, \mathbf{V}\in\mathbb{R}^{s\times d_v}\)。注意堆叠的方式是:\(\mathbf{K}=\\mathbf{k}_1\^\\top \~\~ \\mathbf{k}_2\^\\top \~\~ \\cdots \~\~ \\mathbf{k}_s\^\\top^\top\),\(\mathbf{Q}, \mathbf{V}\) 均采用如同 \(\mathbf{K}\) 的堆叠方式。这样我们就拥有了如下的计算公式:
\\\mathbf{S}=\\sum_{i=1}\^{s}\\mathbf{k}_t\\mathbf{v}_t\^\\top=\\mathbf{K}\^\\top\\mathbf{V} \\
而最终计算结果需要要求查询不能看到未来的键和值(不然就成透视未来了~),因此我们加入一个下三角掩码\(\mathbf{M}_s\in\mathbb{R}^{s\times s}\)即可。这样最终输出就应该是:
\\\begin{aligned} \\mathbf{O}=(\\mathbf{Q}\\mathbf{K}\^\\top\\odot \\mathbf{M}_s)\\mathbf{V} \\end{aligned} \\
我们比较一下并行算法和迭代算法的不同。在复杂度计算中,我们假设 \(d_k=d_v=d\),这样更直观一些。我们每一层Transformer Block的复杂度如下:
| 算法 | 时间复杂度 | 空间复杂度 | 计算步数 |
|---|---|---|---|
| \(\mathbf{S}_t\) 时间步迭代 | \(O(sd^2)\) | 推理\(O(d^2)\),训练\(O(sd^2)\) | \(O(s)\) |
| \(\mathbf{S}\) 并行计算 | \(O(s^2d)+O(sd^2)\) | \(O(sd)\) | \(O(1)\) |
诶?为什么并行计算方式的时间复杂度更高,但是执行时间更低呢?不要忘记并行计算的优势在于同时计算速度快 ,瓶颈因素转移到了计算步数上。对于时间迭代算法,我们无法充分发挥并行计算的优势。因此在长序列上很明显并行算法在计算步数上远小于迭代算法。在GPU上还可以充分利用tensorCore等用于GEMM的优势,时间步迭代则不行。但是并行算法内存占用很高,我们可以看到空间复杂度呈平方增长,这又失去了线性注意力的优势。
为了实现高效的计算,分块并行就成为了一个权衡两者的利弊的一个有效方式,这样可以充分利用计算资源的同时降低内存占用。
FWP分块并行机制
首先我们规定对应的符号。
- \(\mathbf{Q}{t}\in\mathbb{R}^{C\times d_k}\) 代表第 \(t\) 个分块。这个分块通过 \(\mathbf{Q}{t}=\\mathbf{q}\^\\top_{tC} \~\~ \\cdots \~\~ \\mathbf{q}\^\\top_{(t+1)C}^\top\) 的方式堆叠。
- \(\mathbf{q}{t}^r\) 代表 "第\(t\)个分块中的第\(r\)个列向量 \(\mathbf{q}{tC+r}\)"。
- 堆叠里面有\(C+1\)个向量,是因为我们还规定 \(\mathbf{q}{t+1}^0=\mathbf{q}{t}^C=\mathbf{q}_{t}\).
- 上面的记号对 \(\mathbf{k}, \mathbf{v}\) 都成立。
- \(\mathbf{S}{t}^r=\mathbf{S}{tC+r}\in\mathbb{R}^{d_k\times d_v}\),代表第t个分块中的第 \(r\) 个状态矩阵。
- \(\mathbf{S}{t}=\mathbf{S}{t}^0=\mathbf{S}_{t-1}^C\)。
这样,我们就可以改写我们的迭代步骤成混合形式。块内的某一个元素 \(r\) 则为:
\\\begin{aligned} \\mathbf{S}_{\[t}^{r}&=\mathbf{S}{t}^{r-1}+\mathbf{k}{t}^r\mathbf{v}{t}^{r\top}\\ &=\mathbf{S}{t}+\sum_{i=1}^{r}\mathbf{k}{t}^i\mathbf{v}{t}^{i\top}\in\mathbb{R}^{d_k\times d_v}\\ \mathbf{O}{t}^r&=\mathbf{S}{t}^{r\top}\mathbf{q}{t}^r\\ &=\mathbf{S}{t}^\top\mathbf{q}{t}^r+\sum{i=1}^r\mathbf{v}{t}^i\mathbf{k}{t}^{i\top}\mathbf{q}_{t}^r\in\mathbb{R}^{d_v} \end{aligned} \]
这样对整一个块我们有如下公式,注意查询不能看到未来的key和value~:
\\\begin{aligned} \\mathbf{S}_{\[t}&=\mathbf{S}{t-1}+\mathbf{K}{t}^\top\mathbf{V}{t}\in\mathbb{R}^{d_k\times d_v}\\ \mathbf{O}{t}&=\mathbf{Q}{t}\mathbf{S}{t-1}+\left(\mathbf{Q}{t}\mathbf{K}^\top{t}\odot\mathbf{M}c\right)\mathbf{V}{t}\in\mathbb{R}^{C\times d_v} \end{aligned} \]
这样我们就实现了分块并行的线性注意力策略。对于每一层Transformer block,计算步数则为 \(O(\lceil \dfrac{s}{C} \rceil)\),内存复杂度变为 \(O(C\times d)\),计算复杂度则为 \(O(\dfrac{s}{C})\times O(C^2d+Cd^2)=O(sCd+sd^2)\),再次回到了线性注意力的计算复杂度!空间复杂度上,空间复杂度为 \(O(Cd+d^2)\),训练情况下则需要保存每一个chunk为\(O(sd^2)\),也维持了线性的增长!
DeltaNet 分块并行机制
DeltaNet 的更新公式如下:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}+\\underbrace{\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^{'\\top}}_{Write}-\\underbrace{\\mathbf{k}_{t+1}\\bar{\\mathbf{v}}_{t+1}\^\\top}_{forget}\\\\ \&=\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}(\\mathbf{v}_{t+1}-\\bar{\\mathbf{v}}_{t+1})\^\\top\\\\ \&=\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}(\\mathbf{v}_{t+1}-\\mathbf{S}_{t}\^\\top\\mathbf{k}_{t+1})\^\\top\\\\ \\mathbf{O}_{t+1}\&=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
其中我们有:
\\\begin{aligned} \\mathbf{v}'_{t+1}\&=\\textcolor{red}{\\beta_{t+1}}\\mathbf{v}_{t+1} + (1-\\textcolor{red}{\\beta_{t+1}})\\bar{\\mathbf{v}}_{t+1}\\\\ \\bar{\\mathbf{v}}_{t+1} \&= \\mathbf{S}_{t}\^\\top\\mathbf{k}_{t+1} \\end{aligned} \\
最直观的方式就是将后面的一部分重新表示成一个新的向量 \(\mathbf{u}t\),并且将 \(\textcolor{red}{\beta{t}}\) 吸收进去,可以得到如下公式:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}+\\mathbf{k}_{t+1}\\underbrace{\\textcolor{red}{\\beta_{t+1}}(\\mathbf{v}_{t+1}-\\bar{\\mathbf{v}}_{t+1})\^\\top}_{\\mathbf{u}_{t+1}\^\\top}\\\\ \&=\\mathbf{S}_{t}+\\mathbf{k}_{t+1}\\underbrace{\\textcolor{red}{\\beta_{t+1}}(\\mathbf{v}_{t+1}-\\mathbf{S}_{t}\^\\top\\mathbf{k}_{t+1})\^\\top}_{\\mathbf{u}_{t+1}\^\\top}\\\\ \\mathbf{S}\&=\\sum_{i=1}\^t\\mathbf{k}_{i}\\mathbf{u}_i\^\\top \\end{aligned} \\
从而得到整体的计算公式。这样,通过上一节的堆叠和掩码方式得到并行计算的方式,也就是 \(\mathbf{U}=\\mathbf{u}_1\^\\top, \\mathbf{u}_2\^\\top, \\ldots, \\mathbf{u}_n\^\\top^\top\)
\\\mathbf{O}=(\\mathbf{Q}\\mathbf{K}\^\\top\\odot \\mathbf{M}_s)\\mathbf{U} \\
但是,这样的表示真的正确吗?我们存在如下的问题:(1) \(\mathbf{u}_t\) 需要上一个状态的直接计算,导致实际上无法针对 \(\mathbf{U}\) 进行并行计算。(2) 计算每个 \(\mathbf{u}_t\) 都需要上一个状态的状态矩阵 \(\mathbf{S}_t\),导致内存占用从 \(O(sd)\) 上升到 \(O(sd^2)\)。
重新定义 \(\mathbf{u}_t\) 降低内存占用
但是这同时也带来了新的问题:我们的空间复杂度从 \(O(sd)\) 上升到了 \(O(sd^2)\)。回顾我们的导出过程,以及上面的公式,我们可以得到:
\\\mathbf{u}_t=\\textcolor{red}{\\beta_{t+1}}(\\mathbf{v}_t-\\bar{\\mathbf{v}}_t)=\\textcolor{red}{\\beta_{t+1}}(\\mathbf{v}_t-\\mathbf{S}_t\^\\top\\mathbf{k}_{t}) \\
这就意味着我们在计算矩阵 \(\mathbf{u}_t\) 的时候,总是需要保证我们至少存取了上一次的关联矩阵 \(\mathbf{S}_t\),这样我们的实际的内存复杂度就应该为 \(O(sd^2)\)(\(d_k=d_v=d\))。对此,我们需要重新定义 \(\mathbf{u}_t\),不再存储过去的状态矩阵。通过数学归纳法来得到新的 \(\mathbf{u}_t\)。
回顾我们的导出过程,我们有:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^\\top-\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{k}_{t+1}\^\\top\\mathbf{S}_t\\\\ \&=(\\mathbf{I}-\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{k}_{t+1}\^\\top)\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^\\top \\end{aligned} \\
假设 \(\mathbf{S}t=\sum\limits{i=1}^{t}\mathbf{k}_i\mathbf{u}_i^\top\in\mathbb{R}^{d_k\times d_v}\),\(\mathbf{u}_i\in\mathbb{R}^{d_v}\). 这样我们有归纳起始条件:\(\mathbf{S}_1=\beta_1\mathbf{k}_1\mathbf{v}_1^\top\),\(\mathbf{u}1=\beta_1\mathbf{v}1\). 如果有 \(\mathbf{S}{t-1}=\sum\limits{i=1}^{t-1}\mathbf{k}_i\mathbf{u}_i^\top\),则有:
\\\begin{aligned} \\mathbf{S}_t\&=(\\mathbf{I}-\\textcolor{red}{\\beta_t}\\mathbf{k}_t\\mathbf{k}_t\^\\top)\\left(\\sum_{i=1}\^{t-1}\\mathbf{k}_i\\mathbf{u}_i\^\\top\\right)+\\textcolor{red}{\\beta_t}\\mathbf{k}_t\\mathbf{v}_t\^\\top\\\\ \&=\\sum_{i=1}\^{t-1}\\mathbf{k}_i\\mathbf{u}_i\^\\top+\\textcolor{red}{\\beta_t}\\mathbf{k}_t\\mathbf{v}_t\^\\top-\\textcolor{red}{\\beta_t}\\mathbf{k}_t\\mathbf{k}_t\^\\top\\left(\\sum_{i=1}\^{t-1}\\mathbf{k}_i\\mathbf{u}_i\^\\top\\right)\\\\ \&=\\sum_{i=1}\^{t-1}\\mathbf{k}_i\\mathbf{u}_i\^\\top+\\mathbf{k}_t\\underbrace{\\left\[\\textcolor{red}{\\beta_t}\\left(\\mathbf{v}_t\^\\top-\\mathbf{k}_t\^\\top\\sum_{i=1}\^{t-1}\\mathbf{k}_i\\mathbf{u}_i\^\\top\\right)\\right}_{\mathbf{u}_t^\top}\\ \mathbf{u}_t&=\textcolor{red}{\beta_t}\left(\mathbf{v}t-\sum{i=1}^{t-1}\mathbf{u}_i(\mathbf{k}_i^\top\mathbf{k}_t)\right) \end{aligned} \]
这样 \(\mathbf{u}t\) 将不再需要读取上一次的状态矩阵 \(\mathbf{S}{t-1}\),每次计算只需要存取 \(\mathbf{u}_{1:t}, \mathbf{v}t, \mathbf{k}{1:t}\). 此时的内存复杂度再次回到曾经的 \(O(sd)\)。
实际上上面使用的数学归纳法的灵感来源于如下的矩阵计算相关:
HouseHolder变换和WY表示 [[12]](#[12])HouseHolder 变换:对于一个非零向量 \(v\in\mathbb{R}^m\),如果一个矩阵 \(P\in\mathbb{R}^{m\times m}\) 满足
\H=I-\\beta vv\^\\top=I-\\dfrac{2vv\^\\top}{v\^\\top v} \\
则这个矩阵\(P\)称为 HouseHolder 矩阵。\(v\) 称为 HouseHolder 向量。对于一个向量 \(x\in\mathbf{R}^m\), \(y=Hx\) 称为HouseHolder变换。
我们很容易发现 HouseHolder 变换是一个 rank-1 的修正。因为非零向量外积的秩永远为1,所有的列都落在 \(\text{span}\{v\}\) 中。
WY表示:假设\(P=H_1H_2\ldots H_r\) 是一个 rank-r 的修正,这样我们有 \(P=I_m-WY^\top\), \(W,Y\in\mathbb{R}^{m\times r}\),\(P,I\in\mathbb{R}^{m\times m}\).
证明:采用数学归纳法。我们假定 \(P=I-WY^\top, P_+=PH\). 因此我们有:
\\\begin{aligned} P_+\&=(I-WY\^\\top)(I-\\beta vv\^\\top)\\\\ \&=I-WY\^\\top-\\beta vv\^\\top+\\beta WY\^\\top vv\^\\top\\\\ \&=I-WY\^\\top-\\beta(I-WY\^\\top)vv\^\\top\\\\ \&=I-WY\^\\top-\\beta Pvv\^\\top\\\\ \&=I-\[W \\,\|\\, \\beta PvY \\,\|\\, v^\top\\ &=I - W_+Y_+^\top \end{aligned} \]
这样很明显我们的WY表示是成立的。
但是,如同先前需要实现并行或者分块并行的理由相同------GPU等加速器更适合并行矩阵运算,时间步数 \(O(t)\) 的算法不适合在对应硬件上实现。因此我们需要实现DeltaNet的分块并行运算。
DeltaNet的分块并行
针对我们的DeltaNet的分块并行算法,我们需要做一系列比较复杂的变换。我们首先将原本的公式表示成如下所示。
\\\begin{aligned} \\mathbf{S}_{t}\&=(\\mathbf{I}-\\textcolor{red}{\\beta_{t}}\\mathbf{k}_{t}\\mathbf{k}_{t}\^\\top)\\mathbf{S}_{t-1}+\\textcolor{red}{\\beta_{t}}\\mathbf{k}_{t}\\mathbf{v}_{t}\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}\\\\ \\mathbf{O}_{t}\&=\\mathbf{S}\^\\top_{t}\\mathbf{q}_{t}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
我们需要充分利用广义householder变换的性质。因此我们定义如下:
\\\mathbf{h}_{t}=(\\mathbf{I}-\\textcolor{red}{\\beta_{t}}\\mathbf{k}_t\\mathbf{k}_t\^\\top)\\in\\mathbb{R}\^{d_k\\times d_k},\~\~\~\\mathbf{b}_t=\\textcolor{red}{\\beta_{t}}\\mathbf{k}_{t}\\mathbf{v}_t\^\\top\\in\\mathbb{R}\^{d_k\\times d_v} \\
因此我们重写 \(\mathbf{S}_t=\mathbf{h}t\mathbf{S}{t-1}+\mathbf{b}_t\)。通过循环迭代我们得到:
\\\begin{aligned} \\prod_{i=1}\^{n}\\overleftarrow{\\mathbf{h}_i}\&=\\mathbf{h}_{n}\\mathbf{h}_{n-1}\\cdots\\mathbf{h}_{2}\\mathbf{h}_{1}\\\\ \\mathbf{S}_t\&=\\mathbf{h}_t\\mathbf{S}_{t-1}+\\mathbf{b}_t\\\\ \&=\\prod_{i=1}\^t\\overleftarrow{\\mathbf{h}_i}\\mathbf{S}_0+\\sum_{i=1}\^t\\left(\\prod_{j=i+1}\^{t}\\overleftarrow{\\mathbf{h}_j}\\right)\\mathbf{b}_{i} \\end{aligned} \\
接着我们定义分块矩阵相关符号。
- \(\mathbf{S}{t}^r=\mathbf{S}{tC+r}\in\mathbb{R}^{d_k\times d_v}\),代表第t个分块中的第 \(r\) 个状态矩阵并且\(\mathbf{S}{t}=\mathbf{S}{t}^0=\mathbf{S}_{t-1}^C\)。
- \(\mathbf{H}{i}^{j}=\begin{cases}1, &i>j \\ \prod\limits{t=i}^{j}\overleftarrow{\mathbf{h}_{t}}, &i\leqslant j\end{cases}\)
- \(\mathbf{H}{t}^r=\prod\limits{i=1}^r\overleftarrow{\mathbf{h}{tC+i}}=\mathbf{H}{tC+1}^{tC+r}\)
- \(\mathbf{P}{t}^{r}=\sum\limits{i=1}^{r}\left(\mathbf{H}{tC+i+1}^{tC+r}\right)\mathbf{b}{tC+i}=\sum\limits_{i=1}^{r}\left(\prod\limits_{j=i+1}^{r}\overleftarrow{\mathbf{h}{tC+j}}\right)\mathbf{b}{tC+i}\)
- \(\textcolor{red}{\beta_{t}^r}=\textcolor{red}{\beta_{tC+r}}\). 这对 \(\mathbf{k}{t}^r,\mathbf{v}{t}^r,\mathbf{q}_{t}^r\) 同样适用。
\\\begin{aligned} \\mathbf{S}_{\[t}^r&=\prod_{i=1}^r\overleftarrow{\mathbf{h}{tC+i}}\mathbf{S}{t}^0+\sum_{i=1}^{r}\left(\prod_{j=i+1}^{r}\overleftarrow{\mathbf{h}{tC+j}}\right)\mathbf{b}{tC+i}\\ &=\mathbf{H}{t}^r\mathbf{S}{t}^0+\mathbf{P}_{t}^r \end{aligned} \]
这就是我们原文的初始分块模式。但是存储 \(\mathbf{H}{t}^r\) 和 \(\mathbf{P}{t}^r\) 在训练/prefill过程中需要 \(O(sd^2)\) 的存储空间(假设 \(d_k=d_v=d\)),我们可以通过类似上面的 WY表示的数学归纳法 来降低内存占用到 \(O(sd)\)。
我们需要分析一下 \(\mathbf{H}{t}^r\) 和 \(\mathbf{P}{t}^r\)。 我们将他们展开,可以得到:
\\\begin{aligned} \\mathbf{H}_{\[t}^r&=\prod_{i=1}^{r}\overleftarrow{(\mathbf{I}-\textcolor{red}{\beta_{t}^i}\mathbf{k}{t}^{i}\mathbf{k}{t}^{i\top})}\\ \mathbf{P}{t}^r&=\sum{i=1}^{r}\left(\prod_{j=i+1}^{r}\overleftarrow{(\mathbf{I}-\textcolor{red}{\beta_{t}^j}\mathbf{k}{t}^{j}\mathbf{k}{t}^{j\top})}\left(\textcolor{red}{\beta_{t}^i}\mathbf{k}{t}^{i}\mathbf{v}{t}^{i}\right)\right) \end{aligned} \]
接下来进行归纳假设。归纳起点和归纳假设分别如下:
\\\begin{aligned} \\mathbf{H}_{\[t}^{1}&=\mathbf{I}-\textcolor{red}{\beta_{t}^{1}}\mathbf{k}{t}^{1}\mathbf{k}{t}^{1\top}\\ \mathbf{P}{t}^{1}&=\textcolor{red}{\beta{t}^{1}}\mathbf{k}{t}^{1}\mathbf{v}{t}^{1\top}\\ \mathbf{H}{t}^{r}&=\mathbf{I}-\sum{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\\ \mathbf{P}{t}^{r}&=\sum{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top} \end{aligned} \]
这样我们通过数学归纳法可以得到:
\\\begin{aligned} \\mathbf{H}_{\[t}^r&=\mathbf{h}{tC+r}\mathbf{H}{t}^{r-1}\\ &=\left(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top}\right)\left(\mathbf{I}-\sum_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\\ &=\mathbf{I}-\sum_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}-\mathbf{k}{t}^{r}\underbrace{\left\\textcolor{red}{\\beta_{\[t}^{r}}\left(\mathbf{k}{t}^{r\top}-\mathbf{k}{t}^{r\top}\sum{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\right]}{\mathbf{w}{t}^{r\top}}\\ &=\mathbf{I}-\sum_{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\in\mathbb{R}^{d_k\times d_k}\\ \mathbf{P}{t}^{r}&=\mathbf{h}{tC+r}\mathbf{P}{t}^{r-1}+\mathbf{b}{tC+r}\\ &=(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top})\sum_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}+\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{v}{t}^{r\top}\\ &=\sum_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}+\mathbf{k}{t}^{r}\underbrace{\left\\textcolor{red}{\\beta_{\[t}^{r}}\left(\mathbf{v}{t}^{r\top}-\mathbf{k}{t}^{r\top}\sum{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\right)\right]}{\mathbf{u}{t}^{r\top}}\\ &=\sum_{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\in\mathbb{R}^{d_k\times d_v} \end{aligned} \]
这样我们总结一下最终的 \(\mathbf{w}{t}^{r\top}\) 和 \(\mathbf{u}{t}^{r\top}\) 如下:
\\\begin{aligned} \\mathbf{w}_{\[t}^{r}&=\textcolor{red}{\beta_{t}^{r}}\left(\mathbf{k}{t}^{r}-\sum{i=1}^{r-1}\mathbf{w}{t}^{i}\left(\mathbf{k}{t}^{i\top}\mathbf{k}{t}^{r}\right)\right)\in\mathbb{R}^{d_k}\\ \mathbf{u}{t}^{r}&=\textcolor{red}{\beta_{t}^{r}}\left(\mathbf{v}{t}^{r}-\sum{i=1}^{r-1}\mathbf{u}{t}^{i}\left(\mathbf{k}{t}^{i\top}\mathbf{k}_{t}^{r}\right)\right)\in\mathbb{R}^{d_v} \end{aligned} \]
这样我们只需要存储的"KVCache"就应该是 \(\mathbf{k},\mathbf{v},\mathbf{u},\mathbf{w}\), 同时我们需要在训练过程中存取整个序列长度,因此我们的空间复杂度再次回归到 \(O(sd)(d_k=d_v=d)\)。
最后,我们来完善整体的分块并行机制推导。我们可以得到,按照类似如下的方式来堆叠矩阵形成 \(\mathbf{W},\mathbf{Q},\mathbf{K}\in\mathbb{R}^{c\times d_k},\mathbf{U},\mathbf{V},\mathbf{O}\in\mathbb{R}^{c\times d_v}\) 形成分块矩阵:
\\\begin{aligned} \\mathbf{Q}_{\[t}&=\\mathbf{q}_{\[t}^{1}\in\mathbb{R}^{d_k}~~\mathbf{q}{t}^{2}~~ \cdots ~~ \mathbf{q}{t}^{c}]^\top\in\mathbb{R}^{c\times d_k}\\ \mathbf{V}{t}&=\\mathbf{v}_{\[t}^{1}\in\mathbb{R}^{d_v}~~\mathbf{v}{t}^{2}~~\cdots~~\mathbf{v}_{t}^{c}]^\top\in\mathbb{R}^{c\times d_v} \end{aligned} \]
可以得到如下的并行方式:(此处论文[[7:4]](#[7:4])处公式 \(\mathbf{o}_{t}^r\) 展开可能出现了笔误)
\\\begin{aligned} \\mathbf{S}_{\[t}^{r}&=\mathbf{H}{t}^r\mathbf{S}{t}^0+\mathbf{P}{t}^r\\ &=\left(\mathbf{I}-\sum{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\mathbf{S}{t}+\sum{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\\ &=\mathbf{S}{t}+\sum{i=1}^{r}\mathbf{k}{t}^{i}\left(\mathbf{u}{t}^{i\top}-\mathbf{w}{t}^{i\top}\mathbf{S}{t} \right)\in\mathbb{R}^{d_k\times d_v}\\ \mathbf{o}{t}^{r}&=\mathbf{S}{t}^{r\top}\mathbf{q}{t}^{r}=\mathbf{S}{t}^\top\mathbf{q}{t}^{r}+\sum{i=1}^{r}\left(\mathbf{u}{t}^{i}-\mathbf{S}{t}^\top\mathbf{w}{t}^{i} \right)\left(\mathbf{k}{t}^{i\top}\mathbf{q}{t}^{r}\right)\in\mathbb{R}^{d_v}\\ \mathbf{S}{t+1}&=\mathbf{S}{t}+\mathbf{K}{t}^\top\left(\mathbf{U}{t}-\mathbf{W}{t}\mathbf{S}{t}\right)\\ \mathbf{O}{t}&=\mathbf{Q}{t}\mathbf{S}{t}+\left(\mathbf{Q}{t}\mathbf{K}{t}^\top\odot M_c\right)\left(\mathbf{U}{t}-\mathbf{W}{t}\mathbf{S}_{t}\right) \end{aligned} \]
更进一步,我们发现现有的 \(\mathbf{w}{t}^r,\mathbf{u}{t}^r\) 计算仍然是迭代的。为了进一步消除迭代,我们可以分析现有的行为来进一步实现并行。观察现有的矩阵形式,我们有:
\\\begin{aligned} \&\\underbrace{\\begin{bmatrix} \\mathbf{w}_{\[t}^{1\top}\\ \mathbf{w}{t}^{2\top}\\ \vdots\\ \textcolor{purple}{\mathbf{w}{t}^{r\top}}\\ \vdots \end{bmatrix}}{\mathbf{W}{t}\in\mathbb{R}^{c\times d_k}}= \underbrace{\begin{bmatrix} \textcolor{red}{\beta_{t}^1} & 0 & \cdots & 0\\ 0 & \textcolor{red}{\beta_{t}^2} & \cdots & 0\\ \vdots & \vdots & \ddots & \vdots\\ \textcolor{purple}{0} & \textcolor{purple}{\cdots} & \textcolor{purple}{\beta_{t}^r} & \textcolor{purple}{\cdots} \\ \vdots & \vdots & \vdots & \vdots \end{bmatrix}}{\textrm{diag}(\textcolor{red}{\beta{t}})\in\mathbb{R}^{c\times c}} \underbrace{\begin{bmatrix} \mathbf{k}{t}^{1\top}\\ \mathbf{k}{t}^{2\top}\\ \vdots\\ \textcolor{purple}{\mathbf{k}{t}^{r\top}}\\ \vdots \end{bmatrix}}{\mathbf{K}{t}\in\mathbb{R}^{c\times d_k}}- \underbrace{\begin{bmatrix} \textcolor{red}{\beta{t}^1} & 0 & \cdots & 0\\ 0 & \textcolor{red}{\beta_{t}^2} & \cdots & 0\\ \vdots & \vdots & \ddots & \vdots\\ \textcolor{purple}{0} & \textcolor{purple}{\cdots} & \textcolor{purple}{\beta_{t}^r} & \textcolor{purple}{\cdots} \\ \vdots & \vdots & \vdots & \vdots \end{bmatrix}}{\textrm{diag}(\textcolor{red}{\beta{t}})\in\mathbb{R}^{c\times c}} \underbrace{\begin{bmatrix} 0 & 0 & \cdots & 0 & 0\\ \mathbf{k}{t}^{2\top}\mathbf{k}{t}^{1} & 0 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ \textcolor{purple}{\mathbf{k}{t}^{r\top}\mathbf{k}{t}^{1}} & \textcolor{purple}{\cdots} & \textcolor{purple}{\mathbf{k}{t}^{r\top}\mathbf{k}{t}^{r-1}} & \textcolor{purple}{\cdots} & \textcolor{purple}{0}\\ \vdots & \vdots & \vdots & \vdots & \vdots \end{bmatrix}}{\textrm{StrictLower}(\mathbf{K}{t}\mathbf{K}{t}^{\top})\in\mathbb{R}^{c\times c}} \begin{bmatrix} \mathbf{w}{t}^{1\top}\\ \mathbf{w}{t}^{2\top}\\ \vdots\\ \textcolor{purple}{\mathbf{w}{t}^{r\top}}\\ \vdots \end{bmatrix}\\ &\mathbf{W}{t}+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\mathbf{W}{t}=\textrm{diag}(\textcolor{red}{\beta{t}})\mathbf{K}{t}\\ &\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)\mathbf{W}{t}=\textrm{diag}(\textcolor{red}{\beta_{t}})\mathbf{K}{t}\\ &\mathbf{W}{t}=\underbrace{\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta_{t}})}{\mathbf{T}{t}}\mathbf{K}{t}\\ &\mathbf{U}{t}=\underbrace{\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta_{t}})}{\mathbf{T}{t}}\mathbf{V}{t}\\ &\mathbf{T}{t}=\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta_{t}}) \end{aligned} \]
这样我们就完全消除了DeltaNet中的所有时序依赖,允许我们直接进行分块并行运算。最后我们总结一下DeltaNet的分块并行计算的原理如下:
\\\begin{aligned} \\mathbf{S}_{\[t+1}&=\mathbf{S}{t}+\mathbf{K}{t}^\top\left(\mathbf{U}{t}-\mathbf{W}{t}\mathbf{S}{t}\right)\\ \mathbf{O}{t}&=\mathbf{Q}{t}\mathbf{S}{t}+\left(\mathbf{Q}{t}\mathbf{K}{t}^\top\odot M_c\right)\left(\mathbf{U}{t}-\mathbf{W}{t}\mathbf{S}{t}\right)\\ \mathbf{W}{t}&=\mathbf{T}{t}\mathbf{K}{t}, \mathbf{U}{t}=\mathbf{T}{t}\mathbf{V}{t}\\ \mathbf{T}{t}&=\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta_{t}}) \end{aligned} \]
最终DeltaNet的实现图如下所示:
GatedDeltaNet 分块并行机制^[11:1](#GatedDeltaNet 分块并行机制[11:1])^
接下来我们将上面的技巧全部集合到一起,就形成了GDN网络,也是目前混合注意力中采用线性注意力的主要结构。我们首先还是回顾我们的基础迭代变换:
\\\begin{aligned} \\mathbf{S}_{t+1}\&=\\textcolor{blue}{\\alpha_{t+1}}(\\mathbf{I}-\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{k}_{t+1}\^\\top)\\mathbf{S}_{t}+\\textcolor{red}{\\beta_{t+1}}\\mathbf{k}_{t+1}\\mathbf{v}_{t+1}\^\\top\\in\\mathbb{R}\^{d_k\\times d_v}\\\\ \\mathbf{O}_{t+1}\&=\\mathbf{S}\^\\top_{t+1}\\mathbf{q}_{t+1}\\in\\mathbb{R}\^{d_v} \\end{aligned} \\
我们仍然按照上面所述的方式来重新进行递归展开。这样我们可以得到:
\\\begin{aligned} \\mathbf{S}_{\[t}^{r}&=\underbrace{\left(\prod_{i=1}^{r}\textcolor{blue}{\alpha_{t}^{i}}\overleftarrow{\left(\mathbf{I}-\beta_{t}^{i}\mathbf{k}{t}^{i}\mathbf{k}{t}^{i\top}\right)}\right)}{\mathbf{H}{t}^{r}}\mathbf{S}{t}+\underbrace{\sum{i=1}^{r}\left(\prod_{j=i+1}^{r}\textcolor{blue}{\alpha_{t}^{j}}\overleftarrow{\left(\mathbf{I}-\beta_{t}^{j}\mathbf{k}{t}^{j}\mathbf{k}{t}^{j\top}\right)}\right)\left(\beta_{t}^{i}\mathbf{k}{t}^{i}\mathbf{v}{t}^{i\top}\right)}{\mathbf{P}{t}^{r}} \end{aligned} \]
这样,我们发现,我们的Gated DeltaNet对 \(\mathbf{H}{t}^{r}, \mathbf{P}{t}^{r}\) 进行了修改。我们需要重新推断出正确的式子。对此我们仍然使用 WY表示的数学归纳法,起点和定义如下。
\\\begin{aligned} \\gamma_{\[t}^{r}&=\prod_{i=1}^{r}\textcolor{blue}{\alpha_{t}^{i}}\in\mathbb{R}\\ \mathbf{H}{t}^{1}&=\gamma{t}^{1}(\mathbf{I}-\textcolor{red}{\beta_{t}^{1}}\mathbf{k}{t}^{1}\mathbf{k}{t}^{1\top}),~\mathbf{H}{t}^{r}=\gamma{t}^{r}(\mathbf{I}-\sum\limits_{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top})\\ \mathbf{H}{t}^{r}&=\textcolor{blue}{\alpha{t}^{r}}(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top})\mathbf{H}{t}^{r-1}\\ &=\textcolor{blue}{\alpha{t}^{r}}(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top})\gamma_{t}^{r-1}(\mathbf{I}-\sum\limits_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top})\\ &=\gamma_{t}^{r}\left(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top}\right)\left(\mathbf{I}-\sum_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\\ &=\gamma_{t}^{r}\left(\mathbf{I}-\sum_{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}-\mathbf{k}{t}^{r}\underbrace{\left\\textcolor{red}{\\beta_{\[t}^{r}}\left(\mathbf{k}{t}^{r\top}-\mathbf{k}{t}^{r\top}\sum{i=1}^{r-1}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\right]}{\mathbf{w}{t}^{r\top}}\right)\\ &=\gamma_{t}^{r}\left(\mathbf{I}-\sum\limits_{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\\ \mathbf{w}{t}^{r}&=\textcolor{red}{\beta{t}^{r}}\left(\mathbf{k}{t}^{r}-\sum{i=1}^{r-1}\mathbf{w}{t}^{i}\left(\mathbf{k}{t}^{i\top}\mathbf{k}_{t}^{r}\right)\right)\in\mathbb{R}^{d_k} \end{aligned} \]
对于 \(\mathbf{P}{t}^{r}\) 则稍微有些复杂。我们推导如下:(这个系数(紫色部分)挺厉害的,但是论文[[11:2]](#[11:2])附录A部分的证明第二行漏了一个 \(k{t+1}\))。
\\\begin{aligned} \\mathbf{P}_{\[t}^{1}&=\textcolor{red}{\beta_{t}^{1}}\mathbf{k}{t}^{1}\mathbf{v}{t}^{1\top},~\mathbf{P}{t}^{r}=\sum{i=1}^{r}\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\\ \mathbf{P}{t}^{r}&=\textcolor{blue}{\alpha{t}^{r}}(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top})P_{t}^{r-1}+\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{v}{t}^{r\top}\\ &=\textcolor{blue}{\alpha_{t}^{r}}(\mathbf{I}-\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{k}{t}^{r\top})\left(\sum_{i=1}^{r-1}\dfrac{\gamma_{t}^{r-1}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\right)+\textcolor{red}{\beta_{t}^{r}}\mathbf{k}{t}^{r}\mathbf{v}{t}^{r\top}\\ &(\textcolor{blue}{\alpha_{t}^{r}}\gamma_{t}^{r-1}=\gamma_{t}^{r})\\ &=\sum_{i=1}^{r-1}\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}+\textcolor{purple}{\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{r}}}\mathbf{k}{t}^{r}\underbrace{\left\\textcolor{red}{\\beta_{\[t}^{r}}\left(\mathbf{v}{t}^{r\top}-\sum_{i=1}^{r-1}\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\right)\right]}{\mathbf{u}{t}^{r\top}}\\ &=\sum_{i=1}^{r}\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\\ \mathbf{u}{t}^{r}&=\textcolor{red}{\beta{t}^{r}}\left(\mathbf{v}{t}^{r}-\sum{i=1}^{r-1}\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{u}{t}^{i}\left(\mathbf{k}{t}^{i\top}\mathbf{k}_{t}^{r}\right)\right)\in\mathbb{R}^{d_v} \end{aligned} \]
我们仍然沿用DeltaNet的展开方式进一步实现并行:
\\\begin{aligned} \\mathbf{S}_{\[t}^{r}&=\mathbf{H}{t}^r\mathbf{S}{t}^0+\mathbf{P}{t}^r\\ &=\gamma{t}^{r}\left(\mathbf{I}-\sum_{i=1}^{r}\mathbf{k}{t}^{i}\mathbf{w}{t}^{i\top}\right)\mathbf{S}{t}+\sum{i=1}^{r}\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\mathbf{u}{t}^{i\top}\\ &=\gamma_{t}^{r}\mathbf{S}{t}+\sum{i=1}^{r}\mathbf{k}{t}^{i}\left(\dfrac{\gamma{t}^{r}}{\gamma_{t}^{i}}\mathbf{u}{t}^{i\top}-\gamma{t}^{r}\mathbf{w}{t}^{i\top}\mathbf{S}{t} \right)\in\mathbb{R}^{d_k\times d_v}\\ \mathbf{o}{t}^{r}&=\mathbf{S}{t}^{r\top}\mathbf{q}{t}^{r}=\gamma{t}^{r}\mathbf{S}{t}^\top\mathbf{q}{t}^{r}+\sum_{i=1}^{r}\left(\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{i}}\mathbf{u}{t}^{i\top}-\gamma{t}^{r}\mathbf{w}{t}^{i\top}\mathbf{S}{t} \right)\left(\mathbf{k}{t}^{i\top}\mathbf{q}{t}^{r}\right)\in\mathbb{R}^{d_v} \end{aligned} \]
我们继续沿用 DeltaNet 的公式,不过如果要符合论文[[11:3]](#[11:3])的形式,需要进行一些改写。于是我们有如下需要改造的部分,我们用不同的颜色进行了标注。
\\\begin{aligned} \\mathbf{S}_{\[t+1}&=\textcolor{orange}{\mathbf{S}{t}}+\textcolor{brown}{\mathbf{K}{t}^\top}\left(\textcolor{purple}{\mathbf{U}{t}}-\textcolor{green}{\mathbf{W}{t}}\mathbf{S}{t}\right)\\ \mathbf{O}{t}&=\textcolor{green}{\mathbf{Q}{t}}\mathbf{S}{t}+\left(\textcolor{green}{\mathbf{Q}{t}}\textcolor{brown}{\mathbf{K}{t}^\top}\odot M_c\right)\left(\textcolor{purple}{\mathbf{U}{t}}-\textcolor{green}{\mathbf{W}{t}}\mathbf{S}_{t}\right) \end{aligned} \]
首先我们还是先来写出 \(\textcolor{purple}{\mathbf{U}_{t}}\) 这个比较困难的部分,以消除计算过程中的时间依赖关系。我们有如下的过程:
\\\begin{aligned} \&\\underbrace{\\begin{bmatrix} \\mathbf{u}_{\[t}^{1\top}\\ \mathbf{u}{t}^{2\top}\\ \vdots\\ \textcolor{purple}{\mathbf{u}{t}^{r\top}}\\ \vdots \end{bmatrix}}{\textcolor{purple}{\mathbf{U}{t}}\in\mathbb{R}^{c\times d_v}}= \underbrace{\begin{bmatrix} \textcolor{red}{\beta_{t}^1} & 0 & \cdots & 0\\ 0 & \textcolor{red}{\beta_{t}^2} & \cdots & 0\\ \vdots & \vdots & \ddots & \vdots\\ \textcolor{purple}{0} & \textcolor{purple}{\cdots} & \textcolor{purple}{\beta_{t}^r} & \textcolor{purple}{\cdots} \\ \vdots & \vdots & \vdots & \vdots \end{bmatrix}}{\textrm{diag}(\textcolor{red}{\beta{t}})\in\mathbb{R}^{c\times c}} \underbrace{\begin{bmatrix} \mathbf{v}{t}^{1\top}\\ \mathbf{v}{t}^{2\top}\\ \vdots\\ \textcolor{purple}{\mathbf{v}{t}^{r\top}}\\ \vdots \end{bmatrix}}{\mathbf{V}{t}\in\mathbb{R}^{c\times d_v}}\\ &-\underbrace{\begin{bmatrix} \textcolor{red}{\beta{t}^1} & 0 & \cdots & 0\\ 0 & \textcolor{red}{\beta_{t}^2} & \cdots & 0\\ \vdots & \vdots & \ddots & \vdots\\ \textcolor{purple}{0} & \textcolor{purple}{\cdots} & \textcolor{purple}{\beta_{t}^r} & \textcolor{purple}{\cdots} \\ \vdots & \vdots & \vdots & \vdots \end{bmatrix}}{\textrm{diag}(\textcolor{red}{\beta{t}})\in\mathbb{R}^{c\times c}} \left(\underbrace{\begin{bmatrix} 0 & 0 & 0 & \cdots & 0\\ \vdots & \vdots & \ddots & \cdots & \vdots\\ \textcolor{purple}{\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{1}}} & \textcolor{purple}{\cdots} & \textcolor{purple}{\dfrac{\gamma_{t}^{r}}{\gamma_{t}^{r-1}}} & \textcolor{purple}{0} & \textcolor{purple}{\cdots}\\ \vdots & \vdots & \vdots & \cdots & \vdots \end{bmatrix}}{\textrm{StrictLower}\left(\textcolor{purple}{\mathbf{\Gamma}{t}}\right)\in\mathbb{R}^{c\times c}, \Gamma_{t}^{(ij)}=\frac{\gamma_{t}^{i}}{\gamma_{t}^{j}}}\odot \underbrace{\begin{bmatrix} 0 & 0 & \cdots & 0 & 0\\ \mathbf{k}{t}^{2\top}\mathbf{k}{t}^{1} & 0 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ \textcolor{purple}{\mathbf{k}{t}^{r\top}\mathbf{k}{t}^{1}} & \textcolor{purple}{\cdots} & \textcolor{purple}{\mathbf{k}{t}^{r\top}\mathbf{k}{t}^{r-1}} & \textcolor{purple}{\cdots} & \textcolor{purple}{0}\\ \vdots & \vdots & \vdots & \vdots & \vdots \end{bmatrix}}{\textrm{StrictLower}(\mathbf{K}{t}\mathbf{K}{t}^{\top})\in\mathbb{R}^{c\times c}}\right) \begin{bmatrix} \mathbf{u}{t}^{1\top}\\ \mathbf{u}{t}^{2\top}\\ \vdots\\ \textcolor{purple}{\mathbf{u}{t}^{r\top}}\\ \vdots \end{bmatrix}\\ &\textcolor{purple}{\mathbf{U}{t}}+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\textcolor{purple}{\mathbf{\Gamma}{t}}\odot\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\textcolor{purple}{\mathbf{U}{t}}=\textrm{diag}(\textcolor{red}{\beta_{t}})\mathbf{V}{t}\\ &\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\textcolor{purple}{\mathbf{\Gamma}{t}}\odot\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)\textcolor{purple}{\mathbf{U}{t}}=\textrm{diag}(\textcolor{red}{\beta{t}})\mathbf{V}{t}\\ &\textcolor{purple}{\mathbf{U}{t}}=\underbrace{\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\textcolor{purple}{\mathbf{\Gamma}{t}}\odot\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta{t}})}{\textcolor{purple}{\mathbf{T}{t}}}\mathbf{V}{t}\\ &\textcolor{purple}{\mathbf{T}{t}}=\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\textcolor{purple}{\mathbf{\Gamma}{t}}\odot\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta{t}})\\ &\mathbf{W}{t}=\mathbf{T}{t}\mathbf{K}{t}, \mathbf{T}{t}=\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta_{t}}) \end{aligned} \]
接下来,我们进一步改写 \(\textcolor{orange}{\mathbf{S}{t}}, \textcolor{brown}{\mathbf{K}{t}}, \textcolor{green}{\mathbf{W}{t}}, \textcolor{green}{\mathbf{Q}{t}}\)。我们有:
\\\begin{aligned} \\mathbf{S}_{\[t+1}&=\mathbf{S}{t}^{c}=\gamma{t}^{c}\mathbf{S}{t}+\sum{i=1}^{r}\mathbf{k}{t}^{i}\left(\dfrac{\gamma{t}^{r}}{\gamma_{t}^{i}}\textcolor{purple}{\mathbf{u}{t}^{i\top}}-\gamma{t}^{r}\mathbf{w}{t}^{i\top}\mathbf{S}{t} \right)\in\mathbb{R}^{d_k\times d_v}\\ &=\underbrace{\gamma_{t}^{c}\mathbf{S}{t}}{\textcolor{orange}{\mathbf{S}{t}}}+\sum{i=1}^{c}\underbrace{\left(\dfrac{\gamma_{t}^{c}}{\gamma_{t}^{i}}\mathbf{k}{t}^{i}\right)}{\textcolor{brown}{\mathbf{k}{t}^{i}}}\left(\textcolor{purple}{\mathbf{u}{t}^{i\top}}-\underbrace{\left(\gamma_{t}^{i}\mathbf{w}{t}^{i\top}\right)}{\textcolor{green}{\mathbf{w}{t}^{i}}}\mathbf{S}{t} \right)\in\mathbb{R}^{d_k\times d_v}\\ &=\textcolor{orange}{\mathbf{S}{t}}+\textcolor{brown}{\mathbf{K}{t}^\top}\left(\textcolor{purple}{\mathbf{U}{t}}-\textcolor{green}{\mathbf{W}{t}}\mathbf{S}{t}\right)\\ \textcolor{orange}{\mathbf{S}{t}}&=\textcolor{orange}{\gamma_{t}^{c}}\mathbf{S}{t}, \textcolor{brown}{\mathbf{K}{t}}=\left\\cdots\~\\textcolor{brown}{\\dfrac{\\gamma_{\[t}^{c}}{\gamma_{t}^{i}}}\mathbf{k}{t}^{i}~\cdots\right]^\top\\ \textcolor{green}{\mathbf{W}{t}}&=\left\\cdots\~\\textcolor{green}{\\gamma_{\[t}^{i}}\mathbf{w}{t}^{i}~\cdots\right]^\top, \textcolor{green}{\mathbf{Q}{t}}=\left\\cdots\~\\textcolor{green}{\\gamma_{\[t}^{i}}\mathbf{q}_{t}^{i}~\cdots\right]^\top \end{aligned} \]
这样我们就完成了对Gated DeltaNet的整体分块并行改造,实现块内并行,块间迭代的方式。我们最后总结如下:
\\\begin{aligned} \\mathbf{S}_{\[t+1}&=\textcolor{orange}{\mathbf{S}{t}}+\textcolor{brown}{\mathbf{K}{t}^\top}\left(\textcolor{purple}{\mathbf{U}{t}}-\textcolor{green}{\mathbf{W}{t}}\mathbf{S}{t}\right)\\ \mathbf{O}{t}&=\textcolor{green}{\mathbf{Q}{t}}\mathbf{S}{t}+\left(\textcolor{green}{\mathbf{Q}{t}}\textcolor{brown}{\mathbf{K}{t}^\top}\odot M_c\right)\left(\textcolor{purple}{\mathbf{U}{t}}-\textcolor{green}{\mathbf{W}{t}}\mathbf{S}{t}\right)\\ \textcolor{purple}{\mathbf{T}{t}}&=\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\textcolor{purple}{\mathbf{\Gamma}{t}}\odot\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta{t}})\\ \textcolor{purple}{\mathbf{\Gamma}{t}^{(ij)}}&=\dfrac{\gamma{t}^{i}}{\gamma_{t}^{j}}, \textcolor{green}{\gamma_{t}}=\left\\gamma_{\[t}^{1}~\cdots~\gamma_{t}^{i}~\cdots~\gamma_{t}^{c}\right]^\top,\textcolor{brown}{\gamma_{t}}=\left\\dfrac{\\gamma_{\[t}^{c}}{\gamma_{t}^{1}}~\cdots~\dfrac{\gamma_{t}^{c}}{\gamma_{t}^{i}}~\cdots~\dfrac{\gamma_{t}^{c}}{\gamma_{t}^{r}}\right]^\top\\ \mathbf{T}{t}&=\left(\mathbf{I}c+\textrm{diag}(\textcolor{red}{\beta{t}})\textrm{StrictLower}\left(\mathbf{K}{t}\mathbf{K}{t}^{\top}\right)\right)^{-1}\textrm{diag}(\textcolor{red}{\beta{t}})\\ \textcolor{purple}{\mathbf{U}{t}}&=\textcolor{purple}{\mathbf{T}{t}}\mathbf{V}{t}, \textcolor{brown}{\mathbf{K}{t}}=\textcolor{brown}{\gamma_{t}}\odot\mathbf{K}{t},\textcolor{orange}{\mathbf{S}{t}}=\textcolor{orange}{\gamma_{t}^{c}}\mathbf{S}{t},\textcolor{orange}{\gamma{t}^{c}}=\gamma_{t}^{c}\\ \textcolor{green}{\mathbf{W}{t}}&=\textcolor{green}{\gamma{t}}\odot\left(\mathbf{T}{t}\mathbf{K}{t}\right),\textcolor{green}{\mathbf{Q}{t}}=\textcolor{green}{\gamma{t}}\odot\mathbf{Q}_{t} \end{aligned} \]
具有这样的架构:
线性注意力机制总结
下次我们有时间将会继续学习总结(不知道什么时候填坑,等下次假期了):
- 线性注意力在混合注意力中的应用------Kimi KDA,Qwen GDN;
- 其他稀疏注意力机制,包括滑动窗口注意力SWA以及他们的变体,MLA->DSA稀疏注意力机制(Deepseek v3),CSA和HCA机制(Deepseek v4)
尾注碎碎念:断断续续写了快一个星期来了解和学习线性注意力的机制,公式太难推了智商跟不上,这些工作也确实厉害。以后可能就没有大段大段的时间来写笔者喜欢看的东西了。谢谢大家的陪伴。
-
Benjamin Merkel, Prefill and Decode for Concurrent Requests - Optimizing LLM Performance ↩︎ ↩︎ ↩︎
-
李航,《统计学习方法》第二版,133页到142页 ↩︎
-
Transformer Dissection: A Unified Understanding of Transformer's Attention via the Lens of Kernel ↩︎
-
Parallelizing Linear Transformers with the Delta Rule over Sequence Length ↩︎ ↩︎ ↩︎ ↩︎ ↩︎
-
Linear Transformers Are Secretly Fast Weight Programmers ↩︎ ↩︎ ↩︎
-
GATED DELTA NETWORKS: IMPROVING MAMBA2 WITH DELTA RULE ↩︎ ↩︎ ↩︎ ↩︎
-
Matrix Computation, 4th Edition, P233-P239 ↩︎