78-机器学习与大模型开发数学教程-7-6 自注意力机制的计算复杂度分析

「ml-llm-math.zip」

链接:https://pan.quark.cn/s/35fc37047e5e

------为什么"长序列很贵",以及我们能做些什么

一句话版:标准自注意力的时间/显存复杂度都随序列长度 T T T 呈二次增长 ( ∼ O ( T 2 ) \sim O(T^2) ∼O(T2))。训练时,显存常被 T × T T\times T T×T 的注意力矩阵与中间激活占满;推理时,虽然可用 KV Cache 把时间复杂度降到线性按步增长 ,但显存仍随历史长度线性增长 。解决之道要么减少二次项 (稀疏/低秩/核化),要么减少 I/O(FlashAttention、重算、检查点等)。


1. 先把"账"算清楚:每一步都要花什么钱?

设批大小 B B B,长度 T T T,隐藏维 d d d,头数 H H H,每头维 d h = d / H d_h=d/H dh=d/H。

1.1 自注意力(单层,训练前向)的主要开销

  • Q , K , V Q,K,V Q,K,V 线性变换:

    X ∈ R B × T × d → ( Q , K , V ) ∈ R B × T × d X\in\mathbb{R}^{B\times T\times d}\ \to\ (Q,K,V)\in\mathbb{R}^{B\times T\times d} X∈RB×T×d → (Q,K,V)∈RB×T×d

    FLOPs ≈ 3 ⋅ B ⋅ T ⋅ d ⋅ d \approx 3\cdot B\cdot T\cdot d\cdot d ≈3⋅B⋅T⋅d⋅d(三次 d  ⁣ →  ⁣ d d\!\to\! d d→d 线性)。

  • 打分与权重(对所有头汇总后):

    S = Q K ⊤ d h ⇒ O ( B ⋅ H ⋅ T 2 ⋅ d h ) = O ( B ⋅ T 2 ⋅ d ) S=\frac{QK^\top}{\sqrt{d_h}} \quad\Rightarrow\quad O(B\cdot H\cdot T^2\cdot d_h)=O(B\cdot T^2\cdot d) S=dh QK⊤⇒O(B⋅H⋅T2⋅dh)=O(B⋅T2⋅d)

    A = s o f t m a x ( S ) ⇒ O ( B ⋅ T 2 ) A=\mathrm{softmax}(S)\quad\Rightarrow\quad O(B\cdot T^2) A=softmax(S)⇒O(B⋅T2)

  • 加权求和:

    Y = A   V ⇒ O ( B ⋅ H ⋅ T 2 ⋅ d h ) = O ( B ⋅ T 2 ⋅ d ) Y=A\,V \quad\Rightarrow\quad O(B\cdot H\cdot T^2\cdot d_h)=O(B\cdot T^2\cdot d) Y=AV⇒O(B⋅H⋅T2⋅dh)=O(B⋅T2⋅d)

  • 多头拼接 + 输出投影 W O W_O WO:

    ≈ B ⋅ T ⋅ d ⋅ d \approx B\cdot T\cdot d\cdot d ≈B⋅T⋅d⋅d。

合计主项(忽略常数):

Attn FLOPs ∼ O ( B   T 2   d ) + O ( B   T   d 2 ) \boxed{\text{Attn FLOPs}\ \sim\ O(B\,T^2\,d)\ +\ O(B\,T\,d^2)} Attn FLOPs ∼ O(BT2d) + O(BTd2)

其中 O ( B   T 2   d ) O(B\,T^2\,d) O(BT2d) 来自 Q K ⊤ QK^\top QK⊤ 与 A V A V AV, O ( B   T   d 2 ) O(B\,T\,d^2) O(BTd2) 来自线性投影。

因果 Mask 只把 T 2 T^2 T2 前的常数减半(上三角/下三角),级别仍是 O ( T 2 ) O(T^2) O(T2)

1.2 前馈网络(FFN)

两层按位置独立的 MLP,设 d ff ≈ 4 d d_{\text{ff}}\approx 4d dff≈4d:

