LLM 四阶段和 Transformer 架构(二)

上一篇解释完点积和矩阵乘法,矩阵乘法是一种转换,这一篇看 Transformer 中如何运用的。

LLM 的本质是预测下一个 token,阶段二中,使用大量的互联网内容,给模型做训练,使用自监督学习,不断调整 1750 亿个参数。直到模型能够正确的补全文本内容。

阶段二的产物,只有文本补全功能,不具备问答、对话能力。

现在假设所有参数已经调节完毕,以一个输入是 "The cat sat" 展示模型怎么预测下一个 token 是 "on" 的。

查询向量坐标

输入 "The cat sat" 经过 Tokenizer 会把这句话分为 the、cat、sat 三个 token,再从 Vocab Map<String, Integer> 中找到对应的id,再去 Embedding Table 中找到这三个向量坐标。

得到 3 个形状为 1,4096 的向量,Vector_The: 0.1, -0.5, ...、Vector_cat: 0.8, 0.2, ...、Vector_sat: -0.1, 0.9, ...。把这三个向量合并到一个矩阵里面,得到一个形状是 3,4096 的向量。

X i n = x ⃗ T h e x ⃗ c a t x ⃗ s a t X_{in} = \begin{bmatrix} \vec{x}{The} \\ \vec{x}{cat} \\ \vec{x}_{sat} \end{bmatrix} Xin= x Thex catx sat

进入 Layer 加工

接下来正式进入 Layer 加工。之前说过一共有 96 层 Layers,每一层 Layer 有 MHA (Multi-Head Attention) 多头注意力机制 和 FFN (Feed-Forward Network) 前馈神经网络处理。形状为 3,4096 的矩阵会完整经历所有 Layers,最后得到加工后的 3,4096

x_inNormMHA(+ 残差连接)x_midNormFFN(+ 残差连接)x_out

以上是一层 Layer 完整过程。 x_out 会是下一层 Layer 的 x_in。

Norm 层归一化

其中 Norm 是 Layer Normalization(层归一化),矩阵乘法的结果范围非常大,有的值是 50000 而有的是 0.000003,为了防止计算溢出或者梯度乱跳,需要把这些值统一处理为均值为 0 方差为 1。

Norm 计算包含四个步骤,假设以 ( x ⃗ c a t \vec{x}_{cat} x cat): 10, 2, 12, 0 为例。

  1. 求均值 ( μ \mu μ)

    ( 10 + 2 + 12 + 0 ) ÷ 4 = 6 (10 + 2 + 12 + 0) \div 4 = 6 (10+2+12+0)÷4=6

  2. 求方差 ( σ 2 \sigma^2 σ2)

    10 → ( 10 − 6 ) 2 = 16 10 \to (10-6)^2 = 16 10→(10−6)2=16
    2 → ( 2 − 6 ) 2 = 16 2 \to (2-6)^2 = 16 2→(2−6)2=16
    12 → ( 12 − 6 ) 2 = 36 12 \to (12-6)^2 = 36 12→(12−6)2=36
    0 → ( 0 − 6 ) 2 = 36 0 \to (0-6)^2 = 36 0→(0−6)2=36

    方差 = ( 16 + 16 + 36 + 36 ) ÷ 4 = 26 \text{方差} = (16 + 16 + 36 + 36) \div 4 = 26 方差=(16+16+36+36)÷4=26

    标准差 ( σ \sigma σ): 26 ≈ 5.1 \sqrt{26} \approx 5.1 26 ≈5.1

  3. 归一化 (Normalize)

    公式: x − 均值 标准差 \frac{x - \text{均值}}{\text{标准差}} 标准差x−均值目的是把数据强行拉回到 "均值为 0,方差为 1" 的标准形态。

    10 → ( 10 − 6 ) / 5.1 ≈ 0.78 10 \to (10 - 6) / 5.1 \approx \mathbf{0.78} 10→(10−6)/5.1≈0.78

    2 → ( 2 − 6 ) / 5.1 ≈ − 0.78 2 \to (2 - 6) / 5.1 \approx \mathbf{-0.78} 2→(2−6)/5.1≈−0.78

    12 → ( 12 − 6 ) / 5.1 ≈ 1.17 12 \to (12 - 6) / 5.1 \approx \mathbf{1.17} 12→(12−6)/5.1≈1.17

    0 → ( 0 − 6 ) / 5.1 ≈ − 1.17 0 \to (0 - 6) / 5.1 \approx \mathbf{-1.17} 0→(0−6)/5.1≈−1.17

    结果向量: 0.78, -0.78, 1.17, -1.17

  4. 缩放与平移 (Scale & Shift) 如果每次都强行变成 0 均值,可能会破坏数据的含义。所以模型有两个可变参数: γ \gamma γ (缩放) 和 β \beta β (平移)。这里可以理解成一个一元二次方程的线性函数。 假设模型学到这一层需要数值稍微大一点:

    γ = 2 , 2 , 2 , 2 \gamma = 2, 2, 2, 2 γ=2,2,2,2 β = 1 , 1 , 1 , 1 \beta = 1, 1, 1, 1 β=1,1,1,1

    最终输出 = 归一化结果 × γ + β \text{归一化结果} \times \gamma + \beta 归一化结果×γ+β

    0.78 × 2 + 1 = 2.56 0.78 \times 2 + 1 = \mathbf{2.56} 0.78×2+1=2.56

    ...

