文章目录
- 一、张量运算的计算量
-
- [1. FLOPs定义](#1. FLOPs定义)
- [2. 张量计算顺序对计算量的影响](#2. 张量计算顺序对计算量的影响)
- 二、MLA第一次矩阵吸收的计算量分析
-
- [1. 原始注意力计算](#1. 原始注意力计算)
- [2. MLA源代码中的吸收方式](#2. MLA源代码中的吸收方式)
- [3. 提前吸收](#3. 提前吸收)
- [4. 比较分析](#4. 比较分析)
-
- [4.1 比较顺序1和顺序2](#4.1 比较顺序1和顺序2)
- [4.2 比较顺序2和顺序3](#4.2 比较顺序2和顺序3)
- 三、MLA第二次矩阵吸收的计算量分析
-
- [1. 原始输出计算](#1. 原始输出计算)
- [2. MLA源代码中的吸收方式](#2. MLA源代码中的吸收方式)
- [3. 提前吸收](#3. 提前吸收)
- [4. 比较分析](#4. 比较分析)
-
- [4.1 比较顺序1和顺序2](#4.1 比较顺序1和顺序2)
- [4.2 比较顺序2和顺序3](#4.2 比较顺序2和顺序3)
- 参考链接
一、张量运算的计算量
1. FLOPs定义
FLOPs :Floating Point Operations 指的是浮点运算次数,一般特指乘加运算次数,理解为计算量,可以用来衡量算法/模型时间的复杂度。更大的计算量单位通常包括:
- MFLOPs :百万次浮点运算( 10 6 10^6 106 FLOPs)。
- GFLOPs :十亿次浮点运算( 10 9 10^9 109 FLOPs)。
- TFLOPs :万亿次浮点运算( 10 12 10^{12} 1012 FLOPs)。
张量运算的计算量通常与运算维度和操作类型有关,以pytorch
中线性层nn.Linear
的计算为例,设输入张量的维度为 B × S × D B \times S \times D B×S×D,线性层内部权重矩阵维度为 D × O D \times O D×O:
- 若不考虑
bias
,两个张量相乘的结果维度为 B × S × O B \times S \times O B×S×O,结果中的每个元素是由原始张量分别沿着 D D D维度进行了 D D D次乘法和 D − 1 D-1 D−1次加法而来的,因此总计算量为:
( 2 D − 1 ) × B × S × O (2D-1)\times B \times S \times O (2D−1)×B×S×O
- 若考虑
bias
,则每个元素由原始张量分别沿着 D D D维度进行 D D D次乘法和 D − 1 D-1 D−1次加法后,还需加上bias
,因此一共也执行了 D D D次加法,总计算量为:
2 D × B × S × O 2D \times B \times S \times O 2D×B×S×O
为了简单起见,后续分析时均以考虑bias
来分析,这样FLOPs的计算可直接由相关维度的相乘而来。
2. 张量计算顺序对计算量的影响
张量计算顺序的不同会影响计算量。以下是一个例子:
假设有三个张量 A A A、 B B B 和 C C C,它们的形状分别为:
- A A A: ( m , n ) (m, n) (m,n)
- B B B: ( n , p ) (n, p) (n,p)
- C C C: ( p , q ) (p, q) (p,q)
我们需要计算 A × B × C A \times B \times C A×B×C,其中 × \times × 表示矩阵乘法。
计算顺序 1 :先计算 A × B A \times B A×B,再乘以 C C C
- 计算 A × B A \times B A×B:
- 结果形状为 ( m , p ) (m, p) (m,p)。
- 每个元素的计算量为 2 n 2n 2n( n n n 次乘法和 n n n 次加法)。
- 总计算量 : m × p × 2 n = 2 m n p m \times p \times 2n = 2mnp m×p×2n=2mnp。
- 计算 ( A × B ) × C (A \times B) \times C (A×B)×C:
- 结果形状为 ( m , q ) (m, q) (m,q)。
- 每个元素的计算量为 2 p 2p 2p( p p p 次乘法和 p p p 次加法)。
- 总计算量 : m × q × 2 p = 2 m p q m \times q \times 2p = 2mpq m×q×2p=2mpq。
- 总计算量 : 2 m n p + 2 m p q 2mnp + 2mpq 2mnp+2mpq。
计算顺序 2 :先计算 B × C B \times C B×C,再乘以 A A A
- 计算 B × C B \times C B×C:
- 结果形状为 ( n , q ) (n, q) (n,q)。
- 每个元素的计算量为 2 p 2p 2p( p p p 次乘法和 p p p 次加法)。
- 总计算量 : n × q × 2 p = 2 n p q n \times q \times 2p = 2npq n×q×2p=2npq。
- 计算 A × ( B × C ) A \times (B \times C) A×(B×C):
- 结果形状为 ( m , q ) (m, q) (m,q)。
- 每个元素的计算量为 2 n 2n 2n( n n n 次乘法和 n n n 次加法)。
- 总计算量 : m × q × 2 n = 2 m n q m \times q \times 2n = 2mnq m×q×2n=2mnq。
- 总计算量 : 2 n p q + 2 m n q 2npq + 2mnq 2npq+2mnq。
比较两种计算顺序:
-
计算顺序 1 的总计算量为 2 m n p + 2 m p q 2mnp + 2mpq 2mnp+2mpq。
-
计算顺序 2 的总计算量为 2 n p q + 2 m n q 2npq + 2mnq 2npq+2mnq。
-
将上述两式相减,有:
2 [ m n ( p − q ) + p q ( m − n ) ] 2[mn(p-q)+pq(m-n)] 2[mn(p−q)+pq(m−n)]可见如果 p < q , m < n p<q,m<n p<q,m<n则必定计算顺序1的计算量更小,如果 p > q , m > n p>q,m>n p>q,m>n则反之,其余情况 则需根据具体数值分析。
二、MLA第一次矩阵吸收的计算量分析
我们比较三种计算顺序:
假设原始序列 h \mathbf{h} h经Q低秩压缩后得到 c Q \mathbf{c}^Q cQ,经KV低秩压缩得到 c K V \mathbf{c}^{KV} cKV,它们的上投影矩阵分别为 W U Q W^{UQ} WUQ和 W U K W^{UK} WUK。
1. 原始注意力计算
原始注意力计算如下:
( W U Q c Q ) T ( W U K c K V ) (W^{UQ}\mathbf{c}^Q)^T (W^{UK}\mathbf{c}^{KV}) (WUQcQ)T(WUKcKV)
上述张量的形状如下,箭头右边是简记的符号,并将n_heads × qk_nope_head_dim
进行了拆分:
- W U Q W^{UQ} WUQ :
(q_lora_rank, n_heads × qk_nope_head_dim) -> (q, h, d)
- c Q \mathbf{c}^Q cQ :
(bsz, q_seq_len, q_lora_rank) -> (b, s, q)
- W U K W^{UK} WUK :
(kv_lora_rank, n_heads × qk_nope_head_dim) -> (k, h, d)
- c K V \mathbf{c}^{KV} cKV :
(bsz, k_seq_len, kv_lora_rank) -> (b, t, k)
- Step 1: W U Q c Q W^{UQ}\mathbf{c}^Q WUQcQ:
(bsz, q_seq_len, n_heads, qk_nope_head_dim) -> (b, s, h, d)
- Step 2: W U K c K V W^{UK}\mathbf{c}^{KV} WUKcKV:
(bsz, k_seq_len, n_heads, qk_nope_head_dim) -> (b, t, h, d)
- Step 3: ( W U Q c Q ) T ( W U K c K V ) (W^{UQ}\mathbf{c}^Q)^T (W^{UK}\mathbf{c}^{KV}) (WUQcQ)T(WUKcKV):
(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)
这里区分q_seq_len和k_seq_len,训练或prefill时二者是一致的,decode时q_seq_len是1,k_seq_len是cache的长度。
根据张量计算量分析的规则,计算量如下:
FLOPs order 1 = 2 b s h d q + 2 b t h d k + 2 b h s t d \text{FLOPs}_{\text{order}_1}=2bshdq+2bthdk+2bhstd FLOPsorder1=2bshdq+2bthdk+2bhstd
2. MLA源代码中的吸收方式
( W U Q c Q ) T W U K \] c K V \[(W\^{UQ}\\mathbf{c}\^Q)\^T W\^{UK}\]\\mathbf{c}\^{KV} \[(WUQcQ)TWUK\]cKV * Step 1: W U Q c Q W\^{UQ}\\mathbf{c}\^Q WUQcQ:`(bsz, q_seq_len, n_heads, qk_nope_head_dim) -> (b, s, h, d)` * Step 2: ( W U Q c Q ) T W U K (W\^{UQ}\\mathbf{c}\^Q)\^TW\^{UK} (WUQcQ)TWUK:`(bsz, q_seq_len, n_heads, kv_lora_rank) -> (b, s, h, k)` * Step 3: \[ ( W U Q c Q ) T W U K \] c K V \[(W\^{UQ}\\mathbf{c}\^Q)\^T W\^{UK}\]\\mathbf{c}\^{KV} \[(WUQcQ)TWUK\]cKV:`(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)` 计算量如下: FLOPs order 2 = 2 b s h d q + 2 b s h k d + 2 b h s t k \\text{FLOPs}_{\\text{order}_2}=2bshdq+2bshkd+2bhstk FLOPsorder2=2bshdq+2bshkd+2bhstk ### 3. 提前吸收 c Q T ( W U Q T W U K ) c K V {\\mathbf{c}\^Q}\^T(W\^{UQ\^T} W\^{UK})\\mathbf{c}\^{KV} cQT(WUQTWUK)cKV * Step 1: W U Q T W U K W\^{UQ\^T} W\^{UK} WUQTWUK:`(n_heads, q_lora_rank, kv_lora_rank) -> (h, q, k)` * Step 2: c Q T ( W U Q T W U K ) {\\mathbf{c}\^Q}\^T(W\^{UQ\^T} W\^{UK}) cQT(WUQTWUK):`(bsz, q_seq_len, n_heads, kv_lora_rank) -> (b, s, h, k)` * Step 3: c Q T ( W U Q T W U K ) c K V {\\mathbf{c}\^Q}\^T(W\^{UQ\^T} W\^{UK})\\mathbf{c}\^{KV} cQT(WUQTWUK)cKV:`(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)` 计算量如下: FLOPs order 3 = 2 h q k d + 2 b s h k q + 2 b h s t k \\text{FLOPs}_{\\text{order}_3}=2hqkd+2bshkq+2bhstk FLOPsorder3=2hqkd+2bshkq+2bhstk ### 4. 比较分析 #### 4.1 比较顺序1和顺序2 首先比较 FLOPs order 1 \\text{FLOPs}_{\\text{order}_1} FLOPsorder1和 FLOPs order 2 \\text{FLOPs}_{\\text{order}_2} FLOPsorder2,有: FLOPs order 1 − FLOPs order 2 = 2 b h d k ( t − s ) + 2 b h s t ( d − k ) \\text{FLOPs}_{\\text{order}_1}-\\text{FLOPs}_{\\text{order}_2}= 2bhdk(t-s)+2bhst(d-k) FLOPsorder1−FLOPsorder2=2bhdk(t−s)+2bhst(d−k) 其中: * `t`:`k_seq_len` * `s`:`q_seq_len` * `d`:`qk_nope_head_dim = 128` * `k`:`kv_lora_rank = 512` * `h`:`n_heads = 128` * `b`:`bsz`由于第一项和第二项都有`b`,为简单起见,设为1 在训练或prefill阶段,`t`=`s`,上式结果为 − 98304 s 2 -98304s\^2 −98304s2,此时顺序1的计算量更优。 在decode阶段,`t`是缓存长度,而`s`=1,上式结果为 16777216 ( t − 1 ) − 98304 t = 16678912 t − 16777216 16777216(t-1)-98304t=16678912t-16777216 16777216(t−1)−98304t=16678912t−16777216,可见,推理时随着缓存长度`t`的变大,顺序1需要花费更大的计算量,因此才需要把 W U K W\^{UK} WUK吸收进 W U Q c Q W\^{UQ}\\mathbf{c}\^Q WUQcQ(也就是代码中的`q_nope`)中,避免产生的中间量需要大量的计算。 #### 4.2 比较顺序2和顺序3 然后比较 FLOPs order 2 \\text{FLOPs}_{\\text{order}_2} FLOPsorder2和 FLOPs order 3 \\text{FLOPs}_{\\text{order}_3} FLOPsorder3,有: FLOPs order 2 − FLOPs order 3 = 2 h d q ( b s − k ) + 2 b s h k ( d − q ) \\text{FLOPs}_{\\text{order}_2}-\\text{FLOPs}_{\\text{order}_3}= 2hdq(bs-k)+2bshk(d-q) FLOPsorder2−FLOPsorder3=2hdq(bs−k)+2bshk(d−q) 其中: * `q`:`q_lora_rank = 1536` * `b`:`bsz`第一项的`b`无法作为因子提出,因此先不假定具体值 上式结果中不包含`t`,结果为 50331648 ( b s − 512 ) − 184549376 b s = − 134217728 b s − 25769803776 50331648(bs-512)-184549376bs=-134217728bs-25769803776 50331648(bs−512)−184549376bs=−134217728bs−25769803776,恒小于0,因此顺序2的计算量优于顺序3。其原因是 ( W U Q T W U K ) (W\^{UQ\^T} W\^{UK}) (WUQTWUK)充当了新的 W U Q ′ W\^{UQ'} WUQ′,其形状为`(h, q, k)`,具有100663296个元素。而 W U Q W\^{UQ} WUQ和 W U K W\^{UK} WUK的形状分别为`(q, h, d)`和`(k, h, d)`,二者之和只有33554432个元素,约为 W U Q ′ W\^{UQ'} WUQ′的33%,这就解释了虽然公式上直接将 W U K W\^{UK} WUK吸收进了 W U Q W\^{UQ} WUQ,但为什么代码实现上不这么做的原因。不论是从参数量占用还是计算量上,顺序3都没有优势。 ## 三、MLA第二次矩阵吸收的计算量分析 同样比较三种计算顺序: 假设得到的`score`形状大小为`(bsz, n_heads, q_seq_len, k_seq_len)`, c K V \\mathbf{c}\^{KV} cKV向`value`的上投影矩阵为 W U V W\^{UV} WUV,输出维度变换 矩阵为 W O W\^O WO。 ### 1. 原始输出计算 原始计算顺序如下: W O \[ s c o r e ( W U V c K V ) \] W\^O\[score(W\^{UV} \\mathbf{c}\^{KV})\] WO\[score(WUVcKV)
上述张量的形状如下,将n_heads × v_head_dim
进行了拆分:
- s c o r e score score:
(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)
- c K V \mathbf{c}^{KV} cKV:
(bsz, k_seq_len, kv_lora_rank) -> (b, t, k)
- W U V W^{UV} WUV:
(kv_lora_rank, n_heads × v_head_dim) -> (k, h, v)
- W O W^O WO:
(n_heads × v_head_dim, dim) -> (h, v, e)
- Step 1: W U V c K V W^{UV} \mathbf{c}^{KV} WUVcKV:
(bsz, k_seq_len, n_heads, v_head_dim) -> (b, t, h, v)
- Step 2: [ s c o r e ( W U V c K V ) ] [score(W^{UV} \mathbf{c}^{KV})] [score(WUVcKV)]:
(bsz, n_heads, q_seq_len, v_head_dim) -> (b, h, s, v)
- Step 3: W O [ s c o r e ( W U V c K V ) ] W^O[score(W^{UV} \mathbf{c}^{KV})] WO[score(WUVcKV)]:
(bsz, n_heads, q_seq_len, dim) -> (b, h, s, e)
计算量如下:
FLOPs order 1 = 2 b t h v k + 2 b h s v t + 2 b h s e v \text{FLOPs}_{\text{order}_1}=2bthvk+2bhsvt+2bhsev FLOPsorder1=2bthvk+2bhsvt+2bhsev
2. MLA源代码中的吸收方式
W O [ W U V ( s c o r e c K V ) ] W^O[W^{UV} (score\mathbf{c}^{KV})] WO[WUV(scorecKV)]
- Step 1: s c o r e c K V score\mathbf{c}^{KV} scorecKV:
(bsz, n_heads, q_seq_len, kv_lora_rank) -> (b, h, s, k)
- Step 2: [ W U V ( s c o r e c K V ) ] [W^{UV} (score\mathbf{c}^{KV})] [WUV(scorecKV)]:
(bsz, n_heads, q_seq_len, v_head_dim) -> (b, h, s, v)
- Step 3: W O [ W U V ( s c o r e c K V ) ] W^O[W^{UV} (score\mathbf{c}^{KV})] WO[WUV(scorecKV)]:
(bsz, n_heads, q_seq_len, dim) -> (b, h, s, e)
计算量如下:
FLOPs order 2 = 2 b h s k t + 2 b h s v k + 2 b h s e v \text{FLOPs}_{\text{order}_2}=2bhskt+2bhsvk+2bhsev FLOPsorder2=2bhskt+2bhsvk+2bhsev
3. 提前吸收
( W O W U V ) ( s c o r e c K V ) (W^OW^{UV})(score\mathbf{c}^{KV}) (WOWUV)(scorecKV)
- Step 1: W O W U V W^OW^{UV} WOWUV:
(n_heads, kv_lora_rank, dim) -> (h, k, e)
- Step 2: s c o r e c K V score\mathbf{c}^{KV} scorecKV:
(bsz, n_heads, q_seq_len, kv_lora_rank) -> (b, h, s, k)
- Step 3: ( W O W U V ) ( s c o r e c K V ) (W^OW^{UV})(score\mathbf{c}^{KV}) (WOWUV)(scorecKV):
(bsz, n_heads, q_seq_len, dim) -> (b, h, s, e)
计算量如下:
FLOPs order 3 = 2 h k e v + 2 b h s k t + 2 b h s e k \text{FLOPs}_{\text{order}_3}=2hkev+2bhskt+2bhsek FLOPsorder3=2hkev+2bhskt+2bhsek
4. 比较分析
4.1 比较顺序1和顺序2
首先比较 FLOPs order 1 \text{FLOPs}_{\text{order}1} FLOPsorder1和 FLOPs order 2 \text{FLOPs}{\text{order}2} FLOPsorder2,有:
FLOPs order 1 − FLOPs order 2 = 2 b h v k ( t − s ) + 2 b h s t ( v − k ) \text{FLOPs}{\text{order}1}-\text{FLOPs}{\text{order}_2}=2bhvk(t-s)+2bhst(v-k) FLOPsorder1−FLOPsorder2=2bhvk(t−s)+2bhst(v−k)
其中:
t
:k_seq_len
s
:q_seq_len
v
:v_head_dim = 128
k
:kv_lora_rank = 512
h
:n_heads = 128
b
:bsz
由于第一项和第二项都有b
,为简单起见,设为1
由于v
与d
值大小一样,因此计算结果与与第一次矩阵吸收一致。即在训练或prefill阶段,顺序1更优,在decode阶段,顺序2更优。
4.2 比较顺序2和顺序3
然后比较 FLOPs order 2 \text{FLOPs}_{\text{order}2} FLOPsorder2和 FLOPs order 3 \text{FLOPs}{\text{order}3} FLOPsorder3,有:
FLOPs order 2 − FLOPs order 3 = 2 h v k ( b s − e ) + 2 b h s e ( v − k ) \text{FLOPs}{\text{order}2}-\text{FLOPs}{\text{order}_3}=2hvk(bs-e)+2bhse(v-k) FLOPsorder2−FLOPsorder3=2hvk(bs−e)+2bhse(v−k)
其中:
e
:dim = 7168
b
:bsz
第一项的b
无法作为因子提出,因此先不假定具体值
上式结果为 16777216 ( b s − 7168 ) − 704643072 b s = − 687865856 b s − 120259084288 16777216(bs-7168)-704643072bs=-687865856bs −120259084288 16777216(bs−7168)−704643072bs=−687865856bs−120259084288,可见仍然是顺序2的计算结果更优。