链接:https://pan.quark.cn/s/35fc37047e5e

------为什么"长序列很贵",以及我们能做些什么
一句话版:标准自注意力的时间/显存复杂度都随序列长度 T T T 呈二次增长 ( ∼ O ( T 2 ) \sim O(T^2) ∼O(T2))。训练时,显存常被 T × T T\times T T×T 的注意力矩阵与中间激活占满;推理时,虽然可用 KV Cache 把时间复杂度降到线性按步增长 ,但显存仍随历史长度线性增长 。解决之道要么减少二次项 (稀疏/低秩/核化),要么减少 I/O(FlashAttention、重算、检查点等)。
1. 先把"账"算清楚:每一步都要花什么钱?
设批大小 B B B,长度 T T T,隐藏维 d d d,头数 H H H,每头维 d h = d / H d_h=d/H dh=d/H。
1.1 自注意力(单层,训练前向)的主要开销
-
Q , K , V Q,K,V Q,K,V 线性变换:
X ∈ R B × T × d → ( Q , K , V ) ∈ R B × T × d X\in\mathbb{R}^{B\times T\times d}\ \to\ (Q,K,V)\in\mathbb{R}^{B\times T\times d} X∈RB×T×d → (Q,K,V)∈RB×T×d
FLOPs ≈ 3 ⋅ B ⋅ T ⋅ d ⋅ d \approx 3\cdot B\cdot T\cdot d\cdot d ≈3⋅B⋅T⋅d⋅d(三次 d → d d\!\to\! d d→d 线性)。
-
打分与权重(对所有头汇总后):
S = Q K ⊤ d h ⇒ O ( B ⋅ H ⋅ T 2 ⋅ d h ) = O ( B ⋅ T 2 ⋅ d ) S=\frac{QK^\top}{\sqrt{d_h}} \quad\Rightarrow\quad O(B\cdot H\cdot T^2\cdot d_h)=O(B\cdot T^2\cdot d) S=dh QK⊤⇒O(B⋅H⋅T2⋅dh)=O(B⋅T2⋅d)
A = s o f t m a x ( S ) ⇒ O ( B ⋅ T 2 ) A=\mathrm{softmax}(S)\quad\Rightarrow\quad O(B\cdot T^2) A=softmax(S)⇒O(B⋅T2)
-
加权求和:
Y = A V ⇒ O ( B ⋅ H ⋅ T 2 ⋅ d h ) = O ( B ⋅ T 2 ⋅ d ) Y=A\,V \quad\Rightarrow\quad O(B\cdot H\cdot T^2\cdot d_h)=O(B\cdot T^2\cdot d) Y=AV⇒O(B⋅H⋅T2⋅dh)=O(B⋅T2⋅d)
-
多头拼接 + 输出投影 W O W_O WO:
≈ B ⋅ T ⋅ d ⋅ d \approx B\cdot T\cdot d\cdot d ≈B⋅T⋅d⋅d。
合计主项(忽略常数):
Attn FLOPs ∼ O ( B T 2 d ) + O ( B T d 2 ) \boxed{\text{Attn FLOPs}\ \sim\ O(B\,T^2\,d)\ +\ O(B\,T\,d^2)} Attn FLOPs ∼ O(BT2d) + O(BTd2)
其中 O ( B T 2 d ) O(B\,T^2\,d) O(BT2d) 来自 Q K ⊤ QK^\top QK⊤ 与 A V A V AV, O ( B T d 2 ) O(B\,T\,d^2) O(BTd2) 来自线性投影。
因果 Mask 只把 T 2 T^2 T2 前的常数减半(上三角/下三角),级别仍是 O ( T 2 ) O(T^2) O(T2)。
1.2 前馈网络(FFN)
两层按位置独立的 MLP,设 d ff ≈ 4 d d_{\text{ff}}\approx 4d dff≈4d:
FFN FLOPs ≈ 2 ⋅ B ⋅ T ⋅ d ⋅ d ff ≈ 8 ⋅ B ⋅ T ⋅ d 2 . \text{FFN FLOPs}\ \approx 2\cdot B\cdot T\cdot d\cdot d_{\text{ff}}\ \approx\ 8\cdot B\cdot T\cdot d^2. FFN FLOPs ≈2⋅B⋅T⋅d⋅dff ≈ 8⋅B⋅T⋅d2.
1.3 谁在"主宰"算力:注意力 vs FFN?
比较阶数:
- 注意力主项: ∼ 2 B T 2 d \sim 2\,B\,T^2\,d ∼2BT2d
- FFN 主项: ∼ 8 B T d 2 \sim 8\,B\,T\,d^2 ∼8BTd2
二者相当的拐点在
2 T 2 d ≈ 8 T d 2 ⇒ T ≈ 4 d . 2\,T^2\,d \approx 8\,T\,d^2\quad\Rightarrow\quad \boxed{T \approx 4d}. 2T2d≈8Td2⇒T≈4d.
- 当 T ≪ 4 d T \ll 4d T≪4d:FFN 更占 FLOPs(小序列/大维度)。
- 当 T ≫ 4 d T \gg 4d T≫4d:注意力成为瓶颈(长序列场景)。
例: d = 1024 d=1024 d=1024 时拐点在 T ≈ 4096 T\approx 4096 T≈4096。当 T = 2 k T=2\text{k} T=2k 时,多数算力花在 FFN;当 T = 8 k T=8\text{k} T=8k 时,注意力成为主导。
2. 显存(内存)复杂度:是谁在吃内存?
训练时需要存下激活以便反传:
- 注意力权重 A A A:尺寸 B × H × T × T B\times H\times T\times T B×H×T×T,显存 ∼ O ( B H T 2 ) \sim O(B\,H\,T^2) ∼O(BHT2)。
- Q , K , V , Y Q,K,V,Y Q,K,V,Y:尺寸 B × H × T × d h B\times H\times T\times d_h B×H×T×dh,合计 ∼ O ( B T d ) \sim O(B\,T\,d) ∼O(BTd)。
- FFN 中间激活(如 GELU 前后): ∼ O ( B T d ff ) \sim O(B\,T\,d_{\text{ff}}) ∼O(BTdff)。
粗略记忆:二次项 只有注意力矩阵 A A A(或等价中间张量),其余基本是线性。
2.1 一个具体算例(只看注意力矩阵)
- B = 1 , H = 16 , T = 4096 B=1,\ H=16,\ T=4096 B=1, H=16, T=4096。
- A A A 的元素数: 16 × 4096 2 = 268,435,456 16\times 4096^2 = 268{,}435{,}456 16×40962=268,435,456。
- 若 FP16(2 字节),仅 A A A 就占 ≈ 512 MiB \approx 512\,\text{MiB} ≈512MiB。
这还没算 Q , K , V Q,K,V Q,K,V 与 FFN 激活,多层叠加后你就知道为什么 24GB 显存不经用。
2.2 训练期的"省内存"手段
- 激活检查点 (checkpointing):丢弃一部分中间激活,反向时重算 ,以时间换空间。
- FlashAttention :分块计算 Q K ⊤ QK^\top QK⊤ / softmax / A V A V AV,避免显式存 A A A ,将显存从 O ( T 2 ) O(T^2) O(T2) 降到接近 O ( T d ) O(T\,d) O(Td),同时减少 HBM 读写(I/O 成本)。
- 混合精度(bf16/fp16)+ 归约用 fp32 保精度。
3. 推理复杂度:KV Cache 的"线性按步增长"
自回归生成第 t t t 个 token 时:
- 计算当前 q t q_t qt: O ( d 2 ) O(d^2) O(d2)(线性层)。
- 与历史 K 1 : t K_{1:t} K1:t 做点积:每头 O ( t ⋅ d h ) O(t\cdot d_h) O(t⋅dh),合计 O ( t ⋅ d ) O(t\cdot d) O(t⋅d)。
- 读出 V 1 : t V_{1:t} V1:t 做加权和:同阶 O ( t ⋅ d ) O(t\cdot d) O(t⋅d)。
- KV Cache 显存 :每步存 ( k t , v t ) (k_t,v_t) (kt,vt),总计 ∼ O ( H ⋅ T ⋅ d h ) = O ( T ⋅ d ) \sim O(H\cdot T\cdot d_h)=O(T\cdot d) ∼O(H⋅T⋅dh)=O(T⋅d)。
Per-token Time ∼ O ( d 2 ) + O ( t d ) , KV 显存 ∼ O ( T d ) \boxed{\text{Per-token Time}\ \sim\ O(d^2) + O(t\,d),\quad \text{KV 显存}\ \sim\ O(T\,d)} Per-token Time ∼ O(d2)+O(td),KV 显存 ∼ O(Td)
当 t t t 很大时(长上下文),带宽与缓存大小 成为瓶颈,吞吐随 t t t 下降。
3.1 降低推理显存/带宽的工程方案
- MQA/GQA (Multi-Query / Grouped-Query):多个头共享 K/V (或按组共享),将缓存从 O ( H T d h ) O(H\,T\,d_h) O(HTdh) 降到 O ( T d h ) O(T\,d_h) O(Tdh) 或 O ( H g T d h ) O(\frac{H}{g}\,T\,d_h) O(gHTdh)。
- KV 压缩/量化 :8-bit/4-bit KV;或滑窗/裁剪远端缓存。
- 分块解码 :把点积与加权和做成流式分块,降低一次性带宽峰值。
4. 交叉注意力与复杂度形状
在编码器-解码器结构中,交叉注意力打分矩阵尺寸是 T dec × T enc T_{\text{dec}}\times T_{\text{enc}} Tdec×Tenc,
复杂度 ∼ O ( B ⋅ T dec ⋅ T enc ⋅ d ) \sim O(B\cdot T_{\text{dec}} \cdot T_{\text{enc}} \cdot d) ∼O(B⋅Tdec⋅Tenc⋅d)。
当源序列很长(例如多段文档检索)时,交叉注意力也会成为瓶颈。
5. 一个"数量级感觉"的 FLOPs 例子
取 B = 1 , T = 4096 , d = 1024 B=1,\ T=4096,\ d=1024 B=1, T=4096, d=1024。
- 注意力主项:
2 ⋅ T 2 ⋅ d ≈ 2 ⋅ 16,777,216 ⋅ 1024 ≈ 3.44 × 10 10 2\cdot T^2\cdot d \approx 2\cdot 16{,}777{,}216\cdot 1024 \approx \mathbf{3.44\times 10^{10}} 2⋅T2⋅d≈2⋅16,777,216⋅1024≈3.44×1010 次乘加( ∼ 34 \sim 34 ∼34 GFLOPs)。 - FFN 主项( d ff = 4096 d_{\text{ff}}=4096 dff=4096):
2 ⋅ T ⋅ d ⋅ d ff ≈ 2 ⋅ 4096 ⋅ 1024 ⋅ 4096 ≈ 3.44 × 10 10 2\cdot T\cdot d\cdot d_{\text{ff}} \approx 2\cdot 4096\cdot 1024\cdot 4096 \approx \mathbf{3.44\times 10^{10}} 2⋅T⋅d⋅dff≈2⋅4096⋅1024⋅4096≈3.44×1010(巧合相当,因为 T ≈ 4 d T\approx 4d T≈4d 正在拐点)。
这只是一层,多层相乘即可得到粗算总 FLOPs。真实实现里还有常数项、I/O 与 kernel 启动开销。
6. 用 Mermaid 看"复杂度分解"与"选择指南"
6.1 复杂度分解小图
#mermaid-svg-jPLhks20LoaC1rs0{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-jPLhks20LoaC1rs0 .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-jPLhks20LoaC1rs0 .error-icon{fill:#552222;}#mermaid-svg-jPLhks20LoaC1rs0 .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-jPLhks20LoaC1rs0 .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-jPLhks20LoaC1rs0 .marker{fill:#333333;stroke:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 .marker.cross{stroke:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-jPLhks20LoaC1rs0 p{margin:0;}#mermaid-svg-jPLhks20LoaC1rs0 .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster-label text{fill:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster-label span{color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster-label span p{background-color:transparent;}#mermaid-svg-jPLhks20LoaC1rs0 .label text,#mermaid-svg-jPLhks20LoaC1rs0 span{fill:#333;color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .node rect,#mermaid-svg-jPLhks20LoaC1rs0 .node circle,#mermaid-svg-jPLhks20LoaC1rs0 .node ellipse,#mermaid-svg-jPLhks20LoaC1rs0 .node polygon,#mermaid-svg-jPLhks20LoaC1rs0 .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .rough-node .label text,#mermaid-svg-jPLhks20LoaC1rs0 .node .label text,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape .label,#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape .label{text-anchor:middle;}#mermaid-svg-jPLhks20LoaC1rs0 .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .rough-node .label,#mermaid-svg-jPLhks20LoaC1rs0 .node .label,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape .label,#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape .label{text-align:center;}#mermaid-svg-jPLhks20LoaC1rs0 .node.clickable{cursor:pointer;}#mermaid-svg-jPLhks20LoaC1rs0 .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 .arrowheadPath{fill:#333333;}#mermaid-svg-jPLhks20LoaC1rs0 .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-jPLhks20LoaC1rs0 .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-jPLhks20LoaC1rs0 .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-jPLhks20LoaC1rs0 .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-jPLhks20LoaC1rs0 .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-jPLhks20LoaC1rs0 .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-jPLhks20LoaC1rs0 .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster text{fill:#333;}#mermaid-svg-jPLhks20LoaC1rs0 .cluster span{color:#333;}#mermaid-svg-jPLhks20LoaC1rs0 div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-jPLhks20LoaC1rs0 .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-jPLhks20LoaC1rs0 rect.text{fill:none;stroke-width:0;}#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape p,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-jPLhks20LoaC1rs0 .icon-shape .label rect,#mermaid-svg-jPLhks20LoaC1rs0 .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-jPLhks20LoaC1rs0 .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-jPLhks20LoaC1rs0 .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-jPLhks20LoaC1rs0 :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 输入: B,T,d,H
线性投影: O(B*T*d^2)
打分 QK^T: O(B*T^2*d)
softmax: O(B*T^2)
加权和 A*V: O(B*T^2*d)
输出投影: O(B*T*d^2)
输出: O(B*T*d)
说明 :二次项集中在 QK^T 与 A*V ,其余为线性或 d 2 d^2 d2 项。
6.2 长序列"怎么选"决策图
#mermaid-svg-FwUTl91MmEnpaiwk{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-FwUTl91MmEnpaiwk .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-FwUTl91MmEnpaiwk .error-icon{fill:#552222;}#mermaid-svg-FwUTl91MmEnpaiwk .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-FwUTl91MmEnpaiwk .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-FwUTl91MmEnpaiwk .marker{fill:#333333;stroke:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk .marker.cross{stroke:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-FwUTl91MmEnpaiwk p{margin:0;}#mermaid-svg-FwUTl91MmEnpaiwk .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster-label text{fill:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster-label span{color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster-label span p{background-color:transparent;}#mermaid-svg-FwUTl91MmEnpaiwk .label text,#mermaid-svg-FwUTl91MmEnpaiwk span{fill:#333;color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .node rect,#mermaid-svg-FwUTl91MmEnpaiwk .node circle,#mermaid-svg-FwUTl91MmEnpaiwk .node ellipse,#mermaid-svg-FwUTl91MmEnpaiwk .node polygon,#mermaid-svg-FwUTl91MmEnpaiwk .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .rough-node .label text,#mermaid-svg-FwUTl91MmEnpaiwk .node .label text,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape .label,#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape .label{text-anchor:middle;}#mermaid-svg-FwUTl91MmEnpaiwk .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .rough-node .label,#mermaid-svg-FwUTl91MmEnpaiwk .node .label,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape .label,#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape .label{text-align:center;}#mermaid-svg-FwUTl91MmEnpaiwk .node.clickable{cursor:pointer;}#mermaid-svg-FwUTl91MmEnpaiwk .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk .arrowheadPath{fill:#333333;}#mermaid-svg-FwUTl91MmEnpaiwk .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-FwUTl91MmEnpaiwk .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-FwUTl91MmEnpaiwk .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-FwUTl91MmEnpaiwk .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-FwUTl91MmEnpaiwk .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-FwUTl91MmEnpaiwk .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-FwUTl91MmEnpaiwk .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster text{fill:#333;}#mermaid-svg-FwUTl91MmEnpaiwk .cluster span{color:#333;}#mermaid-svg-FwUTl91MmEnpaiwk div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-FwUTl91MmEnpaiwk .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-FwUTl91MmEnpaiwk rect.text{fill:none;stroke-width:0;}#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape p,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-FwUTl91MmEnpaiwk .icon-shape .label rect,#mermaid-svg-FwUTl91MmEnpaiwk .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-FwUTl91MmEnpaiwk .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-FwUTl91MmEnpaiwk .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-FwUTl91MmEnpaiwk :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 是
否
是
否
序列很长?
必须全局依赖?
标准注意力即可
允许近似?
FlashAttn/分布式/减小批量
局部/滑窗注意力 O(T*w*d)
核化/线性注意力 O(T*d^2)
低秩/投影(如 Nyström, Linformer)
混合: 局部 + 少量全局
说明:
- 必须全局精确 :优先 FlashAttention、分布式训练、减小 batch 或层数。
- 允许近似 :选择局部/稀疏 、低秩 、或核化路线,视任务而定。
7. 降复杂与提效的几条路线图
7.1 稀疏/局部模式(保留 O ( T w ) O(T\,w) O(Tw))
- 滑窗/块稀疏 (Longformer/Sliding Window):每个 token 只看邻域 w w w。
复杂度 ∼ O ( B T w d ) \sim O(B\,T\,w\,d) ∼O(BTwd),若 w ≪ T w\ll T w≪T 则远小于 T 2 T^2 T2。 - 混合全局(BigBird 等):局部 + 随机 + 若干全局 token,保持理论可表达性。
7.2 低秩/投影近似
- Linformer :把 K , V K,V K,V 的长度维用投影压到 r ≪ T r\ll T r≪T,复杂度 ∼ O ( B T r d ) \sim O(B\,T\,r\,d) ∼O(BTrd)。
- Nyströmformer :用 r r r 个"地标"近似注意力核,同阶 ∼ O ( B T r d ) \sim O(B\,T\,r\,d) ∼O(BTrd)。
7.3 核化/线性注意力
-
令 s o f t m a x ( Q K ⊤ ) ≈ ϕ ( Q ) ϕ ( K ) ⊤ \mathrm{softmax}(QK^\top)\approx \phi(Q)\phi(K)^\top softmax(QK⊤)≈ϕ(Q)ϕ(K)⊤,则
ϕ ( K ) ⊤ V ∈ O ( B T d ⋅ d ϕ ) , ϕ ( Q ) ⋅ ( ⋅ ) ∈ O ( B T d ⋅ d ϕ ) \phi(K)^\top V \in O(B\,T\,d\cdot d_\phi),\quad \phi(Q)\cdot(\cdot) \in O(B\,T\,d\cdot d_\phi) ϕ(K)⊤V∈O(BTd⋅dϕ),ϕ(Q)⋅(⋅)∈O(BTd⋅dϕ)
总体 ∼ O ( B T d ⋅ d ϕ ) \sim O(B\,T\,d\cdot d_\phi) ∼O(BTd⋅dϕ),线性于 T T T。
-
代表:Performer、Linear Transformers 等。注意 :近似误差与 ϕ \phi ϕ 的选择有关。
7.4 IO-aware:把"搬数据"的代价降下来
- FlashAttention :分块计算,不显式存 A A A ,降低 HBM 往返;复杂度级别不变 (仍是 T 2 T^2 T2),但速度/可训练长度大幅改观。
- 融合 kernel:scale + mask + softmax + dropout + matmul 融合,少 kernel 启动。
- 序列并行/张量并行 :跨设备分摊 T T T 或 d d d,减单卡峰值显存。
8. 工程实战清单(训练 & 推理)
训练侧
- 长序列先上 FlashAttention + bf16 + 归约 fp32。
- 激活检查点 配合 重算;必要时关掉注意力 dropout 稳数值。
- clip grad-norm,防止长序列引发梯度爆炸。
- 善用 梯度累积 替代超大 batch。
- 若任务允许,用局部/稀疏注意力 或低秩投影 尝试把 T 2 T^2 T2 打碎。
推理侧
- KV Cache 预分配;优先 MQA/GQA 降内存与带宽。
- Paged KV / 滑窗裁剪远端上下文。
- 分块点积(blockwise)避免一次性读取过长 KV。
- 合理的 prefill:decode 比例与批内并行(例如把多条请求对齐 prefill)。
9. 小练习(带提示)
- 门槛判断 :给定 d = 1536 d=1536 d=1536,估计注意力与 FFN FLOPs 的拐点 T T T;试计算 T = 8 k T=8\text{k} T=8k 时两者的比值。
- 显存估算 : B = 2 , H = 8 , T = 8192 B=2,H=8,T=8192 B=2,H=8,T=8192 时,FP16 下注意力矩阵 A A A 占用多大?如果用 FlashAttention,会省掉多少显存级别的张量?
- 窗口选择 :在一个长文分类任务上,尝试 w ∈ { 128 , 256 , 512 } w\in\{128,256,512\} w∈{128,256,512} 的滑窗注意力,比较准确率与吞吐。
- 核化近似误差 :实现一个随机特征 ϕ \phi ϕ 的线性注意力,与标准 softmax 注意力在相同 Q , K , V Q,K,V Q,K,V 上比较输出误差随 d ϕ d_\phi dϕ 的变化。
- KV 策略:把 16 头多查询注意力(MHA)替换为 MQA 与 GQA,比较相同上下文长度下的显存与 tokens/s。
10. 小结(带走三句话)
- 标准自注意力的瓶颈是 T 2 T^2 T2 :时间 O ( B T 2 d ) O(B\,T^2\,d) O(BT2d),训练显存 O ( B H T 2 ) O(B\,H\,T^2) O(BHT2)。
- 短序列时 FFN 更贵,长序列时注意力更贵 ;拐点约在 T ≈ 4 d T\approx 4d T≈4d。
- 应对策略两类 :改算法形状 (稀疏/低秩/核化,降到 O ( T ) O(T) O(T) 或 O ( T w ) O(Tw) O(Tw))与改执行方式(FlashAttention/融合/并行/KV 策略,降 I/O 与常数因子)。