Attention-04-decoder部分


场景

arduino 复制代码
源句(英文): "The cat sat"
目标(德文): "Die Katze saß"   ← 模型不知道,要自己生成

编码器已经跑完,Z 已经算好:

scss 复制代码
Z    形状 (3, 512)    ← 3 个源词,固定不变,整个推理过程不再更新

第 0 步:准备开始

解码器输入只有一个开始符:

ini 复制代码
解码器输入: [<BOS>]

<BOS> = Begin of Sentence,告诉解码器"开始生成"。


第 1 步:生成第 1 个词

解码器拿 [<BOS>] 作输入,走完三个子层:

yaml 复制代码
Masked Self-attention: [<BOS>] 只有一个词,只看自己
Cross-attention:       Q 来自 <BOS> 的向量,K/V 来自 Z,查询源句
FFN:                   非线性加工

最后进入 Linear + Softmax,得到词表上的概率分布:

erlang 复制代码
Die      0.41   ← 最高
Der      0.28
Das      0.18
...

取概率最高的词:Die


第 2 步:生成第 2 个词

把 "Die" 加入输入序列:

ini 复制代码
解码器输入: [<BOS>, Die]

走完三个子层:

xml 复制代码
Masked Self-attention: 
  <BOS> 只看 <BOS>
  Die   只看 <BOS>, Die

Cross-attention:
  Q 来自 d[Die](代表"生成了 Die,下一个是什么")
  K/V 来自 Z

Linear + Softmax:

erlang 复制代码
Katze    0.52   ← 最高
Hund     0.21
Maus     0.09
...

取最高:Katze


第 3 步:生成第 3 个词

ini 复制代码
解码器输入: [<BOS>, Die, Katze]
makefile 复制代码
Cross-attention:
  Q 来自 d[Katze](代表"生成了 Die Katze,下一个是什么")
  K/V 来自 Z

Linear + Softmax:

erlang 复制代码
saß      0.61   ← 最高
sitzt    0.18
lag      0.08
...

取最高:saß


第 4 步:生成结束符

ini 复制代码
解码器输入: [<BOS>, Die, Katze, saß]

Linear + Softmax:

xml 复制代码
<EOS>    0.74   ← 最高
auf      0.12
...

取最高: (End of Sentence)

看到 <EOS>,停止生成。


完整结果

makefile 复制代码
生成序列: Die  Katze  saß

整个推理过程的规律

xml 复制代码
步骤 1: 输入 [<BOS>]              → 生成 Die
步骤 2: 输入 [<BOS>, Die]         → 生成 Katze
步骤 3: 输入 [<BOS>, Die, Katze]  → 生成 saß
步骤 4: 输入 [<BOS>, Die, Katze, saß] → 生成 <EOS> → 停止

每步输入序列比上一步多一个词,这就是自回归(auto-regressive) ------用自己已经生成的结果作为下一步的输入。


三个关键点

1. 编码器只跑一次

Z 在推理开始前算好,整个推理过程固定不变。解码器每一步都查询同一个 Z。

2. 解码器每步重新跑完整的三个子层

不是只算新加的那个词,而是把整个输入序列重新过一遍。但因为有 Masked Self-attention,每个位置只看前面的词,结果和上一步一致,不会重复计算错误。

3. 取最高概率不是唯一的策略

这里用的是 Greedy Search(贪心搜索)------每步取概率最高的词。实际中还有 Beam Search(束搜索)------每步保留概率最高的 k 个候选,最后取整体概率最高的序列,效果更好但更慢。


<BOS> 经过 Embedding + 位置编码之后就是 (1, 512)

我们把第 1 步的每个矩阵维度完整走一遍。


起点

xml 复制代码
解码器输入: [<BOS>]

Embedding + 位置编码:
<BOS> → (1, 512)

子层 1:Masked Self-attention

只有一个词,掩码没有实际效果(只有自己,看自己)。

scss 复制代码
输入 X_dec    形状 (1, 512)

投影:
Q = X_dec · W_Q    (1,512) · (512,64) = (1, 64)
K = X_dec · W_K    (1,512) · (512,64) = (1, 64)
V = X_dec · W_V    (1,512) · (512,64) = (1, 64)

点积打分:
Q · Kᵀ    (1,64) · (64,1) = (1, 1)    ← 1个词对自己的分数,就一个数

÷ √64,softmax:
权重      (1, 1)    ← 值为 1.0(只有自己,100%注意自己)