FFN FLOPs ≈ 2 ⋅ B ⋅ T ⋅ d ⋅ d ff ≈ 8 ⋅ B ⋅ T ⋅ d 2 . \text{FFN FLOPs}\ \approx 2\cdot B\cdot T\cdot d\cdot d_{\text{ff}}\ \approx\ 8\cdot B\cdot T\cdot d^2. FFN FLOPs ≈2⋅B⋅T⋅d⋅dff ≈ 8⋅B⋅T⋅d2.

1.3 谁在"主宰"算力:注意力 vs FFN?

比较阶数:

  • 注意力主项: ∼ 2   B   T 2   d \sim 2\,B\,T^2\,d ∼2BT2d
  • FFN 主项: ∼ 8   B   T   d 2 \sim 8\,B\,T\,d^2 ∼8BTd2

二者相当的拐点

2   T 2   d ≈ 8   T   d 2 ⇒ T ≈ 4 d . 2\,T^2\,d \approx 8\,T\,d^2\quad\Rightarrow\quad \boxed{T \approx 4d}. 2T2d≈8Td2⇒T≈4d.

  • 当 T ≪ 4 d T \ll 4d T≪4d:FFN 更占 FLOPs(小序列/大维度)。
  • 当 T ≫ 4 d T \gg 4d T≫4d:注意力成为瓶颈(长序列场景)。

例: d = 1024 d=1024 d=1024 时拐点在 T ≈ 4096 T\approx 4096 T≈4096。当 T = 2 k T=2\text{k} T=2k 时,多数算力花在 FFN;当 T = 8 k T=8\text{k} T=8k 时,注意力成为主导。


2. 显存(内存)复杂度:是谁在吃内存?

训练时需要存下激活以便反传:

  • 注意力权重 A A A:尺寸 B × H × T × T B\times H\times T\times T B×H×T×T,显存 ∼ O ( B   H   T 2 ) \sim O(B\,H\,T^2) ∼O(BHT2)
  • Q , K , V , Y Q,K,V,Y Q,K,V,Y:尺寸 B × H × T × d h B\times H\times T\times d_h B×H×T×dh,合计 ∼ O ( B   T   d ) \sim O(B\,T\,d) ∼O(BTd)。
  • FFN 中间激活(如 GELU 前后): ∼ O ( B   T   d ff ) \sim O(B\,T\,d_{\text{ff}}) ∼O(BTdff)。

粗略记忆:二次项 只有注意力矩阵 A A A(或等价中间张量),其余基本是线性。

2.1 一个具体算例(只看注意力矩阵)

  • B = 1 , H = 16 , T = 4096 B=1,\ H=16,\ T=4096 B=1, H=16, T=4096。
  • A A A 的元素数: 16 × 4096 2 = 268,435,456 16\times 4096^2 = 268{,}435{,}456 16×40962=268,435,456。
  • 若 FP16(2 字节),仅 A A A 就占 ≈ 512   MiB \approx 512\,\text{MiB} ≈512MiB。

这还没算 Q , K , V Q,K,V Q,K,V 与 FFN 激活,多层叠加后你就知道为什么 24GB 显存不经用

2.2 训练期的"省内存"手段

  • 激活检查点 (checkpointing):丢弃一部分中间激活,反向时重算 ,以时间换空间
  • FlashAttention :分块计算 Q K ⊤ QK^\top QK⊤ / softmax / A V A V AV,避免显式存 A A A ,将显存从 O ( T 2 ) O(T^2) O(T2) 降到接近 O ( T   d ) O(T\,d) O(Td),同时减少 HBM 读写(I/O 成本)。
  • 混合精度(bf16/fp16)+ 归约用 fp32 保精度。

3. 推理复杂度:KV Cache 的"线性按步增长"