残差连接

经过 MHA 和 FFN 后,为了防止原始数据丢失,会加上处理前的原始值。

O u t p u t = New_Process ( x ) + x Output = \text{New\_Process}(x) + x Output=New_Process(x)+x

MHA

回到公式,x_in 是形状为 3,4096 的矩阵,经过 Norm 后还是 3,4096。接着进入 MHA。

MHA 中有 W Q 、 W K 、 W V W_Q、W_K、W_V WQ、WK、WV、 W O W_O WO 四个形状为 d_modle,d_modle4096,4096) 的矩阵。这是在阶段二训练好的。

Q 是 question,K 是 key,V 是 value,这是三个非常抽象的矩阵,他们的作用是把 "The cat sat" 中三个 token 的向量坐标互相融合,比如 the 要更关注 cat,经过融合后 the 这个 token 的向量值里就包含了大量 cat 的向量值。

MHA 叫多头注意力机制,比如有 32 个头,4096/32=128,就是把 W Q 、 W K 、 W V W_Q、W_K、W_V WQ、WK、WV 分为 32 个形状为 4096,128 的小矩阵,并行计算得到 32 个结果,再合并起来乘以 W O W_O WO 得到最终的产物。

大量的并行矩阵计算,这也是算力以 GPU 为核心的原因。

头也是抽象的概念,可以用角度类比,以 "The cat sat on the mat because it was tired." 为例。

Head 1 (语法眼): 专门盯着主谓关系。它发现 "it" 指代 "cat"。

Head 2 (逻辑眼): 专门盯着因果关系。它发现 "because" 导致了 "tired"。

Head 3 (位置眼): 专门盯着方位关系。它关注 "on the mat"。

真个 MHA 的计算过程用一个公式表示。

Z = softmax ( ( X W Q ) ( X W K ) T d m o d e l ) ( X W V ) Z = \text{softmax}\left( \frac{(X W_Q)(X W_K)^T}{\sqrt{d_{model}}} \right) (X W_V) Z=softmax(dmodel (XWQ)(XWK)T)(XWV)

其中, Scores = ( X W Q ) ( X W K ) T \text{Scores} = (X W_Q)(X W_K)^T Scores=(XWQ)(XWK)T

X W Q X W_Q XWQ : 算出 Q 矩阵。
X W K X W_K XWK : 算出 K 矩阵。
( ...   ) T (\dots)^T (...)T : 转置 K 矩阵,为了能进行矩阵乘法(前一个矩阵的横行必须等于后一个矩阵的纵列)。

转置是沿对角线翻转。比如

