1 什么是前缀和?
定义: 第 k 个元素的状态依赖于第 k-1 个元素;
公式: 前缀和 = 从第 1 个,一直加到当前位置;
例子:
比如有 4 个数:
A、B、C、D;
那么前缀和的结果为:
txt
S1 = A
S2 = A + B
S3 = A + B + C
S4 = A + B + C + D
在 Linear Attention 中有所体现,即,S2 依赖 S1,S3 依赖 S2,必须串行。
2 问题
如果串行计算的话,时间复杂度相当高;能不能将 O(n)的时间复杂度优化到O(1)呢?
3 并行前缀和加速
第一步:看这个下三角全 1 矩阵
txt
1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1
第二步: 让它 乘以向量 [A,B,C,D]
结果就是:
txt
第1行:1*A + 0*B + 0*C + 0*D = A → S1
第2行:1*A + 1*B + 0*C + 0*D = A+B → S2
第3行:1*A + 1*B + 1*C + 0*D = A+B+C → S3
第4行:1*A + 1*B + 1*C + 1*D = A+B+C+D → S4
结果
✅ 一次性得到 S1、S2、S3、S4!
✅ 没有循环!没有等待!全部并行!
4 结论
并行前缀和的本质就是:
用下三角矩阵,模拟 "只累加前面所有值" 的效果,一次性输出所有结果。
对应到 Linear Attention 的 Chunk 加速中:
公式
St=eΔt⋅∑i=1t(eγi−Δi⋅ki⊤vi) S_t = e^{\Delta_t} \cdot \sum_{i=1}^t \left( e^{\gamma_i - \Delta_i} \cdot k_i^\top v_i \right) St=eΔt⋅i=1∑t(eγi−Δi⋅ki⊤vi)
里面每一项:
- Δt\Delta_tΔt:可以用
cumsum并行算 - eΔte^{\Delta_t}eΔt、e−Δie^{-\Delta_i}e−Δi:并行指数
- eγi−Δiki⊤vie^{\gamma_i - \Delta_i} k_i^\top v_ieγi−Δiki⊤vi:全部并行
- ∑i=1t\sum_{i=1}^t∑i=1t:用 下三角矩阵(tril) 并行前缀和
后面这累加的一串就可以用并行前缀和。
对应代码:
python
decay_mask = tril(...) # 下三角矩阵
attn = k_beta @ k.T * decay_mask # 并行前缀和
这一行 = 一次性算出所有 S_t,不需要循环!