自回归生成第 t t t 个 token 时:

  • 计算当前 q t q_t qt: O ( d 2 ) O(d^2) O(d2)(线性层)。
  • 与历史 K 1 : t K_{1:t} K1:t 做点积:每头 O ( t ⋅ d h ) O(t\cdot d_h) O(t⋅dh),合计 O ( t ⋅ d ) O(t\cdot d) O(t⋅d)。
  • 读出 V 1 : t V_{1:t} V1:t 做加权和:同阶 O ( t ⋅ d ) O(t\cdot d) O(t⋅d)。
  • KV Cache 显存 :每步存 ( k t , v t ) (k_t,v_t) (kt,vt),总计 ∼ O ( H ⋅ T ⋅ d h ) = O ( T ⋅ d ) \sim O(H\cdot T\cdot d_h)=O(T\cdot d) ∼O(H⋅T⋅dh)=O(T⋅d)。

Per-token Time ∼ O ( d 2 ) + O ( t   d ) , KV 显存 ∼ O ( T   d ) \boxed{\text{Per-token Time}\ \sim\ O(d^2) + O(t\,d),\quad \text{KV 显存}\ \sim\ O(T\,d)} Per-token Time ∼ O(d2)+O(td),KV 显存 ∼ O(Td)

当 t t t 很大时(长上下文),带宽与缓存大小 成为瓶颈,吞吐随 t t t 下降。

3.1 降低推理显存/带宽的工程方案

  • MQA/GQA (Multi-Query / Grouped-Query):多个头共享 K/V (或按组共享),将缓存从 O ( H   T   d h ) O(H\,T\,d_h) O(HTdh) 降到 O ( T   d h ) O(T\,d_h) O(Tdh) 或 O ( H g   T   d h ) O(\frac{H}{g}\,T\,d_h) O(gHTdh)。
  • KV 压缩/量化 :8-bit/4-bit KV;或滑窗/裁剪远端缓存。
  • 分块解码 :把点积与加权和做成流式分块,降低一次性带宽峰值。

4. 交叉注意力与复杂度形状

在编码器-解码器结构中,交叉注意力打分矩阵尺寸是 T dec × T enc T_{\text{dec}}\times T_{\text{enc}} Tdec×Tenc,

复杂度 ∼ O ( B ⋅ T dec ⋅ T enc ⋅ d ) \sim O(B\cdot T_{\text{dec}} \cdot T_{\text{enc}} \cdot d) ∼O(B⋅Tdec⋅Tenc⋅d)。

当源序列很长(例如多段文档检索)时,交叉注意力也会成为瓶颈


5. 一个"数量级感觉"的 FLOPs 例子

取 B = 1 , T = 4096 , d = 1024 B=1,\ T=4096,\ d=1024 B=1, T=4096, d=1024。

  • 注意力主项:
    2 ⋅ T 2 ⋅ d ≈ 2 ⋅ 16,777,216 ⋅ 1024 ≈ 3.44 × 10 10 2\cdot T^2\cdot d \approx 2\cdot 16{,}777{,}216\cdot 1024 \approx \mathbf{3.44\times 10^{10}} 2⋅T2⋅d≈2⋅16,777,216⋅1024≈3.44×1010 次乘加( ∼ 34 \sim 34 ∼34 GFLOPs)。
  • FFN 主项( d ff = 4096 d_{\text{ff}}=4096 dff=4096):
    2 ⋅ T ⋅ d ⋅ d ff ≈ 2 ⋅ 4096 ⋅ 1024 ⋅ 4096 ≈ 3.44 × 10 10 2\cdot T\cdot d\cdot d_{\text{ff}} \approx 2\cdot 4096\cdot 1024\cdot 4096 \approx \mathbf{3.44\times 10^{10}} 2⋅T⋅d⋅dff≈2⋅4096⋅1024⋅4096≈3.44×1010(巧合相当,因为 T ≈ 4 d T\approx 4d T≈4d 正在拐点)。

这只是一层,多层相乘即可得到粗算总 FLOPs。真实实现里还有常数项、I/O 与 kernel 启动开销。


6. 用 Mermaid 看"复杂度分解"与"选择指南"

6.1 复杂度分解小图