K = 1 2 3 4 5 6 K = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} K=142536

转置后是

K T = 1 4 2 5 3 6 K^T = \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix} KT= 123456

d_modle 计算模型设定的维度,除以 d m o d e l \sqrt{d_{model}} dmodel 也是为了防止向量值之间差距太大。

A = softmax ( Scores 4 ) A = \text{softmax}\left( \frac{\text{Scores}}{\sqrt{4}} \right) A=softmax(4 Scores)

softmax 是一种把一堆数字变为总和为 1 的小数算法,假设有一个输入向量(Logits) z = z 1 , z 2 , . . . , z n z = z_1, z_2, ..., z_n z=z1,z2,...,zn。 Softmax 函数 σ ( z ) i \sigma(z)_i σ(z)i 的公式是:

σ ( z ) i = e z i ∑ j = 1 N e z j \sigma(z)i = \frac{e^{z_i}}{\sum{j=1}^N e^{z_j}} σ(z)i=∑j=1Nezjezi

分子 ( e z i e^{z_i} ezi) :算出当前元素的指数值。

分母 ( ∑ e z j \sum e^{z_j} ∑ezj) :算出所有元素指数值的总和。

用到了 e 自然指数,把负数也转成正数( e − 2 ≈ 0.135 e^{-2} \approx 0.135 e−2≈0.135),由于指数函数特性,拉大的原始的差距( e 2.0 ≈ 7.4 e^{2.0} \approx \mathbf{7.4} e2.0≈7.4 e 1.0 ≈ 2.7 e^{1.0} \approx \mathbf{2.7} e1.0≈2.7)。

LLM 中常见的配置 temperature 就是在这个公式的分子和分母同时除以 T。

X W V X W_V XWV : 算出 V 矩阵。

再乘以 softmax 后的概率 A,就是最终的结果。

这里 X 参数是形状为 3,4096 的矩阵,而不是单一的 token,在这个过程中,每个 token 之间都相互融合,由于前一个 token 不能看到后一个 token,所以后一个 token 的向量是包含了前面所有信息的。

后一个 token 看不到前一个 token 通过 Mask 矩阵实现。

Attention = softmax ( Q K T d + M a s k ) V \text{Attention} = \text{softmax}\left( \frac{Q K^T}{\sqrt{d}} + \mathbf{Mask} \right) V Attention=softmax(d QKT+Mask)V

上面的公式简化后是这样的,Mask 矩阵是一个上三角矩阵 (右上角全是负无穷 − ∞ -\infty −∞),例如(为了简化这里 d_modle 是 3):

Mask = 0 − ∞ − ∞ 0 0 − ∞ 0 0 0 \text{Mask} = \begin{bmatrix} 0 & -\infty & -\infty \\ 0 & 0 & -\infty \\ 0 & 0 & 0 \end{bmatrix} Mask= 000−∞00−∞−∞0

这样任何矩阵加上这个矩阵右上角都是负无穷,而 e − ∞ = 0 e^{-\infty} = 0 e−∞=0,这保证了在原始句子 "The cat sat" 中 cat 的概率永远是 1。

如果说后一个 token 向量包含了所有信息,那为什么还要算所有向量的 Q K V?这是因为后一个向量计算的过程中,用到了这些数据。

现在用一个 dmodle = 3 的例子,完整演示 MHA 中的计算过程。

入参 X

X = 1 0 0 0 2 0 0 0 2 ← The ← cat ← sat X = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{bmatrix} \begin{matrix} \leftarrow \text{The} \\ \leftarrow \text{cat} \\ \leftarrow \text{sat} \end{matrix} X= 100020002 ←The←cat←sat

阶段二训练后的参数(注意这里是 W Q W_Q WQ,Q 是计算后的矩阵)

W Q = 0 0 0 0 0 0 0 1 0 W_Q = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & \mathbf{1} & 0 \end{bmatrix} WQ= 000001000

