大语言模型(LLM)中 Attention 的计算复杂度:系统与数学视角
下面以**单层多头自注意力(Multi-Head Self-Attention, MHA)**为基准,分别给出时间/空间复杂度、精确到常数的 FLOPs 估算、训练与推理(含 KV Cache)阶段的差异,以及若干改进/近似注意力的复杂度对比与直觉化解释。
1) 标准(密集)自注意力的精确计算量
设序列长度为 nnn,模型维度为 ddd,头数为 hhh,每头维度 dh=d/hd_h=d/hdh=d/h。输入矩阵 X∈Rn×dX\in\mathbb{R}^{n\times d}X∈Rn×d。
1.1 线性投影
通常用三组权重将 XXX 投影为 Q,K,VQ,K,VQ,K,V:
Q=XWQ,K=XWK,V=XWV,WQ,WK,WV∈Rd×d. Q=XW_Q,\quad K=XW_K,\quad V=XW_V,\quad W_Q,W_K,W_V\in\mathbb{R}^{d\times d}. Q=XWQ,K=XWK,V=XWV,WQ,WK,WV∈Rd×d.
- 每次矩阵乘法(n×dn\times dn×d 与 d×dd\times dd×d)的乘加 FLOPs ≈2nd2\approx 2nd^2≈2nd2。
- 三个投影合计 FLOPs ≈3×2nd2=6nd2\approx 3\times 2nd^2=6nd^2≈3×2nd2=6nd2。
输出拼接后还要过一次输出投影 WO∈Rd×dW_O\in\mathbb{R}^{d\times d}WO∈Rd×d:再加 2nd22nd^22nd2 FLOPs。
线性投影总 FLOPs(与序列无关项):
8nd2 \boxed{8nd^2} 8nd2
1.2 注意力核心(按头分块计算)
对每个头 iii:
Ai=softmax (QiKi⊤dh),Oi=AiVi. A_i=\operatorname{softmax}\!\left(\frac{Q_iK_i^\top}{\sqrt{d_h}}\right),\quad O_i=A_iV_i. Ai=softmax(dh QiKi⊤),Oi=AiVi.
- QiKi⊤Q_iK_i^\topQiKi⊤: (n×dh)⋅(dh×n)⇒2n2dh(n\times d_h)\cdot(d_h\times n)\Rightarrow 2n^2d_h(n×dh)⋅(dh×n)⇒2n2dh FLOPs。
- softmax:约 O(n2)O(n^2)O(n2)(相对前两项常数级,可忽略在主项里)。
- AiViA_iV_iAiVi: (n×n)⋅(n×dh)⇒2n2dh(n\times n)\cdot(n\times d_h)\Rightarrow 2n^2d_h(n×n)⋅(n×dh)⇒2n2dh FLOPs。
对 hhh 个头求和:
Attention 核心 FLOPs≈h⋅(2n2dh+2n2dh)=4n2(hdh)=4n2d. \text{Attention 核心 FLOPs} \approx h\cdot(2n^2d_h+2n^2d_h)=4n^2(hd_h)=\boxed{4n^2d}. Attention 核心 FLOPs≈h⋅(2n2dh+2n2dh)=4n2(hdh)=4n2d.
1.3 单层总 FLOPs(前向)
FLOPs≈8nd2+4n2d \boxed{\text{FLOPs} \approx 8nd^2 + 4n^2d} FLOPs≈8nd2+4n2d
这条式子非常关键:当 n≪dn\ll dn≪d 时,8nd28nd^28nd2(投影与 MLP 类似量级)主导;当 n≫dn\gg dn≫d 时,4n2d4n^2d4n2d(注意力矩阵)主导。
临界处在 8nd2≈4n2d⇒n≈2d8nd^2 \approx 4n^2d \Rightarrow n\approx 2d8nd2≈4n2d⇒n≈2d。
注:实际 Transformer 还含有前馈网络(FFN/MLP),其 FLOPs 约为常见扩张倍数 mmm(如 4)下的 ≈2nd⋅md+2nmd⋅d≈4mnd2\approx 2nd\cdot md + 2nmd\cdot d \approx 4m nd^2≈2nd⋅md+2nmd⋅d≈4mnd2(前后两次线性层合计,忽略激活常数),常见 m=4m=4m=4 时约 ∼16nd2\sim 16nd^2∼16nd2,常常与上面的 8nd28nd^28nd2 同量级甚至更大。因此在短上下文 下,MLP 往往比注意力更贵;在超长上下文 下,注意力 n2n^2n2 项会快速成为主导。
2) 空间复杂度(内存/显存)
- 保存 Q,K,VQ,K,VQ,K,V:O(nd)O(nd)O(nd)。
- 朴素实现会显式 构造注意力矩阵 A∈Rn×nA\in\mathbb{R}^{n\times n}A∈Rn×n:O(n2)O(n^2)O(n2)。
- 训练反向需要缓存中间量与梯度:常见为 O(n2)O(n^2)O(n2) 级别的额外显存;可用**激活重计算(checkpointing)**以 \\sim 2\\times 计算换显存,降到 O~(nd)\tilde O(nd)O~(nd) 级别。
- FlashAttention 通过分块(tiling)与在线 softmax,将峰值显存由 O(n2)O(n^2)O(n2) 降到 O(nd)\boxed{O(nd)}O(nd),时间复杂度仍是 O(n2d)O(n^2d)O(n2d) 但常数显著下降(IO 友好)。
3) 自回归推理与 KV Cache
设提示长度 LLL,需生成 GGG 个新 token,总长度 T=L+GT=L+GT=L+G。
3.1 预填充(prefill,一次性算完前 LLL 个位置)
- 复杂度与训练前向同型:8Ld2+4L2d\boxed{8Ld^2+4L^2d}8Ld2+4L2d。
3.2 增量生成(decoding,步进式生成)
开启 KV Cache 时,每步只需:
- 为当前步做投影(∼8d2\sim 8d^2∼8d2 常数级 w.r.t. nnn)。
- 与历史 K,VK,VK,V 做注意力:对第 ttt 步,注意力代价 ∼O(td)\sim O(td)∼O(td)(读出并点积到长度 ttt 的缓存)。
因此从 t=Lt=Lt=L 到 L+G−1L+G-1L+G−1 的总注意力代价为:
∑t=LL+G−1O(td) = O(d⋅(L+G−1+L)G2) = O(G(L+G) d). \sum_{t=L}^{L+G-1} O(td)\;=\;O\Big(d\cdot\frac{(L+G-1+L)G}{2}\Big)\;=\;\boxed{O\big(G(L+G)\,d\big)}. t=L∑L+G−1O(td)=O(d⋅2(L+G−1+L)G)=O(G(L+G)d).
直觉:每步与越来越长的缓存相乘,步均线性涨,因此整体是二次和。
3.3 KV Cache 显存大小
每层、每 token 需要缓存 KKK 与 VVV 各 ddd 维(更精确是 h⋅dh=dh\cdot d_h=dh⋅dh=d),共 2d2d2d 个元素。
若用 FP16(2 字节/元素),每层每 token 占:
2d×2 bytes=4d bytes. \boxed{2d \times 2\text{ bytes} = 4d\text{ bytes}}. 2d×2 bytes=4d bytes.
-
举例:d=4096d=4096d=4096 时,每层每 token =4×4096=16384= 4\times 4096 = 16384=4×4096=16384 字节 ⇒16KB\Rightarrow 16\text{KB}⇒16KB。
-
若层数 N=32N=32N=32,每 token 的 KV 缓存为 16KB×32=512KB16\text{KB}\times 32=512\text{KB}16KB×32=512KB。
-
因此总 KV 显存 ≈512KB×T\approx 512\text{KB}\times T≈512KB×T。
- T=8,192T=8{,}192T=8,192(8K)时:512×8192KB=4,194,304KB≈4 GB512\times 8192\text{KB} = 4{,}194{,}304\text{KB}\approx \mathbf{4\;GB}512×8192KB=4,194,304KB≈4GB。
- T=32,768T=32{,}768T=32,768(32K)时:≈16 GB\approx 16\;GB≈16GB。
- T=131,072T=131{,}072T=131,072(128K)时:≈64 GB\approx 64\;GB≈64GB。
推理瓶颈直觉 :解码阶段往往内存带宽受限 (每步要从显存连续读大量 K,VK,VK,V),而不是纯算力受限。Flash-Decoding、PagedAttention、张量并行/流水并行的 IO 优化都在缓解这点。
3.4 MQA/GQA 对 KV 显存与吞吐的影响
- MQA(Multi-Query Attention) :多个头共享同一组 K,VK,VK,V(即 hK=hV=1h_K=h_V=1hK=hV=1),KV 缓存从 O(h)O(h)O(h) 降到 O(1)O(1)O(1),解码阶段读带宽显著下降,内存占用近似按头数缩小。
- GQA(Grouped-Query Attention) :每组头共享一组 K,VK,VK,V。若分成 ggg 组,KV 显存从 O(h)O(h)O(h) 降到 O(g)O(g)O(g)。
4) 交叉注意力(Cross-Attention)
目标序列长度 nqn_qnq,源序列长度 nkn_knk(如编码器-解码器结构):
- QK⊤QK^\topQK⊤:2nqnkdh2n_q n_k d_h2nqnkdh 每头;合计 ≈2nqnkd\approx 2n_q n_k d≈2nqnkd。
- AVAVAV:同量级,再来 2nqnkd2n_q n_k d2nqnkd。
- 合计注意力核 FLOPs ≈4nqnkd\approx \boxed{4 n_q n_k d}≈4nqnkd;再加四次 nd2nd^2nd2 级的投影项。
5) 近似/稀疏/线性注意力的复杂度对比
目标都是降低 n2n^2n2 到接近线性或 nlognn\log nnlogn 级别,同时尽量控制常数与误差。
- 局部/滑窗注意力(window size www) :
时间 O(nwd)\boxed{O(n w d)}O(nwd),显存 O(nw)\boxed{O(n w)}O(nw)。适合长文本但侧重局部依赖;可配少量全局 token。 - Block/稀疏模式(如 BigBird/Longformer) :
时间 O (n(w+g+r) d)\boxed{O\!\big(n(w+g+r)\,d\big)}O(n(w+g+r)d),其中 www=滑窗宽,ggg=全局 token 数,rrr=随机块连边数。 - LSH/Reformer :期望时间 O(nlogn⋅d)\boxed{O(n\log n\cdot d)}O(nlogn⋅d),但实现与常数较复杂。
- Nyströmformer(秩 rrr) :O(nrd+r2d)\boxed{O(nrd + r^2 d)}O(nrd+r2d);当 r≪nr\ll nr≪n 时接近线性。
- 线性注意力(核化/FAVOR+ 等) :
典型推导将
softmax(QK⊤)\mathrm{softmax}(QK^\top)softmax(QK⊤) 近似为 ϕ(Q)ϕ(K)⊤\phi(Q)\phi(K)^\topϕ(Q)ϕ(K)⊤,可先算
S=ϕ(K)⊤V∈Rdϕ×dS=\phi(K)^\top V\in\mathbb{R}^{d_\phi\times d}S=ϕ(K)⊤V∈Rdϕ×d(代价 O(ndϕd)O(nd_\phi d)O(ndϕd)),再算
ϕ(Q)S\phi(Q)Sϕ(Q)S(代价 O(ndϕd)O(nd_\phi d)O(ndϕd)),总 O(ndϕd)\boxed{O(n d_\phi d)}O(ndϕd)。
若 dϕ∼dd_\phi\sim ddϕ∼d,则为 O(nd2)O(nd^2)O(nd2)------对极长序列 更划算;但当 nnn 不大时常数未必占优。 - FlashAttention :仍是 O(n2d)O(n^2d)O(n2d) 时间 ,但将显存峰值压到 O(nd)O(nd)O(nd),并显著降低 IO,实际速度常常大幅提升。
结论 :当 nnn 远大于 ddd 时,上述方法能把主项从 n2n^2n2 降到近线性;当 nnn 与 ddd 同量级或 nnn 较小,近似法的收益变小甚至不如标准注意力(常数/误差/实现复杂度)。
6) 数学角度的要点与直觉
- 主导项来自两个 n×nn\times nn×n 的乘法 :
⟨Qi,Ki⟩\langle Q_i, K_i\rangle⟨Qi,Ki⟩ 形成 n×nn\times nn×n 的打分矩阵,以及将其与 ViV_iVi 相乘。二者分别贡献 2n2dh2n^2d_h2n2dh FLOPs/头,合计 4n2d4n^2d4n2d。 - 归一化与稳定性 :
除以 dh\sqrt{d_h}dh 保持点积分布方差稳定,避免 softmax 过陡;softmax 的计算量是 O(n2)O(n^2)O(n2) 次指数/加法/除法,通常不是主导项。 - 与快速矩阵乘法的理论界 :
若使用 Strassen/以后算法,正方形矩阵乘法可达 O(nω)O(n^\omega)O(nω), ω≈2.37\omega\approx 2.37ω≈2.37。但注意力的矩阵形状是 (n×dh)⋅(dh×n)(n\times d_h)\cdot(d_h\times n)(n×dh)⋅(dh×n),且在 GPU 上高度优化的 GEMM(经典 O(n3)O(n^3)O(n3) 常数小)更实用;工业界并不使用快速矩阵乘法来降注意力的幂指数。 - 计算-IO(算强度)视角 :
解码阶段每步需要从显存顺序读取 O(td)O(td)O(td) 的 K,VK,VK,V,但计算仅 O(td)O(td)O(td) 级点积/加和,算强度低,内存带宽成为瓶颈。Flash 系列方法通过分块/重排提升局部性与重用,效果显著。
7) 速查表(单层、前向)
-
总 FLOPs:8nd2+4n2d\boxed{8nd^2 + 4n^2d}8nd2+4n2d
- 8nd28nd^28nd2:4 次 d×dd\times dd×d 线性(Q,K,V,OQ,K,V,OQ,K,V,O)
- 4n2d4n^2d4n2d:注意力两次 n×nn\times nn×n 乘法(QK⊤QK^\topQK⊤ 与 AVAVAV)
-
显存峰值(朴素):O(n2)+O(nd)\boxed{O(n^2) + O(nd)}O(n2)+O(nd);FlashAttention:O(nd)\boxed{O(nd)}O(nd)。
-
自回归解码(含 KV Cache)总注意力成本:O(G(L+G) d)\boxed{O(G(L+G)\,d)}O(G(L+G)d);KV 显存:每层每 token 4d bytes\boxed{4d\text{ bytes}}4d bytes(FP16)。
8) 实用建议
- 短上下文(n≲dn \lesssim dn≲d):投影与 MLP 往往主导;优化点在张量并行、低精度、算子融合。
- 长上下文(n≫dn \gg dn≫d) :注意力 n2n^2n2 主导;优先考虑 FlashAttention、稀疏/滑窗/近似注意力、分块推理。
- 推理吞吐 :使用 MQA/GQA 显著降低 KV 显存与带宽;配合 Paged KV、分块解码提升大上下文的可扩展性。
- 训练显存 :激活检查点 + FlashAttention 能将注意力显存从 O(n2)O(n^2)O(n2) 压至近 O(nd)O(nd)O(nd)。