详细分析大语言模型attention的计算复杂度,从数学角度分析

大语言模型(LLM)中 Attention 的计算复杂度:系统与数学视角

下面以**单层多头自注意力(Multi-Head Self-Attention, MHA)**为基准,分别给出时间/空间复杂度、精确到常数的 FLOPs 估算、训练与推理(含 KV Cache)阶段的差异,以及若干改进/近似注意力的复杂度对比与直觉化解释。


1) 标准(密集)自注意力的精确计算量

设序列长度为 nnn,模型维度为 ddd,头数为 hhh,每头维度 dh=d/hd_h=d/hdh=d/h。输入矩阵 X∈Rn×dX\in\mathbb{R}^{n\times d}X∈Rn×d。

1.1 线性投影

通常用三组权重将 XXX 投影为 Q,K,VQ,K,VQ,K,V:

Q=XWQ,K=XWK,V=XWV,WQ,WK,WV∈Rd×d. Q=XW_Q,\quad K=XW_K,\quad V=XW_V,\quad W_Q,W_K,W_V\in\mathbb{R}^{d\times d}. Q=XWQ,K=XWK,V=XWV,WQ,WK,WV∈Rd×d.

  • 每次矩阵乘法(n×dn\times dn×d 与 d×dd\times dd×d)的乘加 FLOPs ≈2nd2\approx 2nd^2≈2nd2。
  • 三个投影合计 FLOPs ≈3×2nd2=6nd2\approx 3\times 2nd^2=6nd^2≈3×2nd2=6nd2。

输出拼接后还要过一次输出投影 WO∈Rd×dW_O\in\mathbb{R}^{d\times d}WO∈Rd×d:再加 2nd22nd^22nd2 FLOPs。

线性投影总 FLOPs(与序列无关项)

8nd2 \boxed{8nd^2} 8nd2

1.2 注意力核心(按头分块计算)

对每个头 iii:

Ai=softmax⁡ ⁣(QiKi⊤dh),Oi=AiVi. A_i=\operatorname{softmax}\!\left(\frac{Q_iK_i^\top}{\sqrt{d_h}}\right),\quad O_i=A_iV_i. Ai=softmax(dh QiKi⊤),Oi=AiVi.

  • QiKi⊤Q_iK_i^\topQiKi⊤: (n×dh)⋅(dh×n)⇒2n2dh(n\times d_h)\cdot(d_h\times n)\Rightarrow 2n^2d_h(n×dh)⋅(dh×n)⇒2n2dh FLOPs。
  • softmax:约 O(n2)O(n^2)O(n2)(相对前两项常数级,可忽略在主项里)。
  • AiViA_iV_iAiVi: (n×n)⋅(n×dh)⇒2n2dh(n\times n)\cdot(n\times d_h)\Rightarrow 2n^2d_h(n×n)⋅(n×dh)⇒2n2dh FLOPs。

对 hhh 个头求和:

Attention 核心 FLOPs≈h⋅(2n2dh+2n2dh)=4n2(hdh)=4n2d. \text{Attention 核心 FLOPs} \approx h\cdot(2n^2d_h+2n^2d_h)=4n^2(hd_h)=\boxed{4n^2d}. Attention 核心 FLOPs≈h⋅(2n2dh+2n2dh)=4n2(hdh)=4n2d.

1.3 单层总 FLOPs(前向)

FLOPs≈8nd2+4n2d \boxed{\text{FLOPs} \approx 8nd^2 + 4n^2d} FLOPs≈8nd2+4n2d

这条式子非常关键:当 n≪dn\ll dn≪d 时,8nd28nd^28nd2(投影与 MLP 类似量级)主导;当 n≫dn\gg dn≫d 时,4n2d4n^2d4n2d(注意力矩阵)主导。

临界处在 8nd2≈4n2d⇒n≈2d8nd^2 \approx 4n^2d \Rightarrow n\approx 2d8nd2≈4n2d⇒n≈2d。