W K = 0 0 0 0 1 0 0 0 0 W_K = \begin{bmatrix} 0 & 0 & 0 \\ 0 & \mathbf{1} & 0 \\ 0 & 0 & 0 \end{bmatrix} WK= 000010000

W V = 1 0 0 0 1 0 0 0 1 W_V = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix} WV= 100010001

矩阵乘法后

Q = 0 0 0 0 0 0 0 2 0 ← The (无需求) ← cat (无需求) ← sat (有需求) Q = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 2 & 0 \end{bmatrix} \begin{matrix} \leftarrow \text{The (无需求)} \\ \leftarrow \text{cat (无需求)} \\ \leftarrow \text{sat (有需求)} \end{matrix} Q= 000002000 ←The (无需求)←cat (无需求)←sat (有需求)

K = 0 0 0 0 2 0 0 0 0 → K T = 0 0 0 0 2 0 0 0 0 K = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 0 \end{bmatrix} \quad \rightarrow \quad K^T = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 0 \end{bmatrix} K= 000020000 →KT= 000020000

V = 1 0 0 0 2 0 0 0 2 ← The 的货 ← cat 的货 ← sat 的货 V = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{bmatrix} \begin{matrix} \leftarrow \text{The 的货} \\ \leftarrow \text{cat 的货} \\ \leftarrow \text{sat 的货} \end{matrix} V= 100020002 ←The 的货←cat 的货←sat 的货

根据公式计算 scores

Scores = 0 0 0 0 0 0 0 4 0 \text{Scores} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & \mathbf{4} & 0 \end{bmatrix} Scores= 000004000

Masked = 0 0 0 0 0 0 0 4 0 + 0 − ∞ − ∞ 0 0 − ∞ 0 0 0 = 0 − ∞ − ∞ 0 0 − ∞ 0 4 0 \text{Masked} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 4 & 0 \end{bmatrix} + \begin{bmatrix} 0 & -\infty & -\infty \\ 0 & 0 & -\infty \\ 0 & 0 & 0 \end{bmatrix} = \begin{bmatrix} 0 & -\infty & -\infty \\ 0 & 0 & -\infty \\ 0 & \mathbf{4} & 0 \end{bmatrix} Masked= 000004000 + 000−∞00−∞−∞0 = 000−∞04−∞−∞0

softmax 后结果

A = 1.0 0 0 0.5 0.5 0 0.08 0.84 0.08 ← The ← cat ← s a t (聚焦 cat) A = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.5 & 0.5 & 0 \\ \mathbf{0.08} & \mathbf{0.84} & \mathbf{0.08} \end{bmatrix} \begin{matrix} \leftarrow \text{The} \\ \leftarrow \text{cat} \\ \leftarrow \mathbf{sat} \text{ (聚焦 cat)} \end{matrix} A= 1.00.50.0800.50.84000.08 ←The←cat←sat (聚焦 cat)

解释下第三行怎么算的

公式: softmax ( Scores 3 ) \text{softmax}(\frac{\text{Scores}}{\sqrt{3}}) softmax(3 Scores)。

3 ≈ 1.73 \sqrt{3} \approx 1.73 3 ≈1.73。

  • cat 得分: 4 / 1.73 ≈ 2.3 4 / 1.73 \approx 2.3 4/1.73≈2.3
  • The/sat 得分: 0 0 0
  • e 2.3 ≈ 10 e^{2.3} \approx 10 e2.3≈10, e 0 = 1 e^0 = 1 e0=1。
  • 概率: 10 / ( 1 + 10 + 1 ) ≈ 0.84 10 / (1+10+1) \approx 0.84 10/(1+10+1)≈0.84

所以是 0.08,0.84,0.08

最后计算 Z 矩阵( Z = A × V Z = A \times V Z=A×V)