#mermaid-svg-jPLhks20LoaC1rs0{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-jPLhks20LoaC1rs0 .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-jPLhks20LoaC1rs0 .error-icon{fill:#552222;}#mermaid-svg-jPLhks20LoaC1rs0 .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-jPLhks20LoaC1rs0 .marker{fill:#333333;stroke:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 .marker.cross{stroke:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-jPLhks20LoaC1rs0 p{margin:0;}#mermaid-svg-jPLhks20LoaC1rs0 .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster-label text{fill:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster-label span{color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster-label span p{background-color:transparent;}#mermaid-svg-jPLhks20LoaC1rs0 .label text,#mermaid-svg-jPLhks20LoaC1rs0 span{fill:#333;color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .node rect,#mermaid-svg-jPLhks20LoaC1rs0 .node circle,#mermaid-svg-jPLhks20LoaC1rs0 .node ellipse,#mermaid-svg-jPLhks20LoaC1rs0 .node polygon,#mermaid-svg-jPLhks20LoaC1rs0 .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .rough-node .label text,#mermaid-svg-jPLhks20LoaC1rs0 .node .label text,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape .label,#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape .label{text-anchor:middle;}#mermaid-svg-jPLhks20LoaC1rs0 .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .rough-node .label,#mermaid-svg-jPLhks20LoaC1rs0 .node .label,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape .label,#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape .label{text-align:center;}#mermaid-svg-jPLhks20LoaC1rs0 .node.clickable{cursor:pointer;}#mermaid-svg-jPLhks20LoaC1rs0 .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 .arrowheadPath{fill:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-jPLhks20LoaC1rs0 .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-jPLhks20LoaC1rs0 .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-jPLhks20LoaC1rs0 .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-jPLhks20LoaC1rs0 .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-jPLhks20LoaC1rs0 .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-jPLhks20LoaC1rs0 .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster text{fill:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster span{color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-jPLhks20LoaC1rs0 .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-jPLhks20LoaC1rs0 rect.text{fill:none;stroke-width:0;}#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape p,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape .label rect,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-jPLhks20LoaC1rs0 .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-jPLhks20LoaC1rs0 .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-jPLhks20LoaC1rs0 :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 输入: B,T,d,H
线性投影: O(B*T*d^2)
打分 QK^T: O(B*T^2*d)
softmax: O(B*T^2)
加权和 A*V: O(B*T^2*d)
输出投影: O(B*T*d^2)
输出: O(B*T*d)

说明 :二次项集中在 QK^TA*V ,其余为线性或 d 2 d^2 d2 项。

6.2 长序列"怎么选"决策图

#mermaid-svg-FwUTl91MmEnpaiwk{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-FwUTl91MmEnpaiwk .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-FwUTl91MmEnpaiwk .error-icon{fill:#552222;}#mermaid-svg-FwUTl91MmEnpaiwk .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-FwUTl91MmEnpaiwk .marker{fill:#333333;stroke:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk .marker.cross{stroke:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-FwUTl91MmEnpaiwk p{margin:0;}#mermaid-svg-FwUTl91MmEnpaiwk .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster-label text{fill:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster-label span{color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster-label span p{background-color:transparent;}#mermaid-svg-FwUTl91MmEnpaiwk .label text,#mermaid-svg-FwUTl91MmEnpaiwk span{fill:#333;color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .node rect,#mermaid-svg-FwUTl91MmEnpaiwk .node circle,#mermaid-svg-FwUTl91MmEnpaiwk .node ellipse,#mermaid-svg-FwUTl91MmEnpaiwk .node polygon,#mermaid-svg-FwUTl91MmEnpaiwk .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .rough-node .label text,#mermaid-svg-FwUTl91MmEnpaiwk .node .label text,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape .label,#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape .label{text-anchor:middle;}#mermaid-svg-FwUTl91MmEnpaiwk .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .rough-node .label,#mermaid-svg-FwUTl91MmEnpaiwk .node .label,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape .label,#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape .label{text-align:center;}#mermaid-svg-FwUTl91MmEnpaiwk .node.clickable{cursor:pointer;}#mermaid-svg-FwUTl91MmEnpaiwk .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk .arrowheadPath{fill:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-FwUTl91MmEnpaiwk .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-FwUTl91MmEnpaiwk .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-FwUTl91MmEnpaiwk .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-FwUTl91MmEnpaiwk .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-FwUTl91MmEnpaiwk .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-FwUTl91MmEnpaiwk .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster text{fill:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster span{color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-FwUTl91MmEnpaiwk .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-FwUTl91MmEnpaiwk rect.text{fill:none;stroke-width:0;}#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape p,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape .label rect,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-FwUTl91MmEnpaiwk .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-FwUTl91MmEnpaiwk .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-FwUTl91MmEnpaiwk :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 是