注:实际 Transformer 还含有前馈网络(FFN/MLP),其 FLOPs 约为常见扩张倍数 mmm(如 4)下的 ≈2nd⋅md+2nmd⋅d≈4mnd2\approx 2nd\cdot md + 2nmd\cdot d \approx 4m nd^2≈2nd⋅md+2nmd⋅d≈4mnd2(前后两次线性层合计,忽略激活常数),常见 m=4m=4m=4 时约 ∼16nd2\sim 16nd^2∼16nd2,常常与上面的 8nd28nd^28nd2 同量级甚至更大。因此在短上下文 下,MLP 往往比注意力更贵;在超长上下文 下,注意力 n2n^2n2 项会快速成为主导。


2) 空间复杂度(内存/显存)

  • 保存 Q,K,VQ,K,VQ,K,V:O(nd)O(nd)O(nd)。
  • 朴素实现会显式 构造注意力矩阵 A∈Rn×nA\in\mathbb{R}^{n\times n}A∈Rn×n:O(n2)O(n^2)O(n2)。
  • 训练反向需要缓存中间量与梯度:常见为 O(n2)O(n^2)O(n2) 级别的额外显存;可用**激活重计算(checkpointing)**以 \\sim 2\\times 计算换显存,降到 O~(nd)\tilde O(nd)O~(nd) 级别。
  • FlashAttention 通过分块(tiling)与在线 softmax,将峰值显存由 O(n2)O(n^2)O(n2) 降到 O(nd)\boxed{O(nd)}O(nd),时间复杂度仍是 O(n2d)O(n^2d)O(n2d) 但常数显著下降(IO 友好)。

3) 自回归推理与 KV Cache

设提示长度 LLL,需生成 GGG 个新 token,总长度 T=L+GT=L+GT=L+G。

3.1 预填充(prefill,一次性算完前 LLL 个位置)

  • 复杂度与训练前向同型:8Ld2+4L2d\boxed{8Ld^2+4L^2d}8Ld2+4L2d。

3.2 增量生成(decoding,步进式生成)

开启 KV Cache 时,每步只需:

  • 为当前步做投影(∼8d2\sim 8d^2∼8d2 常数级 w.r.t. nnn)。
  • 与历史 K,VK,VK,V 做注意力:对第 ttt 步,注意力代价 ∼O(td)\sim O(td)∼O(td)(读出并点积到长度 ttt 的缓存)。

因此从 t=Lt=Lt=L 到 L+G−1L+G-1L+G−1 的总注意力代价为:

∑t=LL+G−1O(td)  =  O(d⋅(L+G−1+L)G2)  =  O(G(L+G) d). \sum_{t=L}^{L+G-1} O(td)\;=\;O\Big(d\cdot\frac{(L+G-1+L)G}{2}\Big)\;=\;\boxed{O\big(G(L+G)\,d\big)}. t=L∑L+G−1O(td)=O(d⋅2(L+G−1+L)G)=O(G(L+G)d).

直觉:每步与越来越长的缓存相乘,步均线性涨,因此整体是二次和。

3.3 KV Cache 显存大小

每层、每 token 需要缓存 KKK 与 VVV 各 ddd 维(更精确是 h⋅dh=dh\cdot d_h=dh⋅dh=d),共 2d2d2d 个元素。

若用 FP16(2 字节/元素),每层每 token 占:

2d×2 bytes=4d bytes. \boxed{2d \times 2\text{ bytes} = 4d\text{ bytes}}. 2d×2 bytes=4d bytes.

  • 举例:d=4096d=4096d=4096 时,每层每 token =4×4096=16384= 4\times 4096 = 16384=4×4096=16384 字节 ⇒16KB\Rightarrow 16\text{KB}⇒16KB。

  • 若层数 N=32N=32N=32,每 token 的 KV 缓存为 16KB×32=512KB16\text{KB}\times 32=512\text{KB}16KB×32=512KB。

  • 因此总 KV 显存 ≈512KB×T\approx 512\text{KB}\times T≈512KB×T。

    • T=8,192T=8{,}192T=8,192(8K)时:512×8192KB=4,194,304KB≈4  GB512\times 8192\text{KB} = 4{,}194{,}304\text{KB}\approx \mathbf{4\;GB}512×8192KB=4,194,304KB≈4GB。
    • T=32,768T=32{,}768T=32,768(32K)时:≈16  GB\approx 16\;GB≈16GB。
    • T=131,072T=131{,}072T=131,072(128K)时:≈64  GB\approx 64\;GB≈64GB。

