transformer模型相关的计算:参数量、FLOPs、训练显存、最大batch size

模型参数量计算

计算过程如下:

  • transformers:

    • FFNs: 176,160,768 * 8 = 45,097,156,608
    • gate: 4096 * 8 = 1,048,576
    • MA: 4096 * (128 * 48) + 4096 * 4096= 41,943,040
    • LN: 4096 * 2 = 8,192
    • total: (176,160,768 * 8 + 41,943,040 + 8,192 + 4096 * 8) * 32 = 46,440,644,608
  • others:

    • embed & output_w: 262,144,000
  • total:

    • 46,440,644,608 + 262,144,000 = 46,702,788,608 = 46B
  • active params:

    • (176,160,768 * 2 + 41,943,040 + 8,192 + 4096 * 8) * 32 + 262,144,000 = 12,879,921,152 = 12.8B

模型训练的并行方式分为3种,DP(data parallel) / TP(tensor parallel) / PP(pipline parallel),MoE模型在训练时可以同时使用这3种模型外,还可以加入EP(expert parallel)方式。EP的精髓就是多个模型中共享expert,2点理解:

  • 因为MoE每次前向中只用到一小部分expert,如果每个模型保留完整的expert,一定会导致大多数expert空闲的情况;
  • 如果DP是8,EP是2,那么2个模型共用一套完整的experts;

训练并行设定:TP2 DP8 EP8 (megatron方案),需要显存如下:

module MoE参数/单卡 Dense参数/单卡
Emb and Output h * vocab *2 /TP =262 144 h * vocab *2 /TP =262 144
experts hffn3 * num_experts * n_layers / TP / EP =88 080 384 hffn3 * n_layers / TP =88 080 384
gate h*num_experts * n_layers =1 048 576 /
GQA hhdim(nhead+n_kv_heads) * 2 * n_layers / TP =671 088 640 hhdim(nhead+n_kv_heads) * 2 * n_layers / TP =671 088 640
LN h*2 * n_layers =262 144 h*2 * n_layers =262 144
total 3,622,043,648 3,620,995,072
推理所需显存/单卡 7,244,087,296 7,241,990,144
训练所需显存/单卡 57,952,698,368 57,935,921,152
  • 通过加入EP的方式,在7B的模型大小下,MoE训练所需的显存于正常7B相差不大,结论成立。

    注:实际计算中,MoE 的激活值会相比原有增大 EP 倍,和训练长度有关,待后续实测。TODO

FLOPs per token计算方式

每一步对应的FLOPs计算如下:

  • Attention: MHA

    • QKV proj: 6sh^2;
    • K@Q: 2s^2h;
    • attention@value: 2s^2h;
    • output: 2sh^2;
  • FFN: SwiGLU

    • 16sh^2;
  • lm head

    • 2sVh;

其中,s是句长,h是hidden size,V是vocab size,B是bath size;

反向传播是前向的2倍所以最终需要乘以系数3,最后得到公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> F L O P s = 3 ( B l ( 24 s h 2 + 4 s 2 h ) + 2 s V h ) = 72 B l s h 2 ( 1 + s 6 h + 2 V 12 h l ) FLOPs=3(Bl(24sh^2+4s^2h)+2sVh)=72Blsh^2(1+\frac{s}{6h}+\frac{2V}{12hl}) </math>FLOPs=3(Bl(24sh2+4s2h)+2sVh)=72Blsh2(1+6hs+12hl2V)

说明:

  1. C=6ND的由来:

    FFN对应的FLOPs计算公式为:2\times3\times\frac{ffn_dim}{dim}sh^2,llama2的结构计算这里是16,所以最后的公式中主项为72Blsh^2 = 6\times (12lh^2) \times Bs \approx 6ND

  2. 因为真实场景下s,h都是比较大的数字,所以计算时只需要保留对应FLOP为3次项的操作;其余被忽略掉,例如:residual、glu里的位乘、layer norm、softmax,部分对应的FLOP如下:

    • softmax: 3Hs^2,其中H为heads数;
    • residual: sh;
    • RMS norm: 5sh;
  3. 如果不考虑lm head的计算量,就去掉括号里的最后一项;

  4. 实验中设定ffn_dim为3.5倍hidden size,所以FFN为:32 s3.5hh=21h^2;

  5. 最后公式中,Bs也可以写成D,表示数据集的大小;

PS:从公式中也可以看到,当s比较小的时候,括号里的第二项可以忽略,整个模型的计算相当于只有第一项,既FFN的计算量。多小才可以呢,比如s<\frac{h}{14}时候,第二项小于0.01可以忽略。

固定GPU数量下的batch size求解

介绍一下下面用的符号

M: total memory per GPU

W: world size, total GPU count

P: partition size, PP * TP

N: model parameter count

activation计算方式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> a c t i v a t i o n = ( 34 s d + 5 s 2 a ) l b activation = (34sd + 5s^2a)lb </math>activation=(34sd+5s2a)lb

来自zhuanlan.zhihu.com/p/648924115

因此,不考虑梯度累积的情况下,最大GBS计算方法如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G B S m a x = M P − 16 N ( 34 s d + 5 s 2 a ) l GBS_{max}=\frac{MP-16N}{(34sd + 5s^2a)l} \\ </math>GBSmax=(34sd+5s2a)lMP−16N

其中,P为满足显存需求的最小切分数,计算方法如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P = 2 c e i l ( l o g 2 ( 16 N 0.7 M ) ) P = 2^{ceil\big(log_2(\frac{16N}{0.7M})\big)} </math>P=2ceil(log2(0.7M16N))

通过计算得到最大GBS与参数两N的变化如下所示:

通过观察,log(N)与log(GBS)呈线性关系,可以表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> l o g ( G B S ) = a ⋅ l o g ( N ) + d log(GBS)=a\cdot log(N) + d </math>log(GBS)=a⋅log(N)+d

或者写成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> B ( N ) = ( B c N ) α B (6) B(N)=(\frac{B_c}{N})^{\alpha_B} \tag{6} </math>B(N)=(NBc)αB(6)

其中,\alpha_B=-a,B_c=10^{-b/a};通过拟合得到\alpha_B=1.08211307, B_c=3\times 10^{15}。

相关推荐
shymoy24 分钟前
Radix Sorts
数据结构·算法·排序算法
风影小子33 分钟前
注册登录学生管理系统小项目
算法
黑龙江亿林等保35 分钟前
深入探索哈尔滨二级等保下的负载均衡SLB及其核心算法
运维·算法·负载均衡
lucy1530275107938 分钟前
【青牛科技】GC5931:工业风扇驱动芯片的卓越替代者
人工智能·科技·单片机·嵌入式硬件·算法·机器学习
杜杜的man1 小时前
【go从零单排】迭代器(Iterators)
开发语言·算法·golang
小沈熬夜秃头中୧⍤⃝1 小时前
【贪心算法】No.1---贪心算法(1)
算法·贪心算法
木向2 小时前
leetcode92:反转链表||
数据结构·c++·算法·leetcode·链表
阿阿越2 小时前
算法每日练 -- 双指针篇(持续更新中)
数据结构·c++·算法
skaiuijing2 小时前
Sparrow系列拓展篇:对调度层进行抽象并引入IPC机制信号量
c语言·算法·操作系统·调度算法·操作系统内核
Star Patrick2 小时前
算法训练(leetcode)二刷第十九天 | *39. 组合总和、*40. 组合总和 II、*131. 分割回文串
python·算法·leetcode