上一篇解释完点积和矩阵乘法,矩阵乘法是一种转换,这一篇看 Transformer 中如何运用的。
LLM 的本质是预测下一个 token,阶段二中,使用大量的互联网内容,给模型做训练,使用自监督学习,不断调整 1750 亿个参数。直到模型能够正确的补全文本内容。
阶段二的产物,只有文本补全功能,不具备问答、对话能力。
现在假设所有参数已经调节完毕,以一个输入是 "The cat sat" 展示模型怎么预测下一个 token 是 "on" 的。
查询向量坐标
输入 "The cat sat" 经过 Tokenizer 会把这句话分为 the、cat、sat 三个 token,再从 Vocab Map<String, Integer> 中找到对应的id,再去 Embedding Table 中找到这三个向量坐标。
得到 3 个形状为 1,4096 的向量,Vector_The: 0.1, -0.5, ...、Vector_cat: 0.8, 0.2, ...、Vector_sat: -0.1, 0.9, ...。把这三个向量合并到一个矩阵里面,得到一个形状是 3,4096 的向量。
Xin= x Thex catx sat
进入 Layer 加工
接下来正式进入 Layer 加工。之前说过一共有 96 层 Layers,每一层 Layer 有 MHA (Multi-Head Attention) 多头注意力机制 和 FFN (Feed-Forward Network) 前馈神经网络处理。形状为 3,4096 的矩阵会完整经历所有 Layers,最后得到加工后的 3,4096。
x_in ➔ Norm ➔ MHA ➔ (+ 残差连接) ➔ x_mid ➔ Norm ➔ FFN ➔ (+ 残差连接) ➔ x_out
以上是一层 Layer 完整过程。 x_out 会是下一层 Layer 的 x_in。
Norm 层归一化
其中 Norm 是 Layer Normalization(层归一化),矩阵乘法的结果范围非常大,有的值是 50000 而有的是 0.000003,为了防止计算溢出或者梯度乱跳,需要把这些值统一处理为均值为 0 方差为 1。
Norm 计算包含四个步骤,假设以 ( x cat): 10, 2, 12, 0 为例。
-
求均值 ( μ)
(10+2+12+0)÷4=6
-
求方差 ( σ2)
10→(10−6)2=16
2→(2−6)2=16
12→(12−6)2=36
0→(0−6)2=36
方差=(16+16+36+36)÷4=26
标准差 ( σ): 26 ≈5.1
-
归一化 (Normalize)
公式: 标准差x−均值目的是把数据强行拉回到 "均值为 0,方差为 1" 的标准形态。
10→(10−6)/5.1≈0.78
2→(2−6)/5.1≈−0.78
12→(12−6)/5.1≈1.17
0→(0−6)/5.1≈−1.17
结果向量: 0.78, -0.78, 1.17, -1.17
-
缩放与平移 (Scale & Shift) 如果每次都强行变成 0 均值,可能会破坏数据的含义。所以模型有两个可变参数: γ (缩放) 和 β (平移)。这里可以理解成一个一元二次方程的线性函数。 假设模型学到这一层需要数值稍微大一点:
γ=2,2,2,2 β=1,1,1,1
最终输出 = 归一化结果×γ+β
0.78×2+1=2.56
...
残差连接
经过 MHA 和 FFN 后,为了防止原始数据丢失,会加上处理前的原始值。
Output=New_Process(x)+x
MHA
回到公式,x_in 是形状为 3,4096 的矩阵,经过 Norm 后还是 3,4096。接着进入 MHA。
MHA 中有 WQ、WK、WV、 WO 四个形状为 d_modle,d_modle(4096,4096) 的矩阵。这是在阶段二训练好的。
Q 是 question,K 是 key,V 是 value,这是三个非常抽象的矩阵,他们的作用是把 "The cat sat" 中三个 token 的向量坐标互相融合,比如 the 要更关注 cat,经过融合后 the 这个 token 的向量值里就包含了大量 cat 的向量值。
MHA 叫多头注意力机制,比如有 32 个头,4096/32=128,就是把 WQ、WK、WV 分为 32 个形状为 4096,128 的小矩阵,并行计算得到 32 个结果,再合并起来乘以 WO 得到最终的产物。
大量的并行矩阵计算,这也是算力以 GPU 为核心的原因。
头也是抽象的概念,可以用角度类比,以 "The cat sat on the mat because it was tired." 为例。
Head 1 (语法眼): 专门盯着主谓关系。它发现 "it" 指代 "cat"。
Head 2 (逻辑眼): 专门盯着因果关系。它发现 "because" 导致了 "tired"。
Head 3 (位置眼): 专门盯着方位关系。它关注 "on the mat"。
真个 MHA 的计算过程用一个公式表示。
Z=softmax(dmodel (XWQ)(XWK)T)(XWV)
其中, Scores=(XWQ)(XWK)T
XWQ : 算出 Q 矩阵。
XWK : 算出 K 矩阵。
(...)T : 转置 K 矩阵,为了能进行矩阵乘法(前一个矩阵的横行必须等于后一个矩阵的纵列)。
转置是沿对角线翻转。比如
K=142536
转置后是
KT= 123456
d_modle 计算模型设定的维度,除以 dmodel 也是为了防止向量值之间差距太大。
A=softmax(4 Scores)
softmax 是一种把一堆数字变为总和为 1 的小数算法,假设有一个输入向量(Logits) z=z1,z2,...,zn。 Softmax 函数 σ(z)i 的公式是:
σ(z)i=∑j=1Nezjezi
分子 ( ezi) :算出当前元素的指数值。
分母 ( ∑ezj) :算出所有元素指数值的总和。
用到了 e 自然指数,把负数也转成正数( e−2≈0.135),由于指数函数特性,拉大的原始的差距( e2.0≈7.4 e1.0≈2.7)。
LLM 中常见的配置 temperature 就是在这个公式的分子和分母同时除以 T。
XWV : 算出 V 矩阵。
再乘以 softmax 后的概率 A,就是最终的结果。
这里 X 参数是形状为 3,4096 的矩阵,而不是单一的 token,在这个过程中,每个 token 之间都相互融合,由于前一个 token 不能看到后一个 token,所以后一个 token 的向量是包含了前面所有信息的。
后一个 token 看不到前一个 token 通过 Mask 矩阵实现。
Attention=softmax(d QKT+Mask)V
上面的公式简化后是这样的,Mask 矩阵是一个上三角矩阵 (右上角全是负无穷 −∞),例如(为了简化这里 d_modle 是 3):
Mask= 000−∞00−∞−∞0
这样任何矩阵加上这个矩阵右上角都是负无穷,而 e−∞=0,这保证了在原始句子 "The cat sat" 中 cat 的概率永远是 1。
如果说后一个 token 向量包含了所有信息,那为什么还要算所有向量的 Q K V?这是因为后一个向量计算的过程中,用到了这些数据。
现在用一个 dmodle = 3 的例子,完整演示 MHA 中的计算过程。
入参 X
X= 100020002 ←The←cat←sat
阶段二训练后的参数(注意这里是 WQ,Q 是计算后的矩阵)
WQ= 000001000
WK= 000010000
WV= 100010001
矩阵乘法后
Q= 000002000 ←The (无需求)←cat (无需求)←sat (有需求)
K= 000020000 →KT= 000020000
V= 100020002 ←The 的货←cat 的货←sat 的货
根据公式计算 scores
Scores= 000004000
Masked= 000004000 + 000−∞00−∞−∞0 = 000−∞04−∞−∞0
softmax 后结果
A= 1.00.50.0800.50.84000.08 ←The←cat←sat (聚焦 cat)
解释下第三行怎么算的
公式: softmax(3 Scores)。
3 ≈1.73。
- cat 得分: 4/1.73≈2.3
- The/sat 得分: 0
- e2.3≈10, e0=1。
- 概率: 10/(1+10+1)≈0.84
所以是 0.08,0.84,0.08
最后计算 Z 矩阵( Z=A×V)
Z= 1.00.50.0801.01.68000.16 ←输出给下一层的 The←输出给下一层的 cat←输出给下一层的 sat
至此初始值 X 向量经过 MHA 后变为了 Z 向量,可以看到 sat 的原始向量向 cat 倾斜了。
32 个头同时这么进行,每个头都会得到一个 Z 向量结果,取最后一行的 Zcat 把他们全部拼接起来,又变成了一个形状为 1,4096 的向量。(推理是只用最后一个 token,训练都用了)
但向量的 0128 位是头 1 的表示的是语法,头 2 的 128 256 位表示位置等等,一直到头 32,最终需要再乘以 Wo 矩阵,把这些独立的信息融为一体,得到最终的结果。
这就是 MHA 的全部过程。
FFN
Z 经过残差连接和层归一化后,进入 FFN。
FFN 里面有两个在阶段二训练好的矩阵, W1 和 W2。
W1 是升维矩阵,形状是 d_modle, 4*d_modle, W2 是降维矩阵,形状是 4*d_modle,d_modle。
假设 Z 已经完成残差连接和层归一化,乘以 W1 矩阵
Hup=Zsat×W1
升维后有更多更细节的信息,比如 Apple 这个 token 升维前的向量是 x = [0.8, 0.1, 0.5] 分别代码是水果,是公司,是红色的。
W1 就像包含 6 个问题的问卷:
- 是电子产品吗?
- 是红色的吗?
- 能吃吗?
- 是交通工具吗?
- 有毛吗?
- 是液体吗?
计算 H=x×W1:
结果向量(6维)可能是这样的:
10, 8, 9, -5, -10, -2
- 10 (是电子产品): 命中!
- 8 (是红色的): 命中!
- 9 (能吃): 命中!
- -5 (是交通工具): 完全不是,负分!
- -10 (有毛): 负分滚粗!
- -2 (是液体): 不太像。
然后经过 ReLU 激活函数处理, f(x)=max(0,x),把小于 0 的置位 0,结果就变成了 10, 8, 9, 0, 0, 0。
再通过 W2 把结果压缩回去,去掉无用信息,比如变成了 5.0, 2.0, 8.0。
再把这个值做一道残差连接,整个 Layer 层执行结束,把结果作为输入传给下一个 Layer 层。