一、核心思想(非技术语言理解)
Transformer Layer的计算复杂度,本质由两个核心模块决定:
- 多头注意力(MHA) :需要计算「每个token与所有其他token的关联」------ 比如序列长度为
L(有L个token),每个token要和L个token比对应关系,这就产生了L×L的"平方级"计算; - 前馈网络(FFN) :每个token独立做线性变换(不依赖其他token),计算量是"线性级"(和
L成正比)。
当序列变长(L增大)时,"平方级"的注意力计算会快速主导复杂度,这也是Transformer处理长序列效率低的核心原因(比如L=1000时平方项是1e6,L=10000时就变成1e8,直接扩大100倍)。
二、精确推导(含公式与符号定义)
1. 符号定义(固定模型参数,仅L变化)
| 符号 | 含义 | 典型值(如BERT-base) |
|---|---|---|
L |
序列长度(seq_len) | 512 / 1024 |
d |
模型维度(token嵌入维度) | 768 |
h |
多头注意力的头数 | 12 |
d_k = d/h |
单个注意力头的维度 | 768/12=64 |
d_ff |
前馈网络中间层维度 | 4d=3072(标准设置) |
2. 计算复杂度衡量标准
以「浮点运算次数(FLOPs)」为指标,忽略常数项(如加法、除法),仅保留主导项(影响最大的项),最终复杂度用「大O表示法」描述增长趋势。
三、分模块推导复杂度
模块1:多头注意力(MHA)------ 核心平方级来源
MHA的计算流程可拆解为6步,仅保留有计算量的步骤:
- Q/K/V线性投影 :输入
L×d,通过3个独立线性层(权重d×d)得到Q、K、V,每个投影的FLOPs为L×d×d(矩阵乘法:(L×d) × (d×d) = L×d),总FLOPs:
3×Ld23 \times L d^23×Ld2 - 注意力分数计算 :Q(
L×d_k)与K的转置(d_k×L)相乘,得到L×L的注意力矩阵,每个头的FLOPs为L×d_k×L,h个头总FLOPs:
h×L2dk=h×L2×dh=L2dh \times L^2 d_k = h \times L^2 \times \frac{d}{h} = L^2 dh×L2dk=h×L2×hd=L2d(代入d_k=d/h) - 注意力加权V :注意力矩阵(
L×L)与V(L×d_k)相乘,每个头的FLOPs为L×L×d_k,h个头总FLOPs:
h×L2dk=L2dh \times L^2 d_k = L^2 dh×L2dk=L2d(同步骤2推导) - 最终线性投影 :拼接多头结果(
L×d)通过线性层(d×d),FLOPs:
Ld2L d^2Ld2
MHA总复杂度 :
3Ld2+L2d+L2d+Ld2=4Ld2+2L2d3Ld^2 + L^2d + L^2d + Ld^2 = 4Ld^2 + 2L^2d3Ld2+L2d+L2d+Ld2=4Ld2+2L2d
模块2:前馈网络(FFN)------ 线性级补充
FFN结构:Linear(d→d_ff) → ReLU → Linear(d_ff→d),ReLU无计算量,仅看两个线性层:
- 第一层(d→d_ff):输入
L×d,权重d×d_ff,FLOPs:L×d×d_ff - 第二层(d_ff→d):输入
L×d_ff,权重d_ff×d,FLOPs:L×d_ff×d
标准设置d_ff=4d,代入后FFN总复杂度 :
Ld⋅4d+L⋅4d⋅d=8Ld2Ld \cdot 4d + L \cdot 4d \cdot d = 8Ld^2Ld⋅4d+L⋅4d⋅d=8Ld2
模块3:LayerNorm与残差连接------可忽略项
- LayerNorm:对每个token的
d维向量做归一化(均值/方差计算+线性缩放),总FLOPs为O(Ld)(线性级); - 残差连接:元素-wise加法,FLOPs为
O(Ld)(线性级)。
当L较大时(如L>100),O(Ld)远小于O(L²d)和O(Ld²),可忽略。
四、Transformer Layer总复杂度与趋势
1. 总复杂度(合并MHA+FFN)
总FLOPs=(4Ld2+2L2d)+8Ld2=12Ld2+2L2d\text{总FLOPs} = (4Ld^2 + 2L^2d) + 8Ld^2 = 12Ld^2 + 2L^2d总FLOPs=(4Ld2+2L2d)+8Ld2=12Ld2+2L2d
2. 随seq_len(L)的增长趋势
- 模型维度
d是固定值(如768),因此:- 次要项:
12Ld²→ 随L线性增长(O(L)); - 主导项:
2L²d→ 随L平方增长(O(L²))。
- 次要项:
3. 结论
- Transformer Layer的计算复杂度为 O(L2d+Ld2)\boxed{O(L^2 d + L d^2)}O(L2d+Ld2);
- 当
seq_len(L)增加时,O(L²)项主导复杂度增长 ,这是Transformer处理长序列(如L>2048)时效率低下的根本原因。
示例验证(直观感受)
假设d=768(BERT-base),不同L对应的主导项(L²d)增长:
| seq_len(L) | 主导项计算(L²×768) |
相对增长(以L=128为基准) |
|---|---|---|
| 128 | 128²×768 ≈ 12.5M | 1倍 |
| 256 | 256²×768 ≈ 50.3M | 4倍 |
| 512 | 512²×768 ≈ 201.3M | 16倍 |
| 1024 | 1024²×768 ≈ 805.3M | 64倍 |
可见,L翻倍时,主导项复杂度直接翻4倍,这也是后续长序列模型(如Longformer、Linformer)需要优化O(L²)项的核心动机。