序列很长?
必须全局依赖?
标准注意力即可
允许近似?
FlashAttn/分布式/减小批量
局部/滑窗注意力 O(T*w*d)
核化/线性注意力 O(T*d^2)
低秩/投影(如 Nyström, Linformer)
混合: 局部 + 少量全局

说明

  • 必须全局精确 :优先 FlashAttention、分布式训练、减小 batch 或层数。
  • 允许近似 :选择局部/稀疏低秩 、或核化路线,视任务而定。

7. 降复杂与提效的几条路线图

7.1 稀疏/局部模式(保留 O ( T   w ) O(T\,w) O(Tw))

  • 滑窗/块稀疏 (Longformer/Sliding Window):每个 token 只看邻域 w w w。
    复杂度 ∼ O ( B   T   w   d ) \sim O(B\,T\,w\,d) ∼O(BTwd),若 w ≪ T w\ll T w≪T 则远小于 T 2 T^2 T2。
  • 混合全局(BigBird 等):局部 + 随机 + 若干全局 token,保持理论可表达性。

7.2 低秩/投影近似

  • Linformer :把 K , V K,V K,V 的长度维用投影压到 r ≪ T r\ll T r≪T,复杂度 ∼ O ( B   T   r   d ) \sim O(B\,T\,r\,d) ∼O(BTrd)。
  • Nyströmformer :用 r r r 个"地标"近似注意力核,同阶 ∼ O ( B   T   r   d ) \sim O(B\,T\,r\,d) ∼O(BTrd)。

7.3 核化/线性注意力

  • 令 s o f t m a x ( Q K ⊤ ) ≈ ϕ ( Q ) ϕ ( K ) ⊤ \mathrm{softmax}(QK^\top)\approx \phi(Q)\phi(K)^\top softmax(QK⊤)≈ϕ(Q)ϕ(K)⊤,则

    ϕ ( K ) ⊤ V ∈ O ( B   T   d ⋅ d ϕ ) , ϕ ( Q ) ⋅ ( ⋅ ) ∈ O ( B   T   d ⋅ d ϕ ) \phi(K)^\top V \in O(B\,T\,d\cdot d_\phi),\quad \phi(Q)\cdot(\cdot) \in O(B\,T\,d\cdot d_\phi) ϕ(K)⊤V∈O(BTd⋅dϕ),ϕ(Q)⋅(⋅)∈O(BTd⋅dϕ)

    总体 ∼ O ( B   T   d ⋅ d ϕ ) \sim O(B\,T\,d\cdot d_\phi) ∼O(BTd⋅dϕ),线性于 T T T

  • 代表:Performer、Linear Transformers 等。注意 :近似误差与 ϕ \phi ϕ 的选择有关。

7.4 IO-aware:把"搬数据"的代价降下来

  • FlashAttention :分块计算,不显式存 A A A ,降低 HBM 往返;复杂度级别不变 (仍是 T 2 T^2 T2),但速度/可训练长度大幅改观。
  • 融合 kernel:scale + mask + softmax + dropout + matmul 融合,少 kernel 启动。
  • 序列并行/张量并行 :跨设备分摊 T T T 或 d d d,减单卡峰值显存。