Z = 1.0 0 0 0.5 1.0 0 0.08 1.68 0.16 ← 输出给下一层的 The ← 输出给下一层的 cat ← 输出给下一层的 sat Z = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.5 & 1.0 & 0 \\ \mathbf{0.08} & \mathbf{1.68} & \mathbf{0.16} \end{bmatrix} \begin{matrix} \leftarrow \text{输出给下一层的 The} \\ \leftarrow \text{输出给下一层的 cat} \\ \leftarrow \text{输出给下一层的 sat} \end{matrix} Z= 1.00.50.0801.01.68000.16 ←输出给下一层的 The←输出给下一层的 cat←输出给下一层的 sat

至此初始值 X 向量经过 MHA 后变为了 Z 向量,可以看到 sat 的原始向量向 cat 倾斜了。

32 个头同时这么进行,每个头都会得到一个 Z 向量结果,取最后一行的 Z c a t Z_{cat} Zcat 把他们全部拼接起来,又变成了一个形状为 1,4096 的向量。(推理是只用最后一个 token,训练都用了)

但向量的 0128 位是头 1 的表示的是语法,头 2 的 128 256 位表示位置等等,一直到头 32,最终需要再乘以 W o W_o Wo 矩阵,把这些独立的信息融为一体,得到最终的结果。

这就是 MHA 的全部过程。

FFN

Z 经过残差连接和层归一化后,进入 FFN。

FFN 里面有两个在阶段二训练好的矩阵, W 1 W_1 W1 和 W 2 W_2 W2。

W 1 W_1 W1 是升维矩阵,形状是 d_modle, 4*d_modle W 2 W_2 W2 是降维矩阵,形状是 4*d_modle,d_modle

假设 Z 已经完成残差连接和层归一化,乘以 W 1 W_1 W1 矩阵

H u p = Z s a t × W 1 H_{up} = Z_{sat} \times W_1 Hup=Zsat×W1

升维后有更多更细节的信息,比如 Apple 这个 token 升维前的向量是 x = [0.8, 0.1, 0.5] 分别代码是水果,是公司,是红色的。

W 1 W_1 W1 就像包含 6 个问题的问卷:

  • 是电子产品吗?
  • 是红色的吗?
  • 能吃吗?
  • 是交通工具吗?
  • 有毛吗?
  • 是液体吗?

计算 H = x × W 1 H = x \times W_1 H=x×W1:

结果向量(6维)可能是这样的:

10, 8, 9, -5, -10, -2

  • 10 (是电子产品): 命中!
  • 8 (是红色的): 命中!
  • 9 (能吃): 命中!
  • -5 (是交通工具): 完全不是,负分!
  • -10 (有毛): 负分滚粗!
  • -2 (是液体): 不太像。

然后经过 ReLU 激活函数处理, f ( x ) = max ⁡ ( 0 , x ) f(x) = \max(0, x) f(x)=max(0,x),把小于 0 的置位 0,结果就变成了 10, 8, 9, 0, 0, 0

再通过 W 2 W_2 W2 把结果压缩回去,去掉无用信息,比如变成了 5.0, 2.0, 8.0

再把这个值做一道残差连接,整个 Layer 层执行结束,把结果作为输入传给下一个 Layer 层。

相关推荐
带刺的坐椅1 天前
从 Claude Code 隐私争议,看 SolonCode 的设计选择
ai·llm·agent·claudecode·soloncode·codingplan
MomentYY1 天前
Temperature:AI 的“脑洞旋钮”
前端·llm·ai编程
Darling噜啦啦2 天前
上下文工程实战:从 Prompt 到 Harness 的三次 AI 工程化浪潮
llm·ai编程
Hyyy2 天前
Function Calling / Tool Use的原理和实现模式
前端·llm·ai编程
智泊AI2 天前
Loop Engineering 为什么会出现?一个 Loop 的组成部分有哪些?
llm
凌奕3 天前
别用文档约束你的 Agent:聊聊 Agent 开发流程的思想
llm·github·agent
Java之美3 天前
vLLM 是怎么工作的?
llm
JouYY4 天前
聊一下多 Agent 编排架构的应用实践
架构·llm·agent
To_OC5 天前
数据集划分不是随便切:手把手切分大众点评情感数据集
人工智能·llm·agent