推理瓶颈直觉 :解码阶段往往内存带宽受限 (每步要从显存连续读大量 K,VK,VK,V),而不是纯算力受限。Flash-Decoding、PagedAttention、张量并行/流水并行的 IO 优化都在缓解这点。

3.4 MQA/GQA 对 KV 显存与吞吐的影响

  • MQA(Multi-Query Attention) :多个头共享同一组 K,VK,VK,V(即 hK=hV=1h_K=h_V=1hK=hV=1),KV 缓存从 O(h)O(h)O(h) 降到 O(1)O(1)O(1),解码阶段读带宽显著下降,内存占用近似按头数缩小。
  • GQA(Grouped-Query Attention) :每组头共享一组 K,VK,VK,V。若分成 ggg 组,KV 显存从 O(h)O(h)O(h) 降到 O(g)O(g)O(g)。

4) 交叉注意力(Cross-Attention)

目标序列长度 nqn_qnq,源序列长度 nkn_knk(如编码器-解码器结构):

  • QK⊤QK^\topQK⊤:2nqnkdh2n_q n_k d_h2nqnkdh 每头;合计 ≈2nqnkd\approx 2n_q n_k d≈2nqnkd。
  • AVAVAV:同量级,再来 2nqnkd2n_q n_k d2nqnkd。
  • 合计注意力核 FLOPs ≈4nqnkd\approx \boxed{4 n_q n_k d}≈4nqnkd;再加四次 nd2nd^2nd2 级的投影项。

5) 近似/稀疏/线性注意力的复杂度对比

目标都是降低 n2n^2n2 到接近线性或 nlog⁡nn\log nnlogn 级别,同时尽量控制常数与误差。

  • 局部/滑窗注意力(window size www)
    时间 O(nwd)\boxed{O(n w d)}O(nwd),显存 O(nw)\boxed{O(n w)}O(nw)。适合长文本但侧重局部依赖;可配少量全局 token。
  • Block/稀疏模式(如 BigBird/Longformer)
    时间 O ⁣(n(w+g+r) d)\boxed{O\!\big(n(w+g+r)\,d\big)}O(n(w+g+r)d),其中 www=滑窗宽,ggg=全局 token 数,rrr=随机块连边数。
  • LSH/Reformer :期望时间 O(nlog⁡n⋅d)\boxed{O(n\log n\cdot d)}O(nlogn⋅d),但实现与常数较复杂。
  • Nyströmformer(秩 rrr) :O(nrd+r2d)\boxed{O(nrd + r^2 d)}O(nrd+r2d);当 r≪nr\ll nr≪n 时接近线性。
  • 线性注意力(核化/FAVOR+ 等)
    典型推导将
    softmax(QK⊤)\mathrm{softmax}(QK^\top)softmax(QK⊤) 近似为 ϕ(Q)ϕ(K)⊤\phi(Q)\phi(K)^\topϕ(Q)ϕ(K)⊤,可先算
    S=ϕ(K)⊤V∈Rdϕ×dS=\phi(K)^\top V\in\mathbb{R}^{d_\phi\times d}S=ϕ(K)⊤V∈Rdϕ×d(代价 O(ndϕd)O(nd_\phi d)O(ndϕd)),再算
    ϕ(Q)S\phi(Q)Sϕ(Q)S(代价 O(ndϕd)O(nd_\phi d)O(ndϕd)),总 O(ndϕd)\boxed{O(n d_\phi d)}O(ndϕd)。
    若 dϕ∼dd_\phi\sim ddϕ∼d,则为 O(nd2)O(nd^2)O(nd2)------对极长序列 更划算;但当 nnn 不大时常数未必占优。
  • FlashAttention仍是 O(n2d)O(n^2d)O(n2d) 时间 ,但将显存峰值压到 O(nd)O(nd)O(nd),并显著降低 IO,实际速度常常大幅提升。

结论 :当 nnn 远大于 ddd 时,上述方法能把主项从 n2n^2n2 降到近线性;当 nnn 与 ddd 同量级或 nnn 较小,近似法的收益变小甚至不如标准注意力(常数/误差/实现复杂度)。