加权求和:
权重 · V    (1,1) · (1,64) = (1, 64)    ← 单头输出

8个头拼接:
(1, 64) × 8 → (1, 512)

· W_O:
(1,512) · (512,512) = (1, 512)

残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)

输出 d(1, 512)


子层 2:Cross-attention

Q 来自解码器,K 和 V 来自编码器的 Z。

scss 复制代码
d(解码器)    形状 (1, 512)
Z(编码器)    形状 (3, 512)    ← 源句 3 个词

投影:
Q = d · W_Q    (1,512) · (512,64) = (1, 64)    ← 解码器提问
K = Z · W_K    (3,512) · (512,64) = (3, 64)    ← 源句 3 个 Key
V = Z · W_V    (3,512) · (512,64) = (3, 64)    ← 源句 3 个 Value

点积打分:
Q · Kᵀ    (1,64) · (64,3) = (1, 3)
               ↑
               1个解码器位置 对 3个源词 各打一个分

÷ √64,softmax:
权重    (1, 3)    ← 3个数,加起来=1,代表对源句3个词的注意力分配

例:
[0.15, 0.65, 0.20]
  The   cat   sat

加权求和:
权重 · V    (1,3) · (3,64) = (1, 64)    ← 单头输出

8个头拼接:
(1, 64) × 8 → (1, 512)

· W_O:
(1,512) · (512,512) = (1, 512)

残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)

输出:(1, 512)


子层 3:FFN

scss 复制代码
输入    (1, 512)

· W₁:
(1,512) · (512,2048) = (1, 2048)    ← 升维

ReLU:
(1, 2048)                            ← 负值截断

· W₂:
(1,2048) · (2048,512) = (1, 512)    ← 降回

残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)

输出:(1, 512)


第 1 层解码器结束,传入第 2 层,重复 6 次。

第 6 层结束后进入输出层:

scss 复制代码
(1, 512) · Linear(512, 37000) = (1, 37000)

Softmax → (1, 37000)    ← 词表每个词的概率

取最大 → "Die"

整条维度变化总结

scss 复制代码
<BOS> embedding        (1, 512)
↓ Masked Self-attn
  Q,K,V 各            (1, 64)   × 8头
  打分矩阵            (1, 1)
  单头输出            (1, 64)
  拼接+W_O            (1, 512)
  残差+LN             (1, 512)
↓ Cross-attention
  Q                   (1, 64)   ← 来自解码器
  K,V                 (3, 64)   ← 来自编码器Z
  打分矩阵            (1, 3)    ← 1个位置对3个源词打分
  单头输出            (1, 64)
  拼接+W_O            (1, 512)
  残差+LN             (1, 512)
↓ FFN
  升维                (1, 2048)
  降维                (1, 512)
  残差+LN             (1, 512)
↓ × 6层
↓ Linear              (1, 37000)
↓ Softmax             (1, 37000)
↓ 取最大              "Die"

Cross-attention 里 (1, 3) 这个打分矩阵是最关键的地方------1 行代表解码器当前位置,3 列代表源句 3 个词。这一行 softmax 之后的 3 个数,就是"生成这个词时,应该参考源句哪里"的权重分配。

相关推荐
码之气三段.2 小时前
牛客周赛 Round 145-E(写了200行的史山)
算法·深度优先
计算机安禾2 小时前
【算法分析与设计】第13篇:最小生成树:Prim算法与Kruskal算法的比较研究
大数据·人工智能·算法
vortex52 小时前
国密(商用密码)算法核心参数速查
算法·密码学
wuweijianlove3 小时前
算法中的记忆化思想与重复子问题优化的技术5
算法
小江的记录本4 小时前
【JVM虚拟机】垃圾回收GC:垃圾判定算法:引用计数法、可达性分析算法(附《思维导图》+《面试高频考点清单》)
java·jvm·后端·python·算法·spring·面试
Hello.Reader4 小时前
算法基础(十四)—— 随机化快速排序为什么平均表现很好
算法
吴可可1234 小时前
Teigha中OdGe几何库详解及C#使用
算法
爱喝水的鱼丶4 小时前
SAP-ABAP:变量、常量、结构与内表声明(10篇博客合集) 第六篇:ABAP 7.40+新特性:声明语法的简化写法与兼容注意事项
运维·服务器·开发语言·学习·算法·sap·abap
国科安芯4 小时前
AS32S601商业航天级抗辐照MCU芯片:架构设计与技术特性研究
单片机·嵌入式硬件·算法·安全·架构·risc-v