8. 工程实战清单(训练 & 推理)

训练侧

  • 长序列先上 FlashAttention + bf16 + 归约 fp32
  • 激活检查点 配合 重算;必要时关掉注意力 dropout 稳数值。
  • clip grad-norm,防止长序列引发梯度爆炸。
  • 善用 梯度累积 替代超大 batch。
  • 若任务允许,用局部/稀疏注意力低秩投影 尝试把 T 2 T^2 T2 打碎。

推理侧

  • KV Cache 预分配;优先 MQA/GQA 降内存与带宽。
  • Paged KV / 滑窗裁剪远端上下文。
  • 分块点积(blockwise)避免一次性读取过长 KV。
  • 合理的 prefill:decode 比例与批内并行(例如把多条请求对齐 prefill)。

9. 小练习(带提示)

  1. 门槛判断 :给定 d = 1536 d=1536 d=1536,估计注意力与 FFN FLOPs 的拐点 T T T;试计算 T = 8 k T=8\text{k} T=8k 时两者的比值。
  2. 显存估算 : B = 2 , H = 8 , T = 8192 B=2,H=8,T=8192 B=2,H=8,T=8192 时,FP16 下注意力矩阵 A A A 占用多大?如果用 FlashAttention,会省掉多少显存级别的张量?
  3. 窗口选择 :在一个长文分类任务上,尝试 w ∈ { 128 , 256 , 512 } w\in\{128,256,512\} w∈{128,256,512} 的滑窗注意力,比较准确率与吞吐。
  4. 核化近似误差 :实现一个随机特征 ϕ \phi ϕ 的线性注意力,与标准 softmax 注意力在相同 Q , K , V Q,K,V Q,K,V 上比较输出误差随 d ϕ d_\phi dϕ 的变化。
  5. KV 策略:把 16 头多查询注意力(MHA)替换为 MQA 与 GQA,比较相同上下文长度下的显存与 tokens/s。

10. 小结(带走三句话)

  1. 标准自注意力的瓶颈是 T 2 T^2 T2 :时间 O ( B   T 2   d ) O(B\,T^2\,d) O(BT2d),训练显存 O ( B   H   T 2 ) O(B\,H\,T^2) O(BHT2)。
  2. 短序列时 FFN 更贵,长序列时注意力更贵 ;拐点约在 T ≈ 4 d T\approx 4d T≈4d。
  3. 应对策略两类改算法形状 (稀疏/低秩/核化,降到 O ( T ) O(T) O(T) 或 O ( T w ) O(Tw) O(Tw))与改执行方式(FlashAttention/融合/并行/KV 策略,降 I/O 与常数因子)。
相关推荐
新加坡内哥谈技术1 小时前
Claude Code 中动态工作流(Dynamic Workflows)
人工智能
XMAIPC_Robot1 小时前
基于RK3588 ARM+FPGA电火花数控机床控制系统设计,兼顾ethercat软硬件实时
linux·arm开发·人工智能·嵌入式硬件·fpga开发
见合八方1 小时前
【滤波器】热调谐FP滤波器
人工智能·算法
古城小栈1 小时前
cargo-pprof:Rust性能调优
人工智能·算法·rust
程序大视界1 小时前
Google I/O 2026 全解析:Gemini 3.5 Flash 免费用、4倍速碾压 GPT-5.5,AI 迎来“Agent 时代“
人工智能·gpt
sunneo1 小时前
S1.2损失厌恶与用户忠诚度的关系:让用户觉得离开是一种损失
人工智能·产品运营·产品经理·用户运营·用户体验
段一凡-华北理工大学1 小时前
工业领域的Hadoop架构学习~系列文章05:Kafka消息队列 - 工业数据流传输
人工智能·hadoop·学习·架构·kafka·工业智能体·高炉炼铁智能化
zcg19421 小时前
如何在CV中使用transformer
人工智能·深度学习·transformer
xiaobangsky1 小时前
AI 时代来临,我该何去何从
人工智能