6) 数学角度的要点与直觉

  1. 主导项来自两个 n×nn\times nn×n 的乘法
    ⟨Qi,Ki⟩\langle Q_i, K_i\rangle⟨Qi,Ki⟩ 形成 n×nn\times nn×n 的打分矩阵,以及将其与 ViV_iVi 相乘。二者分别贡献 2n2dh2n^2d_h2n2dh FLOPs/头,合计 4n2d4n^2d4n2d。
  2. 归一化与稳定性
    除以 dh\sqrt{d_h}dh 保持点积分布方差稳定,避免 softmax 过陡;softmax 的计算量是 O(n2)O(n^2)O(n2) 次指数/加法/除法,通常不是主导项。
  3. 与快速矩阵乘法的理论界
    若使用 Strassen/以后算法,正方形矩阵乘法可达 O(nω)O(n^\omega)O(nω), ω≈2.37\omega\approx 2.37ω≈2.37。但注意力的矩阵形状是 (n×dh)⋅(dh×n)(n\times d_h)\cdot(d_h\times n)(n×dh)⋅(dh×n),且在 GPU 上高度优化的 GEMM(经典 O(n3)O(n^3)O(n3) 常数小)更实用;工业界并不使用快速矩阵乘法来降注意力的幂指数。
  4. 计算-IO(算强度)视角
    解码阶段每步需要从显存顺序读取 O(td)O(td)O(td) 的 K,VK,VK,V,但计算仅 O(td)O(td)O(td) 级点积/加和,算强度低,内存带宽成为瓶颈。Flash 系列方法通过分块/重排提升局部性与重用,效果显著。

7) 速查表(单层、前向)

  • 总 FLOPs:8nd2+4n2d\boxed{8nd^2 + 4n^2d}8nd2+4n2d

    • 8nd28nd^28nd2:4 次 d×dd\times dd×d 线性(Q,K,V,OQ,K,V,OQ,K,V,O)
    • 4n2d4n^2d4n2d:注意力两次 n×nn\times nn×n 乘法(QK⊤QK^\topQK⊤ 与 AVAVAV)
  • 显存峰值(朴素):O(n2)+O(nd)\boxed{O(n^2) + O(nd)}O(n2)+O(nd);FlashAttention:O(nd)\boxed{O(nd)}O(nd)。

  • 自回归解码(含 KV Cache)总注意力成本:O(G(L+G) d)\boxed{O(G(L+G)\,d)}O(G(L+G)d);KV 显存:每层每 token 4d bytes\boxed{4d\text{ bytes}}4d bytes(FP16)。


8) 实用建议

  • 短上下文(n≲dn \lesssim dn≲d):投影与 MLP 往往主导;优化点在张量并行、低精度、算子融合。
  • 长上下文(n≫dn \gg dn≫d) :注意力 n2n^2n2 主导;优先考虑 FlashAttention、稀疏/滑窗/近似注意力、分块推理。
  • 推理吞吐 :使用 MQA/GQA 显著降低 KV 显存与带宽;配合 Paged KV、分块解码提升大上下文的可扩展性。
  • 训练显存 :激活检查点 + FlashAttention 能将注意力显存从 O(n2)O(n^2)O(n2) 压至近 O(nd)O(nd)O(nd)。
相关推荐
MiaoChuAI3 分钟前
豆包AI PPT与秒出PPT对比评测:谁更适合你?
人工智能·powerpoint
%KT%21 分钟前
简单聊聊多模态大语言模型MLLM
人工智能·语言模型·自然语言处理
唐某人丶29 分钟前
教你如何用 JS 实现一个 Agent 系统(1)—— 认识 Agentic System
前端·人工智能
泡泡茶壶_ovo36 分钟前
RORPCAP: retrieval-based objects and relations prompt for image captioning
人工智能·深度学习·计算机视觉·语言模型·prompt·多模态·imagecaptioning
MaxCode-140 分钟前
单智能体篇:Prompt工程艺术
大数据·人工智能·prompt
小鹿的工作手帐1 小时前
有鹿机器人:智慧清洁新时代的引领者
人工智能·科技·机器人
这张生成的图像能检测吗1 小时前
(论文速读)Logits DeConfusion-CLIP少样本学习
人工智能·计算机视觉·图像分类·clip
居然JuRan2 小时前
RAG系统开发中的12大痛点及应对策略
人工智能
sinat_286945192 小时前
AI服务器介绍
服务器·人工智能·算法·chatgpt·transformer
Kusunoki_D2 小时前
PyTorch 环境配置
人工智能·pytorch·python