场景
arduino
源句(英文): "The cat sat"
目标(德文): "Die Katze saß" ← 模型不知道,要自己生成
编码器已经跑完,Z 已经算好:
scss
Z 形状 (3, 512) ← 3 个源词,固定不变,整个推理过程不再更新
第 0 步:准备开始
解码器输入只有一个开始符:
ini
解码器输入: [<BOS>]
<BOS> = Begin of Sentence,告诉解码器"开始生成"。
第 1 步:生成第 1 个词
解码器拿 [<BOS>] 作输入,走完三个子层:
yaml
Masked Self-attention: [<BOS>] 只有一个词,只看自己
Cross-attention: Q 来自 <BOS> 的向量,K/V 来自 Z,查询源句
FFN: 非线性加工
最后进入 Linear + Softmax,得到词表上的概率分布:
erlang
Die 0.41 ← 最高
Der 0.28
Das 0.18
...
取概率最高的词:Die
第 2 步:生成第 2 个词
把 "Die" 加入输入序列:
ini
解码器输入: [<BOS>, Die]
走完三个子层:
xml
Masked Self-attention:
<BOS> 只看 <BOS>
Die 只看 <BOS>, Die
Cross-attention:
Q 来自 d[Die](代表"生成了 Die,下一个是什么")
K/V 来自 Z
Linear + Softmax:
erlang
Katze 0.52 ← 最高
Hund 0.21
Maus 0.09
...
取最高:Katze
第 3 步:生成第 3 个词
ini
解码器输入: [<BOS>, Die, Katze]
makefile
Cross-attention:
Q 来自 d[Katze](代表"生成了 Die Katze,下一个是什么")
K/V 来自 Z
Linear + Softmax:
erlang
saß 0.61 ← 最高
sitzt 0.18
lag 0.08
...
取最高:saß
第 4 步:生成结束符
ini
解码器输入: [<BOS>, Die, Katze, saß]
Linear + Softmax:
xml
<EOS> 0.74 ← 最高
auf 0.12
...
取最高: (End of Sentence)
看到 <EOS>,停止生成。
完整结果
makefile
生成序列: Die Katze saß
整个推理过程的规律
xml
步骤 1: 输入 [<BOS>] → 生成 Die
步骤 2: 输入 [<BOS>, Die] → 生成 Katze
步骤 3: 输入 [<BOS>, Die, Katze] → 生成 saß
步骤 4: 输入 [<BOS>, Die, Katze, saß] → 生成 <EOS> → 停止
每步输入序列比上一步多一个词,这就是自回归(auto-regressive) ------用自己已经生成的结果作为下一步的输入。
三个关键点
1. 编码器只跑一次
Z 在推理开始前算好,整个推理过程固定不变。解码器每一步都查询同一个 Z。
2. 解码器每步重新跑完整的三个子层
不是只算新加的那个词,而是把整个输入序列重新过一遍。但因为有 Masked Self-attention,每个位置只看前面的词,结果和上一步一致,不会重复计算错误。
3. 取最高概率不是唯一的策略
这里用的是 Greedy Search(贪心搜索)------每步取概率最高的词。实际中还有 Beam Search(束搜索)------每步保留概率最高的 k 个候选,最后取整体概率最高的序列,效果更好但更慢。
<BOS> 经过 Embedding + 位置编码之后就是 (1, 512)。
我们把第 1 步的每个矩阵维度完整走一遍。
起点
xml
解码器输入: [<BOS>]
Embedding + 位置编码:
<BOS> → (1, 512)
子层 1:Masked Self-attention
只有一个词,掩码没有实际效果(只有自己,看自己)。
scss
输入 X_dec 形状 (1, 512)
投影:
Q = X_dec · W_Q (1,512) · (512,64) = (1, 64)
K = X_dec · W_K (1,512) · (512,64) = (1, 64)
V = X_dec · W_V (1,512) · (512,64) = (1, 64)
点积打分:
Q · Kᵀ (1,64) · (64,1) = (1, 1) ← 1个词对自己的分数,就一个数
÷ √64,softmax:
权重 (1, 1) ← 值为 1.0(只有自己,100%注意自己)
加权求和:
权重 · V (1,1) · (1,64) = (1, 64) ← 单头输出
8个头拼接:
(1, 64) × 8 → (1, 512)
· W_O:
(1,512) · (512,512) = (1, 512)
残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)
输出 d:(1, 512)
子层 2:Cross-attention
Q 来自解码器,K 和 V 来自编码器的 Z。
scss
d(解码器) 形状 (1, 512)
Z(编码器) 形状 (3, 512) ← 源句 3 个词
投影:
Q = d · W_Q (1,512) · (512,64) = (1, 64) ← 解码器提问
K = Z · W_K (3,512) · (512,64) = (3, 64) ← 源句 3 个 Key
V = Z · W_V (3,512) · (512,64) = (3, 64) ← 源句 3 个 Value
点积打分:
Q · Kᵀ (1,64) · (64,3) = (1, 3)
↑
1个解码器位置 对 3个源词 各打一个分
÷ √64,softmax:
权重 (1, 3) ← 3个数,加起来=1,代表对源句3个词的注意力分配
例:
[0.15, 0.65, 0.20]
The cat sat
加权求和:
权重 · V (1,3) · (3,64) = (1, 64) ← 单头输出
8个头拼接:
(1, 64) × 8 → (1, 512)
· W_O:
(1,512) · (512,512) = (1, 512)
残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)
输出:(1, 512)
子层 3:FFN
scss
输入 (1, 512)
· W₁:
(1,512) · (512,2048) = (1, 2048) ← 升维
ReLU:
(1, 2048) ← 负值截断
· W₂:
(1,2048) · (2048,512) = (1, 512) ← 降回
残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)
输出:(1, 512)
第 1 层解码器结束,传入第 2 层,重复 6 次。
第 6 层结束后进入输出层:
scss
(1, 512) · Linear(512, 37000) = (1, 37000)
Softmax → (1, 37000) ← 词表每个词的概率
取最大 → "Die"
整条维度变化总结
scss
<BOS> embedding (1, 512)
↓ Masked Self-attn
Q,K,V 各 (1, 64) × 8头
打分矩阵 (1, 1)
单头输出 (1, 64)
拼接+W_O (1, 512)
残差+LN (1, 512)
↓ Cross-attention
Q (1, 64) ← 来自解码器
K,V (3, 64) ← 来自编码器Z
打分矩阵 (1, 3) ← 1个位置对3个源词打分
单头输出 (1, 64)
拼接+W_O (1, 512)
残差+LN (1, 512)
↓ FFN
升维 (1, 2048)
降维 (1, 512)
残差+LN (1, 512)
↓ × 6层
↓ Linear (1, 37000)
↓ Softmax (1, 37000)
↓ 取最大 "Die"
Cross-attention 里 (1, 3) 这个打分矩阵是最关键的地方------1 行代表解码器当前位置,3 列代表源句 3 个词。这一行 softmax 之后的 3 个数,就是"生成这个词时,应该参考源句哪里"的权重分配。