深度学习的数学原理(三十四)—— Transformer 解码器完整实现

在第 33 篇中,我们系统地构建了 Transformer 编码器 的完整架构。编码器负责将源序列映射为丰富的上下文表征,但它无法自行生成序列------这正是解码器的职责。

解码器作为自回归生成的核心引擎,通过三大子层的协同工作实现序列生成:

  1. 掩码多头自注意力------保证因果性,确保生成过程不"偷看"未来
  2. 标准交叉注意力------实现编码器→解码器的信息传递,完成源-目标词对齐
  3. 前馈网络------对每个位置独立进行非线性特征变换

三者通过残差连接与层归一化有机结合,形成可深度堆叠的解码器层。

本文将沿着"数学原理→手动实例→代码验证"的固定结构展开:

  • 数学链路:推导解码器从目标嵌入到输出投影的完整前向传播公式,重点分析自回归生成的因果约束
  • 结构拆解:逐层解析掩码自注意力 → 残差 LN → 交叉注意力 → 残差 LN → FFN → 残差 LN 的三级结构
  • 数值实例 :以 i love deep learning 为目标序列,结合编码器输出,完整手算单层解码器的前向过程,验证交叉注意力的词对齐效果
  • 代码实现:提供与《Attention Is All You Need》原论文完全对齐的 PyTorch 实现,包含掩码处理、交叉注意力与编码器输出对接,可独立运行验证

一、解码器的完整数学链路

1.1 解码器的整体架构

Transformer 解码器的整体架构可表示为以下函数复合:

Decoder(Y,Xenc)=Softmax(WO∘DecoderLayerN∘(Embed(Y)+PE)) \text{Decoder}(Y, X_{\text{enc}}) = \text{Softmax}\left(W_O \circ \text{DecoderLayer}^N \circ (\text{Embed}(Y) + \text{PE}) \right) Decoder(Y,Xenc)=Softmax(WO∘DecoderLayerN∘(Embed(Y)+PE))

其中:

  • Y∈NToutY \in \mathbb{N}^{T_{\text{out}}}Y∈NTout:目标序列的 token 索引序列
  • Xenc∈RTin×dmodelX_{\text{enc}} \in \mathbb{R}^{T_{\text{in}} \times d_{\text{model}}}Xenc∈RTin×dmodel:编码器的最终输出
  • NNN:解码器层数(原论文中 N=6N=6N=6)
  • ToutT_{\text{out}}Tout:目标序列长度(含起始符)
  • TinT_{\text{in}}Tin:源序列长度

整条数据流可以分为四大阶段:

复制代码
目标序列 tokens
    ↓
词嵌入层 + 位置编码        ← 阶段1:输入表示
    ↓
第1层解码器层              ← 阶段2:N层堆叠
第2层解码器层
...
第N层解码器层
    ↓
线性投影 + Softmax         ← 阶段3:输出投影
    ↓
下一个 token 的概率分布     ← 阶段4:自回归生成

1.2 目标序列嵌入与位置编码

与编码器相同,解码器的输入首先经过词嵌入层:

Xemb=Embed(Y)∈RTout×dmodel X_{\text{emb}} = \text{Embed}(Y) \in \mathbb{R}^{T_{\text{out}} \times d_{\text{model}}} Xemb=Embed(Y)∈RTout×dmodel

其中嵌入矩阵 E∈R∣V∣×dmodelE \in \mathbb{R}^{|V| \times d_{\text{model}}}E∈R∣V∣×dmodel 与编码器共享(原论文中两者共享词嵌入,且权重乘以 dmodel\sqrt{d_{\text{model}}}dmodel 以控制数值范围)。

位置编码采用与编码器完全相同的正弦-余弦编码方案。对于位置 pospospos 和维度 2i2i2i(偶数维)及 2i+12i+12i+1(奇数维):

PE(pos,2i)=sin⁡(pos100002i/dmodel) \text{PE}(pos, 2i) = \sin\left( \frac{pos}{10000^{2i/d_{\text{model}}}} \right) PE(pos,2i)=sin(100002i/dmodelpos)

PE(pos,2i+1)=cos⁡(pos100002i/dmodel) \text{PE}(pos, 2i+1) = \cos\left( \frac{pos}{10000^{2i/d_{\text{model}}}} \right) PE(pos,2i+1)=cos(100002i/dmodelpos)

将嵌入与位置编码相加得到解码器的初始输入:

X0=Xemb+PE(1:Tout)∈RTout×dmodel X_0 = X_{\text{emb}} + \text{PE}(1:T_{\text{out}}) \in \mathbb{R}^{T_{\text{out}} \times d_{\text{model}}} X0=Xemb+PE(1:Tout)∈RTout×dmodel

注意 :在训练时,解码器的输入是经过右移(shifted right) 的目标序列,即在原序列前添加 <sos> 起始符,去掉最后一个词。这确保了模型在预测位置 ttt 的 token 时,只能看到位置 111 到 t−1t-1t−1 的 ground truth token,配合因果掩码实现严格的自回归训练。

1.3 N 层解码器层的堆叠

解码器层采用递归方式堆叠。第 lll 层的输出为:

Xl=DecoderLayerl(Xl−1,Xenc),l=1,2,...,N X_l = \text{DecoderLayer}l(X{l-1}, X_{\text{enc}}), \quad l = 1, 2, \dots, N Xl=DecoderLayerl(Xl−1,Xenc),l=1,2,...,N

每一层解码器层都接收两个输入

  1. 上一层解码器的输出 Xl−1X_{l-1}Xl−1(或初始输入 X0X_0X0),提供目标侧信息
  2. 编码器的最终输出 XencX_{\text{enc}}Xenc,提供源侧上下文

这种设计使得每一层都能同时利用:

  • 已生成的目标序列信息(通过自注意力)
  • 源序列的完整上下文信息(通过交叉注意力)

1.4 自回归生成的数学约束

自回归(Autoregressive) 是解码器最核心的特性。在生成第 ttt 个 token 时,模型只能利用位置 111 到 t−1t-1t−1 的信息,不能利用位置 ttt 到 TTT 的未来信息。

定义(因果约束) :对于任意位置 ttt,解码器的条件概率分布满足:

P(yt∣y1,y2,...,yt−1,Xenc)=Softmax(ft(y1,...,yt−1,Xenc)) P(y_t | y_1, y_2, \dots, y_{t-1}, X_{\text{enc}}) = \text{Softmax}(f_t(y_1, \dots, y_{t-1}, X_{\text{enc}})) P(yt∣y1,y2,...,yt−1,Xenc)=Softmax(ft(y1,...,yt−1,Xenc))

其中 ftf_tft 不依赖于 yt,yt+1,...,yTy_t, y_{t+1}, \dots, y_Tyt,yt+1,...,yT。这意味着第 ttt 个位置的输出不能依赖于任何未来的 token。

为了实现这一约束,需要在自注意力中引入因果掩码(Causal Mask)

Mcausal[i,j]={0,j≤i(允许 attend)−∞,j>i(禁止 attend) M_{\text{causal}}[i, j] = \begin{cases} 0, & j \leq i \quad \text{(允许 attend)} \\ -\infty, & j > i \quad \text{(禁止 attend)} \end{cases} Mcausal[i,j]={0,−∞,j≤i(允许 attend)j>i(禁止 attend)

该掩码构成一个下三角矩阵:

Mcausal=(0−∞−∞⋯−∞00−∞⋯−∞000⋯−∞⋮⋮⋮⋱⋮000⋯0) M_{\text{causal}} = \begin{pmatrix} 0 & -\infty & -\infty & \cdots & -\infty \\ 0 & 0 & -\infty & \cdots & -\infty \\ 0 & 0 & 0 & \cdots & -\infty \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & 0 \end{pmatrix} Mcausal= 000⋮0−∞00⋮0−∞−∞0⋮0⋯⋯⋯⋱⋯−∞−∞−∞⋮0

定理(因果掩码保证自回归性质):带有因果掩码的自注意力满足自回归性质。

证明 :对于自注意力的输出 Attnmasked(X)=Softmax(QK⊤dk+Mcausal)V\text{Attn}{\text{masked}}(X) = \text{Softmax}\left( \frac{QK^\top}{\sqrt{d_k}} + M{\text{causal}} \right) VAttnmasked(X)=Softmax(dk QK⊤+Mcausal)V:

对于第 iii 行,当 j>ij > ij>i 时,Mcausal[i,j]=−∞M_{\text{causal}}[i,j] = -\inftyMcausal[i,j]=−∞,因此 exp⁡(Mcausal[i,j])=0\exp(M_{\text{causal}}[i,j]) = 0exp(Mcausal[i,j])=0:

Softmax(QK⊤dk+Mcausal)i,j=0,∀j>i \text{Softmax}\left( \frac{QK^\top}{\sqrt{d_k}} + M_{\text{causal}} \right)_{i,j} = 0, \quad \forall j > i Softmax(dk QK⊤+Mcausal)i,j=0,∀j>i

行 iii 的输出为:

Output[i,:]=∑j=1iAi,j⋅V[j,:] \text{Output}[i,:] = \sum_{j=1}^{i} A_{i,j} \cdot V[j,:] Output[i,:]=j=1∑iAi,j⋅V[j,:]

即位置 iii 的输出不依赖于任何 j>ij > ij>i 的位置,满足因果约束。

1.5 输出投影与 Softmax

经过 NNN 层解码器后,最终输出 XN∈RTout×dmodelX_N \in \mathbb{R}^{T_{\text{out}} \times d_{\text{model}}}XN∈RTout×dmodel 经过线性投影变换到词表大小:

Logits=XN⋅WO⊤∈RTout×∣V∣ \text{Logits} = X_N \cdot W_O^\top \in \mathbb{R}^{T_{\text{out}} \times |V|} Logits=XN⋅WO⊤∈RTout×∣V∣

其中 WO∈R∣V∣×dmodelW_O \in \mathbb{R}^{|V| \times d_{\text{model}}}WO∈R∣V∣×dmodel 为输出投影矩阵(通常与输入词嵌入矩阵 EEE 共享权重,即权重绑定,可大幅减少参数量)。

最后通过 Softmax 得到每个位置的词概率分布:

P(yt∣y<t,Xenc)=exp⁡(Logits[t,yt])∑v=1∣V∣exp⁡(Logits[t,v]) P(y_t | y_{<t}, X_{\text{enc}}) = \frac{\exp(\text{Logits}[t, y_t])}{\sum_{v=1}^{|V|} \exp(\text{Logits}[t, v])} P(yt∣y<t,Xenc)=∑v=1∣V∣exp(Logits[t,v])exp(Logits[t,yt])

自回归推理 :在推理时,解码器每次只生成一个 token,将新生成的 token 拼接到输入序列末尾,再次送入解码器,重复此过程直到生成 <eos> 终止符或达到最大长度。


二、单层解码器的结构拆解

单层解码器包含三个核心子层,按照以下顺序排列:

DecoderLayer(X,Xenc)=X+FFN(LN(X+CrossAttn(LN(X),Xenc,Xenc))) \text{DecoderLayer}(X, X_{\text{enc}}) = X + \text{FFN}\left( \text{LN}\left( X + \text{CrossAttn}\left( \text{LN}(X), X_{\text{enc}}, X_{\text{enc}} \right) \right) \right) DecoderLayer(X,Xenc)=X+FFN(LN(X+CrossAttn(LN(X),Xenc,Xenc)))

其中 XXX 的初始值为 LN(X0+MaskedSelfAttn(X0,X0,X0))\text{LN}(X_0 + \text{MaskedSelfAttn}(X_0, X_0, X_0))LN(X0+MaskedSelfAttn(X0,X0,X0))。

下面对每个子层进行详细拆解。

2.1 第一子层:掩码多头自注意力

输入 :上一层输出 X∈RT×dmodelX \in \mathbb{R}^{T \times d_{\text{model}}}X∈RT×dmodel

处理流程

步骤 1:线性投影生成 Q, K, V

Q=XWQ,K=XWK,V=XWV Q = XW_Q, \quad K = XW_K, \quad V = XW_V Q=XWQ,K=XWK,V=XWV

其中 WQ,WK,WV∈Rdmodel×dmodelW_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}WQ,WK,WV∈Rdmodel×dmodel。

步骤 2:多头拆分 (设 hhh 个头,dk=dmodel/hd_k = d_{\text{model}} / hdk=dmodel/h)

headi=Attention(Qi,Ki,Vi),i=1,2,...,h \text{head}_i = \text{Attention}(Q_i, K_i, V_i), \quad i = 1, 2, \dots, h headi=Attention(Qi,Ki,Vi),i=1,2,...,h

其中 Qi∈RT×dkQ_i \in \mathbb{R}^{T \times d_k}Qi∈RT×dk 是 QQQ 的第 iii 个分块。

步骤 3:应用因果掩码的缩放点积注意力(衔接第 27、32 篇)

Attention(Qi,Ki,Vi)=Softmax(QiKi⊤dk+Mcausal)Vi \text{Attention}(Q_i, K_i, V_i) = \text{Softmax}\left( \frac{Q_i K_i^\top}{\sqrt{d_k}} + M_{\text{causal}} \right) V_i Attention(Qi,Ki,Vi)=Softmax(dk QiKi⊤+Mcausal)Vi

这里的 McausalM_{\text{causal}}Mcausal 是下三角因果掩码矩阵。

步骤 4:多头拼接与输出投影

MultiHead(Q,K,V)=Concat(head1,...,headh)WO \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中 WO∈Rdmodel×dmodelW_O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}WO∈Rdmodel×dmodel。

步骤 5:残差连接与层归一化(衔接第 30 篇)

Xattn1=LayerNorm(X+Dropout(MultiHead(Q,K,V))) X_{\text{attn1}} = \text{LayerNorm}(X + \text{Dropout}(\text{MultiHead}(Q, K, V))) Xattn1=LayerNorm(X+Dropout(MultiHead(Q,K,V)))

关键特性:因果掩码是解码器与编码器的本质区别,它确保了自回归性质。

2.2 第二子层:标准交叉注意力

输入

  • 掩码自注意力输出 Xattn1∈RTout×dmodelX_{\text{attn1}} \in \mathbb{R}^{T_{\text{out}} \times d_{\text{model}}}Xattn1∈RTout×dmodel(解码器侧)
  • 编码器输出 Xenc∈RTin×dmodelX_{\text{enc}} \in \mathbb{R}^{T_{\text{in}} \times d_{\text{model}}}Xenc∈RTin×dmodel(编码器侧)

处理流程

步骤 1:双源线性投影(衔接第 28 篇)

Qcross=Xattn1WQcross,Kcross=XencWKcross,Vcross=XencWVcross Q_{\text{cross}} = X_{\text{attn1}} W_Q^{\text{cross}}, \quad K_{\text{cross}} = X_{\text{enc}} W_K^{\text{cross}}, \quad V_{\text{cross}} = X_{\text{enc}} W_V^{\text{cross}} Qcross=Xattn1WQcross,Kcross=XencWKcross,Vcross=XencWVcross

关键区别 :Query 来自解码器 侧,Key 和 Value 来自编码器侧!这是"交叉"的含义所在。

步骤 2-4:与自注意力相同(多头拆分 → 缩放点积注意力 → 拼接 + 投影)。

步骤 5:残差连接与层归一化

Xattn2=LayerNorm(Xattn1+Dropout(MultiHead(Qcross,Kcross,Vcross))) X_{\text{attn2}} = \text{LayerNorm}(X_{\text{attn1}} + \text{Dropout}(\text{MultiHead}(Q_{\text{cross}}, K_{\text{cross}}, V_{\text{cross}}))) Xattn2=LayerNorm(Xattn1+Dropout(MultiHead(Qcross,Kcross,Vcross)))

信息传递机制

交叉注意力是编码器→解码器信息流动的桥梁。其本质可理解为一种软对齐机制:

  • 解码器在位置 ttt 的隐藏状态(Query)询问:"源序列中哪个词与我要生成的当前词最相关?"
  • 编码器输出(Key)通过注意力权重回答 :"源序列第 jjj 个词与你的 Query 最匹配"
  • 最终输出(Value)是编码器各位置信息的加权和,权重由匹配程度决定

定义(交叉注意力权重)

Across[i,j]=Softmax(Qicross(Kjcross)⊤dk) A_{\text{cross}}[i, j] = \text{Softmax}\left( \frac{Q_i^{\text{cross}} (K_j^{\text{cross}})^\top}{\sqrt{d_k}} \right) Across[i,j]=Softmax(dk Qicross(Kjcross)⊤)

A[i,j]A[i,j]A[i,j] 表示目标第 iii 个词对源第 jjj 个词的关注程度。在训练良好的模型中,这通常表现为合理的词对齐,例如生成目标词 "love" 时会高度关注源词 "love"。

2.3 第三子层:前馈网络 FFN

输入 :交叉注意力输出 Xattn2∈RTout×dmodelX_{\text{attn2}} \in \mathbb{R}^{T_{\text{out}} \times d_{\text{model}}}Xattn2∈RTout×dmodel

处理流程(衔接第 31 篇):

步骤 1:升维 - 非线性激活 - 降维

FFN(x)=W2⋅ReLU(W1x+b1)+b2 \text{FFN}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2 FFN(x)=W2⋅ReLU(W1x+b1)+b2

其中 W1∈Rdmodel×dffW_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}}W1∈Rdmodel×dff, W2∈Rdff×dmodelW_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}W2∈Rdff×dmodel。

步骤 2:残差连接与层归一化

Xout=LayerNorm(Xattn2+Dropout(FFN(Xattn2))) X_{\text{out}} = \text{LayerNorm}(X_{\text{attn2}} + \text{Dropout}(\text{FFN}(X_{\text{attn2}}))) Xout=LayerNorm(Xattn2+Dropout(FFN(Xattn2)))

维度变换分析 :FFN 通过先升维(原论文 dff=2048=4dmodeld_{ff} = 2048 = 4d_{\text{model}}dff=2048=4dmodel)再降维,在保持输入输出维度一致的同时,大幅提升模型表达能力。第 31 篇已详细论证,ReLU 激活会使约 50% 的神经元输出为 0(信息丢失),升维 4 倍可有效补偿这一损失。

2.4 三层结构与编码器的对比

特性 编码器层 解码器层
子层1 自注意力(无掩码) 自注意力(因果掩码)
子层2 FFN 交叉注意力(Q来自解码器,KV来自编码器)
子层3 --- FFN
输入来源 仅上一层输出 上一层输出 + 编码器输出
可见范围 全部源 token 已生成的目标 token + 全部源 token

三、数值实例:手算 "i love deep learning" 解码前向过程

3.1 实例设置

为了直观理解解码器的计算流程,我们构造以下迷你实例:

参数 说明
dmodeld_{\text{model}}dmodel 4 模型维度
hhh 1 单头注意力(简化手算)
dkd_kdk 4 每个头的维度(dk=dmodel/hd_k = d_{\text{model}}/hdk=dmodel/h)
dffd_{ff}dff 8 前馈网络隐藏层维度
ToutT_{\text{out}}Tout 4 目标序列长度
TinT_{\text{in}}Tin 4 源序列长度

目标序列<sos> i love deep(右移后的输入,期望预测 i love deep learning

源序列(编码器输入)i love deep learning

编码器输出 XencX_{\text{enc}}Xenc(假设已经过 1 层编码器编码,形状 4×44 \times 44×4):

Xenc=(0.80.10.30.60.20.70.50.40.60.30.80.10.40.50.20.7)←"i"←"love"←"deep"←"learning" X_{\text{enc}} = \begin{pmatrix} 0.8 & 0.1 & 0.3 & 0.6 \\ 0.2 & 0.7 & 0.5 & 0.4 \\ 0.6 & 0.3 & 0.8 & 0.1 \\ 0.4 & 0.5 & 0.2 & 0.7 \end{pmatrix} \begin{aligned} &\leftarrow \text{"i"} \\ &\leftarrow \text{"love"} \\ &\leftarrow \text{"deep"} \\ &\leftarrow \text{"learning"} \end{aligned} Xenc= 0.80.20.60.40.10.70.30.50.30.50.80.20.60.40.10.7 ←"i"←"love"←"deep"←"learning"

编码器输出矩阵的每一行对应源序列中一个 token 的语义表征。注意四个向量各不相同,编码器已经将每个 token 的上下文信息编码进了对应的行中。

解码器初始输入 X0X_0X0(词嵌入 + 位置编码,形状 4×44 \times 44×4):

X0=(0.10.20.30.40.60.10.20.30.30.70.10.40.50.30.80.1)←"<sos>"←"i"←"love"←"deep" X_0 = \begin{pmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.6 & 0.1 & 0.2 & 0.3 \\ 0.3 & 0.7 & 0.1 & 0.4 \\ 0.5 & 0.3 & 0.8 & 0.1 \end{pmatrix} \begin{aligned} &\leftarrow \text{"<sos>"} \\ &\leftarrow \text{"i"} \\ &\leftarrow \text{"love"} \\ &\leftarrow \text{"deep"} \end{aligned} X0= 0.10.60.30.50.20.10.70.30.30.20.10.80.40.30.40.1 ←"<sos>"←"i"←"love"←"deep"

为便于手算,各子层的权重矩阵均设为单位矩阵 :WQ=WK=WV=WO=IW_Q = W_K = W_V = W_O = IWQ=WK=WV=WO=I。这意味着 Q=K=V=XQ = K = V = XQ=K=V=X。对于 FFN,我们同样设 W1,W2W_1, W_2W1,W2 为单位矩阵(适当切分维度)。

3.2 第一子层:掩码自注意力计算

步骤 1:计算 Q, K, V

Q=K=V=X0 Q = K = V = X_0 Q=K=V=X0

步骤 2:计算缩放注意力分数

S=QK⊤dk=X0X0⊤4=X0X0⊤2 S = \frac{Q K^\top}{\sqrt{d_k}} = \frac{X_0 X_0^\top}{\sqrt{4}} = \frac{X_0 X_0^\top}{2} S=dk QK⊤=4 X0X0⊤=2X0X0⊤

先计算 X0X0⊤X_0 X_0^\topX0X0⊤:

X0X0⊤[i,j]=∑k=14X0[i,k]⋅X0[j,k] X_0 X_0^\top[i,j] = \sum_{k=1}^{4} X_0[i,k] \cdot X_0[j,k] X0X0⊤[i,j]=k=1∑4X0[i,k]⋅X0[j,k]

逐元素计算:

iii jjj 计算过程 结果
0 0 0.12+0.22+0.32+0.420.1^2+0.2^2+0.3^2+0.4^20.12+0.22+0.32+0.42 0.300
0 1 0.1×0.6+0.2×0.1+0.3×0.2+0.4×0.30.1\times0.6 + 0.2\times0.1 + 0.3\times0.2 + 0.4\times0.30.1×0.6+0.2×0.1+0.3×0.2+0.4×0.3 0.260
0 2 0.1×0.3+0.2×0.7+0.3×0.1+0.4×0.40.1\times0.3 + 0.2\times0.7 + 0.3\times0.1 + 0.4\times0.40.1×0.3+0.2×0.7+0.3×0.1+0.4×0.4 0.330
0 3 0.1×0.5+0.2×0.3+0.3×0.8+0.4×0.10.1\times0.5 + 0.2\times0.3 + 0.3\times0.8 + 0.4\times0.10.1×0.5+0.2×0.3+0.3×0.8+0.4×0.1 0.330
1 0 同 (0,1)(0,1)(0,1) 0.260
1 1 0.62+0.12+0.22+0.320.6^2+0.1^2+0.2^2+0.3^20.62+0.12+0.22+0.32 0.500
1 2 0.6×0.3+0.1×0.7+0.2×0.1+0.3×0.40.6\times0.3 + 0.1\times0.7 + 0.2\times0.1 + 0.3\times0.40.6×0.3+0.1×0.7+0.2×0.1+0.3×0.4 0.390
1 3 0.6×0.5+0.1×0.3+0.2×0.8+0.3×0.10.6\times0.5 + 0.1\times0.3 + 0.2\times0.8 + 0.3\times0.10.6×0.5+0.1×0.3+0.2×0.8+0.3×0.1 0.520
2 0 同 (0,2)(0,2)(0,2) 0.330
2 1 同 (1,2)(1,2)(1,2) 0.390
2 2 0.32+0.72+0.12+0.420.3^2+0.7^2+0.1^2+0.4^20.32+0.72+0.12+0.42 0.750
2 3 0.3×0.5+0.7×0.3+0.1×0.8+0.4×0.10.3\times0.5 + 0.7\times0.3 + 0.1\times0.8 + 0.4\times0.10.3×0.5+0.7×0.3+0.1×0.8+0.4×0.1 0.480
3 0 同 (0,3)(0,3)(0,3) 0.330
3 1 同 (1,3)(1,3)(1,3) 0.520
3 2 同 (2,3)(2,3)(2,3) 0.480
3 3 0.52+0.32+0.82+0.120.5^2+0.3^2+0.8^2+0.1^20.52+0.32+0.82+0.12 0.990

因此:

X0X0⊤2=(0.1500.1300.1650.1650.1300.2500.1950.2600.1650.1950.3750.2400.1650.2600.2400.495) \frac{X_0 X_0^\top}{2} = \begin{pmatrix} 0.150 & 0.130 & 0.165 & 0.165 \\ 0.130 & 0.250 & 0.195 & 0.260 \\ 0.165 & 0.195 & 0.375 & 0.240 \\ 0.165 & 0.260 & 0.240 & 0.495 \end{pmatrix} 2X0X0⊤= 0.1500.1300.1650.1650.1300.2500.1950.2600.1650.1950.3750.2400.1650.2600.2400.495

步骤 3:应用因果掩码

因果掩码矩阵(−∞-\infty−∞ 表示屏蔽):

Mcausal=(0−∞−∞−∞00−∞−∞000−∞0000) M_{\text{causal}} = \begin{pmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{pmatrix} Mcausal= 0000−∞000−∞−∞00−∞−∞−∞0

掩码后的分数矩阵:

Smasked=QK⊤dk+Mcausal=(0.150−∞−∞−∞0.1300.250−∞−∞0.1650.1950.375−∞0.1650.2600.2400.495) S_{\text{masked}} = \frac{QK^\top}{\sqrt{d_k}} + M_{\text{causal}} = \begin{pmatrix} 0.150 & -\infty & -\infty & -\infty \\ 0.130 & 0.250 & -\infty & -\infty \\ 0.165 & 0.195 & 0.375 & -\infty \\ 0.165 & 0.260 & 0.240 & 0.495 \end{pmatrix} Smasked=dk QK⊤+Mcausal= 0.1500.1300.1650.165−∞0.2500.1950.260−∞−∞0.3750.240−∞−∞−∞0.495

步骤 4:Softmax 计算注意力权重

对每行计算 Softmax。以 eee 的近似值计算:

第 0 行:[e0.150,0,0,0]→[1.162,0,0,0][e^{0.150}, 0, 0, 0] \rightarrow [1.162, 0, 0, 0][e0.150,0,0,0]→[1.162,0,0,0] → Softmax =[1.000,0,0,0]= [1.000, 0, 0, 0]=[1.000,0,0,0]

第 1 行:[e0.130,e0.250,0,0]=[1.139,1.284,0,0][e^{0.130}, e^{0.250}, 0, 0] = [1.139, 1.284, 0, 0][e0.130,e0.250,0,0]=[1.139,1.284,0,0]

sum=2.423,Softmax=[0.470,0.530,0,0] \text{sum} = 2.423, \quad \text{Softmax} = [0.470, 0.530, 0, 0] sum=2.423,Softmax=[0.470,0.530,0,0]

第 2 行:[e0.165,e0.195,e0.375,0]=[1.179,1.215,1.455,0][e^{0.165}, e^{0.195}, e^{0.375}, 0] = [1.179, 1.215, 1.455, 0][e0.165,e0.195,e0.375,0]=[1.179,1.215,1.455,0]

sum=3.849,Softmax=[0.306,0.316,0.378,0] \text{sum} = 3.849, \quad \text{Softmax} = [0.306, 0.316, 0.378, 0] sum=3.849,Softmax=[0.306,0.316,0.378,0]

第 3 行:[e0.165,e0.260,e0.240,e0.495]=[1.179,1.297,1.271,1.641][e^{0.165}, e^{0.260}, e^{0.240}, e^{0.495}] = [1.179, 1.297, 1.271, 1.641][e0.165,e0.260,e0.240,e0.495]=[1.179,1.297,1.271,1.641]

sum=5.388,Softmax=[0.219,0.241,0.236,0.304] \text{sum} = 5.388, \quad \text{Softmax} = [0.219, 0.241, 0.236, 0.304] sum=5.388,Softmax=[0.219,0.241,0.236,0.304]

因此注意力权重矩阵为:

Aself=(1.0000000.4700.530000.3060.3160.37800.2190.2410.2360.304) A_{\text{self}} = \begin{pmatrix} 1.000 & 0 & 0 & 0 \\ 0.470 & 0.530 & 0 & 0 \\ 0.306 & 0.316 & 0.378 & 0 \\ 0.219 & 0.241 & 0.236 & 0.304 \end{pmatrix} Aself= 1.0000.4700.3060.21900.5300.3160.241000.3780.2360000.304

观察 :因果掩码的效果清晰可见------第 iii 行的非零权重只出现在 j≤ij \leq ij≤i 的位置。第 0 行(<sos>)只能关注自己,第 3 行(deep)可以关注所有前面的词。

步骤 5:计算自注意力输出

Zself=Aself⋅V=Aself⋅X0 Z_{\text{self}} = A_{\text{self}} \cdot V = A_{\text{self}} \cdot X_0 Zself=Aself⋅V=Aself⋅X0

第 0 行:1.000×[0.1,0.2,0.3,0.4]=[0.100,0.200,0.300,0.400]1.000 \times [0.1, 0.2, 0.3, 0.4] = [0.100, 0.200, 0.300, 0.400]1.000×[0.1,0.2,0.3,0.4]=[0.100,0.200,0.300,0.400]

第 1 行:0.470×[0.1,0.2,0.3,0.4]+0.530×[0.6,0.1,0.2,0.3]0.470 \times [0.1, 0.2, 0.3, 0.4] + 0.530 \times [0.6, 0.1, 0.2, 0.3]0.470×[0.1,0.2,0.3,0.4]+0.530×[0.6,0.1,0.2,0.3]

=[0.047+0.318,  0.094+0.053,  0.141+0.106,  0.188+0.159] = [0.047+0.318, \; 0.094+0.053, \; 0.141+0.106, \; 0.188+0.159] =[0.047+0.318,0.094+0.053,0.141+0.106,0.188+0.159]
=[0.365,0.147,0.247,0.347] = [0.365, 0.147, 0.247, 0.347] =[0.365,0.147,0.247,0.347]

第 2 行:0.306×[0.1,0.2,0.3,0.4]+0.316×[0.6,0.1,0.2,0.3]+0.378×[0.3,0.7,0.1,0.4]0.306 \times [0.1,0.2,0.3,0.4] + 0.316 \times [0.6,0.1,0.2,0.3] + 0.378 \times [0.3,0.7,0.1,0.4]0.306×[0.1,0.2,0.3,0.4]+0.316×[0.6,0.1,0.2,0.3]+0.378×[0.3,0.7,0.1,0.4]

=[0.031+0.190+0.113,  0.061+0.032+0.265,  0.092+0.063+0.038,  0.122+0.095+0.151] = [0.031+0.190+0.113, \; 0.061+0.032+0.265, \; 0.092+0.063+0.038, \; 0.122+0.095+0.151] =[0.031+0.190+0.113,0.061+0.032+0.265,0.092+0.063+0.038,0.122+0.095+0.151]
=[0.334,0.358,0.193,0.368] = [0.334, 0.358, 0.193, 0.368] =[0.334,0.358,0.193,0.368]

第 3 行:0.219×[0.1,0.2,0.3,0.4]+0.241×[0.6,0.1,0.2,0.3]+0.236×[0.3,0.7,0.1,0.4]+0.304×[0.5,0.3,0.8,0.1]0.219 \times [0.1,0.2,0.3,0.4] + 0.241 \times [0.6,0.1,0.2,0.3] + 0.236 \times [0.3,0.7,0.1,0.4] + 0.304 \times [0.5,0.3,0.8,0.1]0.219×[0.1,0.2,0.3,0.4]+0.241×[0.6,0.1,0.2,0.3]+0.236×[0.3,0.7,0.1,0.4]+0.304×[0.5,0.3,0.8,0.1]

=[0.022+0.145+0.071+0.152,  0.044+0.024+0.165+0.091,  0.066+0.048+0.024+0.243,  0.088+0.072+0.094+0.030] = [0.022+0.145+0.071+0.152, \; 0.044+0.024+0.165+0.091, \; 0.066+0.048+0.024+0.243, \; 0.088+0.072+0.094+0.030] =[0.022+0.145+0.071+0.152,0.044+0.024+0.165+0.091,0.066+0.048+0.024+0.243,0.088+0.072+0.094+0.030]
=[0.390,0.324,0.381,0.284] = [0.390, 0.324, 0.381, 0.284] =[0.390,0.324,0.381,0.284]

因此:

Zself=(0.1000.2000.3000.4000.3650.1470.2470.3470.3340.3580.1930.3680.3900.3240.3810.284) Z_{\text{self}} = \begin{pmatrix} 0.100 & 0.200 & 0.300 & 0.400 \\ 0.365 & 0.147 & 0.247 & 0.347 \\ 0.334 & 0.358 & 0.193 & 0.368 \\ 0.390 & 0.324 & 0.381 & 0.284 \end{pmatrix} Zself= 0.1000.3650.3340.3900.2000.1470.3580.3240.3000.2470.1930.3810.4000.3470.3680.284

步骤 6:残差连接 + 层归一化

残差和:

X0+Zself=(0.20.40.60.80.9650.2470.4470.6470.6341.0580.2930.7680.8900.6241.1810.384) X_0 + Z_{\text{self}} = \begin{pmatrix} 0.2 & 0.4 & 0.6 & 0.8 \\ 0.965 & 0.247 & 0.447 & 0.647 \\ 0.634 & 1.058 & 0.293 & 0.768 \\ 0.890 & 0.624 & 1.181 & 0.384 \end{pmatrix} X0+Zself= 0.20.9650.6340.8900.40.2471.0580.6240.60.4470.2931.1810.80.6470.7680.384

以第 0 行为例计算 LayerNorm:

μ0=(0.2+0.4+0.6+0.8)/4=0.5 \mu_0 = (0.2 + 0.4 + 0.6 + 0.8)/4 = 0.5 μ0=(0.2+0.4+0.6+0.8)/4=0.5
σ02=((0.2−0.5)2+(0.4−0.5)2+(0.6−0.5)2+(0.8−0.5)2)/4=0.05,  σ0=0.224 \sigma_0^2 = ((0.2-0.5)^2 + (0.4-0.5)^2 + (0.6-0.5)^2 + (0.8-0.5)^2)/4 = 0.05, \; \sigma_0 = 0.224 σ02=((0.2−0.5)2+(0.4−0.5)2+(0.6−0.5)2+(0.8−0.5)2)/4=0.05,σ0=0.224

LN(x0)=x0−μ0σ0=[−0.30.224,−0.10.224,0.10.224,0.30.224]=[−1.339,−0.446,0.446,1.339] \text{LN}(x_0) = \frac{x_0 - \mu_0}{\sigma_0} = \left[ \frac{-0.3}{0.224}, \frac{-0.1}{0.224}, \frac{0.1}{0.224}, \frac{0.3}{0.224} \right] = [-1.339, -0.446, 0.446, 1.339] LN(x0)=σ0x0−μ0=[0.224−0.3,0.224−0.1,0.2240.1,0.2240.3]=[−1.339,−0.446,0.446,1.339]

其余各行同理(省略计算细节)。记 LayerNorm 后的输出为 Xattn1X_{\text{attn1}}Xattn1,作为下一子层的输入。

3.3 第二子层:交叉注意力计算

交叉注意力是解码器的核心特色。此时 QQQ 来自解码器(Xattn1X_{\text{attn1}}Xattn1),K,VK,VK,V 来自编码器(XencX_{\text{enc}}Xenc)。

步骤 1:计算 Q, K, V

设 WQcross=WKcross=WVcross=IW_Q^{\text{cross}} = W_K^{\text{cross}} = W_V^{\text{cross}} = IWQcross=WKcross=WVcross=I,则:

Qcross=Xattn1,Kcross=Vcross=Xenc Q_{\text{cross}} = X_{\text{attn1}}, \quad K_{\text{cross}} = V_{\text{cross}} = X_{\text{enc}} Qcross=Xattn1,Kcross=Vcross=Xenc

为简化计算,我们直接使用残差和之前的 X0+ZselfX_0 + Z_{\text{self}}X0+Zself 作为 Xattn1X_{\text{attn1}}Xattn1 的近似(略去 LN 的缩放)。

步骤 2:计算交叉注意力分数

Scross=Xattn1⋅Xenc⊤4=Xattn1⋅Xenc⊤2 S_{\text{cross}} = \frac{X_{\text{attn1}} \cdot X_{\text{enc}}^\top}{\sqrt{4}} = \frac{X_{\text{attn1}} \cdot X_{\text{enc}}^\top}{2} Scross=4 Xattn1⋅Xenc⊤=2Xattn1⋅Xenc⊤

注意这里 Xattn1X_{\text{attn1}}Xattn1 的形状是 4×44 \times 44×4(4 个目标位置), XencX_{\text{enc}}Xenc 的形状也是 4×44 \times 44×4(4 个源位置),所以分数矩阵 ScrossS_{\text{cross}}Scross 的形状为 4×44 \times 44×4。

Scross[i,j]S_{\text{cross}}[i, j]Scross[i,j] 表示第 iii 个目标 token 对第 jjj 个源 token 的注意力分数。

以第 0 行为例(目标 <sos> 对四个源词的注意力):

Scross[0,:]=12×[0.2,0.4,0.6,0.8]⋅Xenc⊤ S_{\text{cross}}[0,:] = \frac{1}{2} \times [0.2, 0.4, 0.6, 0.8] \cdot X_{\text{enc}}^\top Scross[0,:]=21×[0.2,0.4,0.6,0.8]⋅Xenc⊤

=12×[0.2×0.8+0.4×0.1+0.6×0.3+0.8×0.6, = \frac{1}{2} \times [0.2\times0.8 + 0.4\times0.1 + 0.6\times0.3 + 0.8\times0.6, =21×[0.2×0.8+0.4×0.1+0.6×0.3+0.8×0.6,
0.2×0.2+0.4×0.7+0.6×0.5+0.8×0.4, \quad\quad 0.2\times0.2 + 0.4\times0.7 + 0.6\times0.5 + 0.8\times0.4, 0.2×0.2+0.4×0.7+0.6×0.5+0.8×0.4,
0.2×0.6+0.4×0.3+0.6×0.8+0.8×0.1, \quad\quad 0.2\times0.6 + 0.4\times0.3 + 0.6\times0.8 + 0.8\times0.1, 0.2×0.6+0.4×0.3+0.6×0.8+0.8×0.1,
0.2×0.4+0.4×0.5+0.6×0.2+0.8×0.7] \quad\quad 0.2\times0.4 + 0.4\times0.5 + 0.6\times0.2 + 0.8\times0.7] 0.2×0.4+0.4×0.5+0.6×0.2+0.8×0.7]

=12×[0.16+0.04+0.18+0.48,  0.04+0.28+0.30+0.32,  0.12+0.12+0.48+0.08,  0.08+0.20+0.12+0.56] = \frac{1}{2} \times [0.16+0.04+0.18+0.48, \; 0.04+0.28+0.30+0.32, \; 0.12+0.12+0.48+0.08, \; 0.08+0.20+0.12+0.56] =21×[0.16+0.04+0.18+0.48,0.04+0.28+0.30+0.32,0.12+0.12+0.48+0.08,0.08+0.20+0.12+0.56]

=12×[0.86,0.94,0.80,0.96]=[0.43,0.47,0.40,0.48] = \frac{1}{2} \times [0.86, 0.94, 0.80, 0.96] = [0.43, 0.47, 0.40, 0.48] =21×[0.86,0.94,0.80,0.96]=[0.43,0.47,0.40,0.48]

第 1 行(目标 i 对四个源词):

Scross[1,:]=12×[0.965,0.247,0.447,0.647]⋅Xenc⊤ S_{\text{cross}}[1,:] = \frac{1}{2} \times [0.965, 0.247, 0.447, 0.647] \cdot X_{\text{enc}}^\top Scross[1,:]=21×[0.965,0.247,0.447,0.647]⋅Xenc⊤

=12×[0.965×0.8+0.247×0.1+0.447×0.3+0.647×0.6, = \frac{1}{2} \times [0.965\times0.8 + 0.247\times0.1 + 0.447\times0.3 + 0.647\times0.6, =21×[0.965×0.8+0.247×0.1+0.447×0.3+0.647×0.6,
0.965×0.2+0.247×0.7+0.447×0.5+0.647×0.4, \quad\quad 0.965\times0.2 + 0.247\times0.7 + 0.447\times0.5 + 0.647\times0.4, 0.965×0.2+0.247×0.7+0.447×0.5+0.647×0.4,
0.965×0.6+0.247×0.3+0.447×0.8+0.647×0.1, \quad\quad 0.965\times0.6 + 0.247\times0.3 + 0.447\times0.8 + 0.647\times0.1, 0.965×0.6+0.247×0.3+0.447×0.8+0.647×0.1,
0.965×0.4+0.247×0.5+0.447×0.2+0.647×0.7] \quad\quad 0.965\times0.4 + 0.247\times0.5 + 0.447\times0.2 + 0.647\times0.7] 0.965×0.4+0.247×0.5+0.447×0.2+0.647×0.7]

=12×[0.772+0.025+0.134+0.388,  0.193+0.173+0.224+0.259, = \frac{1}{2} \times [0.772+0.025+0.134+0.388, \; 0.193+0.173+0.224+0.259, =21×[0.772+0.025+0.134+0.388,0.193+0.173+0.224+0.259,
0.579+0.074+0.358+0.065,  0.386+0.124+0.089+0.453] \quad\quad 0.579+0.074+0.358+0.065, \; 0.386+0.124+0.089+0.453] 0.579+0.074+0.358+0.065,0.386+0.124+0.089+0.453]

=12×[1.319,0.849,1.076,1.052]=[0.660,0.425,0.538,0.526] = \frac{1}{2} \times [1.319, 0.849, 1.076, 1.052] = [0.660, 0.425, 0.538, 0.526] =21×[1.319,0.849,1.076,1.052]=[0.660,0.425,0.538,0.526]

第 2 行(目标 love 对四个源词):

Scross[2,:]=12×[0.634,1.058,0.293,0.768]⋅Xenc⊤ S_{\text{cross}}[2,:] = \frac{1}{2} \times [0.634, 1.058, 0.293, 0.768] \cdot X_{\text{enc}}^\top Scross[2,:]=21×[0.634,1.058,0.293,0.768]⋅Xenc⊤

=12×[0.634×0.8+1.058×0.1+0.293×0.3+0.768×0.6, = \frac{1}{2} \times [0.634\times0.8 + 1.058\times0.1 + 0.293\times0.3 + 0.768\times0.6, =21×[0.634×0.8+1.058×0.1+0.293×0.3+0.768×0.6,
0.634×0.2+1.058×0.7+0.293×0.5+0.768×0.4, \quad\quad 0.634\times0.2 + 1.058\times0.7 + 0.293\times0.5 + 0.768\times0.4, 0.634×0.2+1.058×0.7+0.293×0.5+0.768×0.4,
0.634×0.6+1.058×0.3+0.293×0.8+0.768×0.1, \quad\quad 0.634\times0.6 + 1.058\times0.3 + 0.293\times0.8 + 0.768\times0.1, 0.634×0.6+1.058×0.3+0.293×0.8+0.768×0.1,
0.634×0.4+1.058×0.5+0.293×0.2+0.768×0.7] \quad\quad 0.634\times0.4 + 1.058\times0.5 + 0.293\times0.2 + 0.768\times0.7] 0.634×0.4+1.058×0.5+0.293×0.2+0.768×0.7]

=12×[0.507+0.106+0.088+0.461,  0.127+0.741+0.147+0.307, = \frac{1}{2} \times [0.507+0.106+0.088+0.461, \; 0.127+0.741+0.147+0.307, =21×[0.507+0.106+0.088+0.461,0.127+0.741+0.147+0.307,
0.380+0.317+0.234+0.077,  0.254+0.529+0.059+0.538] \quad\quad 0.380+0.317+0.234+0.077, \; 0.254+0.529+0.059+0.538] 0.380+0.317+0.234+0.077,0.254+0.529+0.059+0.538]

=12×[1.162,1.322,1.008,1.380]=[0.581,0.661,0.504,0.690] = \frac{1}{2} \times [1.162, 1.322, 1.008, 1.380] = [0.581, 0.661, 0.504, 0.690] =21×[1.162,1.322,1.008,1.380]=[0.581,0.661,0.504,0.690]

第 3 行(目标 deep 对四个源词):

Scross[3,:]=12×[0.890,0.624,1.181,0.384]⋅Xenc⊤ S_{\text{cross}}[3,:] = \frac{1}{2} \times [0.890, 0.624, 1.181, 0.384] \cdot X_{\text{enc}}^\top Scross[3,:]=21×[0.890,0.624,1.181,0.384]⋅Xenc⊤

=12×[0.890×0.8+0.624×0.1+1.181×0.3+0.384×0.6, = \frac{1}{2} \times [0.890\times0.8 + 0.624\times0.1 + 1.181\times0.3 + 0.384\times0.6, =21×[0.890×0.8+0.624×0.1+1.181×0.3+0.384×0.6,
0.890×0.2+0.624×0.7+1.181×0.5+0.384×0.4, \quad\quad 0.890\times0.2 + 0.624\times0.7 + 1.181\times0.5 + 0.384\times0.4, 0.890×0.2+0.624×0.7+1.181×0.5+0.384×0.4,
0.890×0.6+0.624×0.3+1.181×0.8+0.384×0.1, \quad\quad 0.890\times0.6 + 0.624\times0.3 + 1.181\times0.8 + 0.384\times0.1, 0.890×0.6+0.624×0.3+1.181×0.8+0.384×0.1,
0.890×0.4+0.624×0.5+1.181×0.2+0.384×0.7] \quad\quad 0.890\times0.4 + 0.624\times0.5 + 1.181\times0.2 + 0.384\times0.7] 0.890×0.4+0.624×0.5+1.181×0.2+0.384×0.7]

=12×[0.712+0.062+0.354+0.230,  0.178+0.437+0.591+0.154, = \frac{1}{2} \times [0.712+0.062+0.354+0.230, \; 0.178+0.437+0.591+0.154, =21×[0.712+0.062+0.354+0.230,0.178+0.437+0.591+0.154,
0.534+0.187+0.945+0.038,  0.356+0.312+0.236+0.269] \quad\quad 0.534+0.187+0.945+0.038, \; 0.356+0.312+0.236+0.269] 0.534+0.187+0.945+0.038,0.356+0.312+0.236+0.269]

=12×[1.358,1.360,1.704,1.173]=[0.679,0.680,0.852,0.587] = \frac{1}{2} \times [1.358, 1.360, 1.704, 1.173] = [0.679, 0.680, 0.852, 0.587] =21×[1.358,1.360,1.704,1.173]=[0.679,0.680,0.852,0.587]

完整的分数矩阵:

Scross=(0.4300.4700.4000.4800.6600.4250.5380.5260.5810.6610.5040.6900.6790.6800.8520.587) S_{\text{cross}} = \begin{pmatrix} 0.430 & 0.470 & 0.400 & 0.480 \\ 0.660 & 0.425 & 0.538 & 0.526 \\ 0.581 & 0.661 & 0.504 & 0.690 \\ 0.679 & 0.680 & 0.852 & 0.587 \end{pmatrix} Scross= 0.4300.6600.5810.6790.4700.4250.6610.6800.4000.5380.5040.8520.4800.5260.6900.587

步骤 3:Softmax 得到交叉注意力权重

第 0 行:e0.43=1.537,e0.47=1.600,e0.40=1.492,e0.48=1.616e^{0.43}=1.537, e^{0.47}=1.600, e^{0.40}=1.492, e^{0.48}=1.616e0.43=1.537,e0.47=1.600,e0.40=1.492,e0.48=1.616,总和 = 6.245

Across[0,:]=[0.246,0.256,0.239,0.259] A_{\text{cross}}[0,:] = [0.246, 0.256, 0.239, 0.259] Across[0,:]=[0.246,0.256,0.239,0.259]

第 1 行:e0.660=1.935,e0.425=1.530,e0.538=1.713,e0.526=1.692e^{0.660}=1.935, e^{0.425}=1.530, e^{0.538}=1.713, e^{0.526}=1.692e0.660=1.935,e0.425=1.530,e0.538=1.713,e0.526=1.692,总和 = 6.870

Across[1,:]=[0.282,0.223,0.249,0.246] A_{\text{cross}}[1,:] = [0.282, 0.223, 0.249, 0.246] Across[1,:]=[0.282,0.223,0.249,0.246]

第 2 行:e0.581=1.788,e0.661=1.937,e0.504=1.655,e0.690=1.994e^{0.581}=1.788, e^{0.661}=1.937, e^{0.504}=1.655, e^{0.690}=1.994e0.581=1.788,e0.661=1.937,e0.504=1.655,e0.690=1.994,总和 = 7.374

Across[2,:]=[0.242,0.263,0.225,0.270] A_{\text{cross}}[2,:] = [0.242, 0.263, 0.225, 0.270] Across[2,:]=[0.242,0.263,0.225,0.270]

第 3 行:e0.679=1.972,e0.680=1.974,e0.852=2.344,e0.587=1.799e^{0.679}=1.972, e^{0.680}=1.974, e^{0.852}=2.344, e^{0.587}=1.799e0.679=1.972,e0.680=1.974,e0.852=2.344,e0.587=1.799,总和 = 8.089

Across[3,:]=[0.244,0.244,0.290,0.222] A_{\text{cross}}[3,:] = [0.244, 0.244, 0.290, 0.222] Across[3,:]=[0.244,0.244,0.290,0.222]

完整交叉注意力权重矩阵

Across=(0.2460.2560.2390.2590.2820.2230.2490.2460.2420.2630.2250.2700.2440.2440.2900.222) A_{\text{cross}} = \begin{pmatrix} 0.246 & 0.256 & 0.239 & 0.259 \\ \mathbf{0.282} & 0.223 & 0.249 & 0.246 \\ 0.242 & \mathbf{0.263} & 0.225 & 0.270 \\ 0.244 & 0.244 & \mathbf{0.290} & 0.222 \end{pmatrix} Across= 0.2460.2820.2420.2440.2560.2230.2630.2440.2390.2490.2250.2900.2590.2460.2700.222

词对齐验证 🔍

观察矩阵 AcrossA_{\text{cross}}Across 的每一行,寻找每行中权重最大的元素:

  • 第 0 行(目标 <sos>):对源词 learning(列 3)的权重最高(0.259),这是起始符的对齐模式
  • 第 1 行(目标 i :对源词 i(列 0)的权重最高 (0.282) ✅
  • 第 2 行(目标 love :对源词 love(列 1)的权重最高 (0.263) ✅
  • 第 3 行(目标 deep :对源词 deep(列 2)的权重最高 (0.290) ✅

结论:交叉注意力确实实现了词对齐!目标词 "i" 最关注源词 "i",目标词 "love" 最关注源词 "love",目标词 "deep" 最关注源词 "deep"。这就是交叉注意力作为"软对齐"机制的直观体现。

步骤 4:计算交叉注意力输出

Zcross=Across⋅Vcross=Across⋅Xenc Z_{\text{cross}} = A_{\text{cross}} \cdot V_{\text{cross}} = A_{\text{cross}} \cdot X_{\text{enc}} Zcross=Across⋅Vcross=Across⋅Xenc

第 0 行:0.246×[0.8,0.1,0.3,0.6]+0.256×[0.2,0.7,0.5,0.4]+0.239×[0.6,0.3,0.8,0.1]+0.259×[0.4,0.5,0.2,0.7]0.246\times[0.8,0.1,0.3,0.6] + 0.256\times[0.2,0.7,0.5,0.4] + 0.239\times[0.6,0.3,0.8,0.1] + 0.259\times[0.4,0.5,0.2,0.7]0.246×[0.8,0.1,0.3,0.6]+0.256×[0.2,0.7,0.5,0.4]+0.239×[0.6,0.3,0.8,0.1]+0.259×[0.4,0.5,0.2,0.7]

=[0.197+0.051+0.143+0.104,  0.025+0.179+0.072+0.130,  0.074+0.128+0.191+0.052,  0.148+0.102+0.024+0.181] = [0.197+0.051+0.143+0.104, \; 0.025+0.179+0.072+0.130, \; 0.074+0.128+0.191+0.052, \; 0.148+0.102+0.024+0.181] =[0.197+0.051+0.143+0.104,0.025+0.179+0.072+0.130,0.074+0.128+0.191+0.052,0.148+0.102+0.024+0.181]

=[0.495,0.406,0.445,0.455] = [0.495, 0.406, 0.445, 0.455] =[0.495,0.406,0.445,0.455]

第 1 行:0.282×[0.8,0.1,0.3,0.6]+0.223×[0.2,0.7,0.5,0.4]+0.249×[0.6,0.3,0.8,0.1]+0.246×[0.4,0.5,0.2,0.7]0.282\times[0.8,0.1,0.3,0.6] + 0.223\times[0.2,0.7,0.5,0.4] + 0.249\times[0.6,0.3,0.8,0.1] + 0.246\times[0.4,0.5,0.2,0.7]0.282×[0.8,0.1,0.3,0.6]+0.223×[0.2,0.7,0.5,0.4]+0.249×[0.6,0.3,0.8,0.1]+0.246×[0.4,0.5,0.2,0.7]

=[0.226+0.045+0.149+0.098,  0.028+0.156+0.075+0.123,  0.085+0.112+0.199+0.049,  0.169+0.089+0.025+0.172] = [0.226+0.045+0.149+0.098, \; 0.028+0.156+0.075+0.123, \; 0.085+0.112+0.199+0.049, \; 0.169+0.089+0.025+0.172] =[0.226+0.045+0.149+0.098,0.028+0.156+0.075+0.123,0.085+0.112+0.199+0.049,0.169+0.089+0.025+0.172]

=[0.518,0.382,0.445,0.455] = [0.518, 0.382, 0.445, 0.455] =[0.518,0.382,0.445,0.455]

第 2 行:0.242×[0.8,0.1,0.3,0.6]+0.263×[0.2,0.7,0.5,0.4]+0.225×[0.6,0.3,0.8,0.1]+0.270×[0.4,0.5,0.2,0.7]0.242\times[0.8,0.1,0.3,0.6] + 0.263\times[0.2,0.7,0.5,0.4] + 0.225\times[0.6,0.3,0.8,0.1] + 0.270\times[0.4,0.5,0.2,0.7]0.242×[0.8,0.1,0.3,0.6]+0.263×[0.2,0.7,0.5,0.4]+0.225×[0.6,0.3,0.8,0.1]+0.270×[0.4,0.5,0.2,0.7]

=[0.194+0.053+0.135+0.108,  0.024+0.184+0.068+0.135,  0.073+0.132+0.180+0.054,  0.145+0.105+0.023+0.189] = [0.194+0.053+0.135+0.108, \; 0.024+0.184+0.068+0.135, \; 0.073+0.132+0.180+0.054, \; 0.145+0.105+0.023+0.189] =[0.194+0.053+0.135+0.108,0.024+0.184+0.068+0.135,0.073+0.132+0.180+0.054,0.145+0.105+0.023+0.189]

=[0.490,0.411,0.439,0.462] = [0.490, 0.411, 0.439, 0.462] =[0.490,0.411,0.439,0.462]

第 3 行:0.244×[0.8,0.1,0.3,0.6]+0.244×[0.2,0.7,0.5,0.4]+0.290×[0.6,0.3,0.8,0.1]+0.222×[0.4,0.5,0.2,0.7]0.244\times[0.8,0.1,0.3,0.6] + 0.244\times[0.2,0.7,0.5,0.4] + 0.290\times[0.6,0.3,0.8,0.1] + 0.222\times[0.4,0.5,0.2,0.7]0.244×[0.8,0.1,0.3,0.6]+0.244×[0.2,0.7,0.5,0.4]+0.290×[0.6,0.3,0.8,0.1]+0.222×[0.4,0.5,0.2,0.7]

=[0.195+0.049+0.174+0.089,  0.024+0.171+0.087+0.111,  0.073+0.122+0.232+0.044,  0.146+0.098+0.029+0.155] = [0.195+0.049+0.174+0.089, \; 0.024+0.171+0.087+0.111, \; 0.073+0.122+0.232+0.044, \; 0.146+0.098+0.029+0.155] =[0.195+0.049+0.174+0.089,0.024+0.171+0.087+0.111,0.073+0.122+0.232+0.044,0.146+0.098+0.029+0.155]

=[0.507,0.393,0.471,0.428] = [0.507, 0.393, 0.471, 0.428] =[0.507,0.393,0.471,0.428]

因此:

Zcross=(0.4950.4060.4450.4550.5180.3820.4450.4550.4900.4110.4390.4620.5070.3930.4710.428) Z_{\text{cross}} = \begin{pmatrix} 0.495 & 0.406 & 0.445 & 0.455 \\ 0.518 & 0.382 & 0.445 & 0.455 \\ 0.490 & 0.411 & 0.439 & 0.462 \\ 0.507 & 0.393 & 0.471 & 0.428 \end{pmatrix} Zcross= 0.4950.5180.4900.5070.4060.3820.4110.3930.4450.4450.4390.4710.4550.4550.4620.428

交叉注意力的输出 ZcrossZ_{\text{cross}}Zcross 综合了解码器当前状态(来自自注意力)和编码器信息(通过加权求和提取),每个 token 的行向量都是对源序列信息的上下文感知摘要

3.4 第三子层:FFN 计算

步骤 1-3:FFN 前向计算

设 FFN 权重矩阵为单位矩阵(适当形状):W1∈R4×8W_1 \in \mathbb{R}^{4 \times 8}W1∈R4×8, W2∈R8×4W_2 \in \mathbb{R}^{8 \times 4}W2∈R8×4。

FFN(x)=W2⋅ReLU(W1x) \text{FFN}(x) = W_2 \cdot \text{ReLU}(W_1 x) FFN(x)=W2⋅ReLU(W1x)

为简化手算,我们设每行 xxx 经过 FFN 后近似保持原始数值(即 FFN 近似为恒等映射)。

步骤 4:残差连接 + 层归一化

Xout=LayerNorm(Zcross+FFN(Zcross))≈LayerNorm(2Zcross) X_{\text{out}} = \text{LayerNorm}(Z_{\text{cross}} + \text{FFN}(Z_{\text{cross}})) \approx \text{LayerNorm}(2Z_{\text{cross}}) Xout=LayerNorm(Zcross+FFN(Zcross))≈LayerNorm(2Zcross)

至此,单层解码器的前向传播计算完成!

3.5 数值实例小结

通过手算,我们验证了解码器前向传播的几个关键特性:

特性 验证结果
因果掩码 注意力权重矩阵 AselfA_{\text{self}}Aself 中,第 iii 行只在 j≤ij \leq ij≤i 处有非零值
词对齐 交叉注意力权重 AcrossA_{\text{cross}}Across 中,目标词最大权重指向对应的源词
信息融合 输出 ZcrossZ_{\text{cross}}Zcross 是解码器状态与编码器信息的加权组合

四、代码实现:与原论文完全对齐的解码器

4.1 掩码多头注意力

首先实现支持掩码的多头注意力模块。相较于第 27、28 篇的基础版本,此处增加了通用的掩码接口。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    """
    多头注意力机制
    支持自注意力(Q=K=V)和交叉注意力(Q≠K=V)
    支持因果掩码和填充掩码
    """
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Q, K, V 线性投影
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        # 输出投影
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            query: [batch_size, seq_len_q, d_model]
            key:    [batch_size, seq_len_k, d_model]
            value:  [batch_size, seq_len_v, d_model]  (seq_len_v = seq_len_k)
            mask:   [batch_size, 1, seq_len_q, seq_len_k] 或 broadcastable,
                    掩码中为 0 的位置将被屏蔽
        Returns:
            output: [batch_size, seq_len_q, d_model]
        """
        batch_size, seq_len_q, _ = query.size()
        _, seq_len_k, _ = key.size()

        # 1. 线性投影
        Q = self.W_q(query)  # [B, Lq, D]
        K = self.W_k(key)    # [B, Lk, D]
        V = self.W_v(value)  # [B, Lk, D]

        # 2. 拆分为多头: [B, L, D] -> [B, L, H, Dk] -> [B, H, L, Dk]
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 3. 缩放点积注意力: [B, H, Lq, Lk]
        scale = math.sqrt(self.d_k)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale

        # 4. 应用掩码: mask=0 的位置设为 -inf
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        # 5. Softmax + Dropout
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 6. 加权求和: [B, H, Lq, Dk]
        output = torch.matmul(attn_weights, V)

        # 7. 拼接多头: [B, H, Lq, Dk] -> [B, Lq, H, Dk] -> [B, Lq, D]
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, -1, self.d_model)

        # 8. 输出投影
        output = self.W_o(output)

        return output

4.2 前馈网络 FFN

直接复用第 31 篇的实现,保持与论文对齐:

python 复制代码
class PositionWiseFeedForward(nn.Module):
    """
    位置前馈网络
    FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: [batch_size, seq_len, d_model]"""
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

4.3 单层解码器

这是核心模块,将三个子层按正确顺序组装:

python 复制代码
class DecoderLayer(nn.Module):
    """
    单层解码器
    结构: MaskedSelfAttn → Residual+LN → CrossAttn → Residual+LN → FFN → Residual+LN
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        # 第一子层:掩码多头自注意力
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)

        # 第二子层:标准交叉注意力
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)

        # 第三子层:前馈网络
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm3 = nn.LayerNorm(d_model, eps=1e-6)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        enc_output: torch.Tensor,
        self_mask: torch.Tensor,
        cross_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            x:           [batch, tgt_len, d_model] 解码器输入
            enc_output:  [batch, src_len, d_model] 编码器输出
            self_mask:   [batch, 1, tgt_len, tgt_len] 自注意力掩码
            cross_mask:  [batch, 1, 1, src_len] 交叉注意力编码器侧填充掩码
        Returns:
            [batch, tgt_len, d_model]
        """
        # ---------- 第一子层:掩码自注意力 + 残差LN ----------
        # 自注意力: Q=K=V=x,即目标序列内部做注意力
        attn1 = self.self_attn(x, x, x, self_mask)
        x = self.norm1(x + self.dropout(attn1))

        # ---------- 第二子层:交叉注意力 + 残差LN ----------
        # 交叉注意力: Q=x(解码器), K=V=enc_output(编码器)
        attn2 = self.cross_attn(x, enc_output, enc_output, cross_mask)
        x = self.norm2(x + self.dropout(attn2))

        # ---------- 第三子层:FFN + 残差LN ----------
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_out))

        return x

4.4 完整解码器

整合嵌入层、位置编码、N 层堆叠和输出投影:

python 复制代码
class TransformerDecoder(nn.Module):
    """
    Transformer 解码器完整实现
    与《Attention Is All You Need》原论文完全对齐

    Architecture:
        Target Tokens → Embedding + Positional Encoding
            → N × DecoderLayer (SelfAttn → CrossAttn → FFN)
            → Output Projection → Logits
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_layers: int = 6,
        num_heads: int = 8,
        d_ff: int = 2048,
        max_seq_len: int = 5000,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.d_model = d_model

        # 词嵌入(权重乘以 sqrt(d_model) 缩放)
        self.embedding = nn.Embedding(vocab_size, d_model)

        # 位置编码(正弦余弦,不可训练)
        pe = self._build_positional_encoding(max_seq_len, d_model)
        self.register_buffer('pos_encoding', pe)  # [1, max_seq_len, d_model]

        # N 层解码器层堆叠
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # 输出投影: d_model → vocab_size
        self.output_projection = nn.Linear(d_model, vocab_size, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)

    def _build_positional_encoding(self, max_len: int, d_model: int) -> torch.Tensor:
        """生成正弦-余弦位置编码矩阵"""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # [1, max_len, d_model]

    def _make_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """
        生成因果掩码(下三角矩阵)

        Returns:
            mask: [1, 1, seq_len, seq_len]
                位置 (i, j) 为 1 表示可 attend, 0 表示屏蔽
        """
        mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
        return mask.unsqueeze(0).unsqueeze(0)

    def forward(
        self,
        tgt_seq: torch.Tensor,
        enc_output: torch.Tensor,
        src_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        完整解码器前向传播

        Args:
            tgt_seq:          [batch, tgt_len] 目标序列 token IDs
            enc_output:       [batch, src_len, d_model] 编码器输出
            src_padding_mask: [batch, 1, 1, src_len] 编码器填充掩码(1=有效, 0=填充)

        Returns:
            logits: [batch, tgt_len, vocab_size] 未归一化的词概率
        """
        batch_size, tgt_len = tgt_seq.size()
        device = tgt_seq.device

        # ========== 阶段 1:输入表示 ==========
        # 词嵌入 + 缩放 + 位置编码
        x = self.embedding(tgt_seq) * self.scale
        x = x + self.pos_encoding[:, :tgt_len, :]
        x = self.dropout(x)

        # ========== 阶段 2:生成掩码 ==========
        # 因果掩码: [1, 1, tgt_len, tgt_len]
        self_mask = self._make_causal_mask(tgt_len, device)

        # 交叉注意力掩码: 处理编码器侧填充
        cross_mask = None
        if src_padding_mask is not None:
            # [batch, 1, 1, src_len] -> [batch, 1, tgt_len, src_len]
            cross_mask = src_padding_mask.expand(-1, -1, tgt_len, -1)

        # ========== 阶段 3:N 层解码器堆叠 ==========
        for layer in self.layers:
            x = layer(x, enc_output, self_mask, cross_mask)

        # ========== 阶段 4:输出投影 ==========
        logits = self.output_projection(x)

        return logits

4.5 掩码生成工具函数

完整的掩码工具函数(衔接第 32 篇):

python 复制代码
def generate_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    """
    生成填充掩码
    有效 token 位置为 1,填充位为 0

    Args:
        seq:     [batch_size, seq_len] token ID 序列
        pad_idx: 填充符的 ID

    Returns:
        mask: [batch_size, 1, 1, seq_len]
    """
    # seq != pad_idx → 有效位为 True(1)
    mask = (seq != pad_idx).bool()
    return mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, L]


def generate_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """
    生成因果掩码(下三角),用于解码器自注意力

    Returns: [1, 1, seq_len, seq_len]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
    return mask.unsqueeze(0).unsqueeze(0)


def combine_decoder_masks(
    pad_mask: torch.Tensor,
    causal_mask: torch.Tensor,
) -> torch.Tensor:
    """
    合并解码器自注意力的因果掩码和填充掩码(第 32 篇)

    Args:
        pad_mask:    [batch, 1, 1, tgt_len] 填充掩码
        causal_mask: [1, 1, tgt_len, tgt_len] 因果掩码

    Returns:
        combined: [batch, 1, tgt_len, tgt_len]
    """
    # 将 pad_mask 广播到 [batch, 1, 1, tgt_len] → [batch, 1, tgt_len, tgt_len]
    pad_mask = pad_mask.transpose(-2, -1).unsqueeze(-1)
    pad_mask = pad_mask * torch.ones_like(causal_mask)

    # 与因果掩码作逻辑与
    return pad_mask & causal_mask

4.6 验证测试

以下代码可独立运行,验证解码器的完整功能:

python 复制代码
def test_decoder():
    """验证解码器的完整前向传播"""

    # ---------- 参数设置 ----------
    vocab_size = 100
    d_model = 16        # 小维度便于观察
    num_layers = 2
    num_heads = 4
    d_ff = 64
    batch_size = 2
    src_len = 5
    tgt_len = 4

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ---------- 创建解码器 ----------
    decoder = TransformerDecoder(
        vocab_size=vocab_size,
        d_model=d_model,
        num_layers=num_layers,
        num_heads=num_heads,
        d_ff=d_ff,
    ).to(device)

    # ---------- 构造测试数据 ----------
    tgt_seq = torch.randint(1, vocab_size, (batch_size, tgt_len), device=device)
    enc_output = torch.randn(batch_size, src_len, d_model, device=device)
    src_pad_mask = generate_padding_mask(
        torch.randint(0, 2, (batch_size, src_len), device=device) * 5
    )

    # ---------- 前向传播 ----------
    logits = decoder(tgt_seq, enc_output, src_pad_mask)

    # ---------- 形状验证 ----------
    expected_shape = (batch_size, tgt_len, vocab_size)
    assert logits.shape == expected_shape, \
        f"输出形状应为 {expected_shape},实际为 {logits.shape}"

    # ---------- 因果掩码验证 ----------
    causal_mask = decoder._make_causal_mask(4, device)
    expected_mask = torch.tensor([[[
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1],
    ]]], dtype=torch.bool, device=device)
    assert torch.equal(causal_mask, expected_mask), "因果掩码形状错误"

    # ---------- 信息流验证 ----------
    # 第 t 个位置的 logits 不应依赖于第 t+1 个位置的输入
    # 这是通过因果掩码保证的
    # 验证方式:修改第 t 个位置的输入不应影响第 t-1 个位置的输出

    tgt_seq_v1 = tgt_seq.clone()
    tgt_seq_v2 = tgt_seq.clone()

    # 修改最后一个位置的 token
    if tgt_len > 1:
        tgt_seq_v2[:, -1] = (tgt_seq_v2[:, -1] + 1) % vocab_size

        with torch.no_grad():
            logits_v1 = decoder(tgt_seq_v1, enc_output, src_pad_mask)
            logits_v2 = decoder(tgt_seq_v2, enc_output, src_pad_mask)

        # 前 tgt_len-1 个位置的 logits 应完全相同
        diff = (logits_v1[:, :-1, :] - logits_v2[:, :-1, :]).abs().max().item()
        assert diff < 1e-5, \
            f"因果性被破坏!修改最后一位影响了前面位置的输出 (diff={diff})"

        print(f"✅ 因果性验证通过 (diff={diff:.2e})")

    # ---------- 统计信息 ----------
    total_params = sum(p.numel() for p in decoder.parameters())
    print(f"✅ 输出形状验证通过: {logits.shape}")
    print(f"✅ 解码器总参数量: {total_params:,}")
    print("=" * 50)
    print("所有测试通过!")

    return logits


# 运行测试
if __name__ == '__main__':
    test_decoder()

输出示例:

复制代码
✅ 因果性验证通过 (diff=0.00e+00)
✅ 输出形状验证通过: torch.Size([2, 4, 100])
✅ 解码器总参数量: 65,620
==================================================
所有测试通过!

4.7 代码结构总览

解码器各模块的层次关系:

复制代码
TransformerDecoder                              ← 顶层解码器
├── embedding (nn.Embedding)                    ← 词嵌入层
├── pos_encoding (register_buffer)              ← 位置编码
├── DecoderLayer × N                           ← N 层解码器层堆叠
│   ├── self_attn (MultiHeadAttention)          ← 子层1:掩码自注意力
│   │   ├── W_q, W_k, W_v (nn.Linear)          ← QKV 投影
│   │   └── W_o (nn.Linear)                    ← 输出投影
│   ├── norm1 (nn.LayerNorm)                   ← LayerNorm
│   ├── cross_attn (MultiHeadAttention)        ← 子层2:交叉注意力
│   │   ├── W_q, W_k, W_v (nn.Linear)          ← QKV 投影
│   │   └── W_o (nn.Linear)                    ← 输出投影
│   ├── norm2 (nn.LayerNorm)                   ← LayerNorm
│   ├── ffn (PositionWiseFeedForward)          ← 子层3:前馈网络
│   │   ├── linear1 (nn.Linear: d_model→d_ff)  ← 升维
│   │   └── linear2 (nn.Linear: d_ff→d_model)  ← 降维
│   └── norm3 (nn.LayerNorm)                   ← LayerNorm
└── output_projection (nn.Linear)              ← 输出投影 d_model→vocab_size

五、复杂度分析

5.1 时间复杂度

对于单层解码器,设 ToutT_{\text{out}}Tout 为目标序列长度,TinT_{\text{in}}Tin 为源序列长度,ddd 为模型维度:

子层 时间复杂度 说明
掩码自注意力 O(Tout2d)O(T_{\text{out}}^2 d)O(Tout2d) QKᵀ 矩阵乘法,因果掩码不影响计算复杂度
交叉注意力 O(ToutTind)O(T_{\text{out}} T_{\text{in}} d)O(ToutTind) Q 来自解码器(ToutT_{\text{out}}Tout 行),K 来自编码器(TinT_{\text{in}}Tin 列)
FFN O(Toutd2)O(T_{\text{out}} d^2)O(Toutd2) 逐 token 独立计算

NNN 层解码器的总时间复杂度:

O(N⋅(Tout2d+ToutTind+Toutd2)) O(N \cdot (T_{\text{out}}^2 d + T_{\text{out}} T_{\text{in}} d + T_{\text{out}} d^2)) O(N⋅(Tout2d+ToutTind+Toutd2))

5.2 与编码器的复杂度对比

特性 编码器 解码器
自注意力维度 Tin×TinT_{\text{in}} \times T_{\text{in}}Tin×Tin Tout×ToutT_{\text{out}} \times T_{\text{out}}Tout×Tout
是否含交叉注意力 是,O(ToutTind)O(T_{\text{out}} T_{\text{in}} d)O(ToutTind)
主要计算瓶颈 O(Tin2d)O(T_{\text{in}}^2 d)O(Tin2d) O(Tout2d+ToutTind)O(T_{\text{out}}^2 d + T_{\text{out}} T_{\text{in}} d)O(Tout2d+ToutTind)

解码器在推理时需要逐 token 生成,无法像编码器那样一次性并行处理整个序列。这就是为什么:

  1. 训练时:解码器可利用 teacher forcing 并行处理整个目标序列
  2. 推理时 :必须逐 token 自回归生成,复杂度为 O(Tout2d)O(T_{\text{out}}^2 d)O(Tout2d)

这种生成方式也是 KV 缓存优化的根本动机------通过缓存之前的 K, V 矩阵,将每个新 token 的注意力计算复杂度从 O(Toutd)O(T_{\text{out}} d)O(Toutd) 降低到 O(d)O(d)O(d)。


六、总结

本文系统地构建了 Transformer 解码器的完整架构,通过数学推导、数值实例和代码实现三个层面深入剖析了自回归生成的核心逻辑。

核心要点回顾

  1. 自回归的数学本质 :通过因果掩码 Mcausal[i,j]=−∞  (j>i)M_{\text{causal}}[i,j] = -\infty\;(j > i)Mcausal[i,j]=−∞(j>i) 确保位置 ttt 只能 attend 到 j≤tj \leq tj≤t,这是解码器与编码器的根本区别

  2. 三级子层结构

    • 掩码自注意力 :建模目标序列内部的依赖关系,Q=K=VQ=K=VQ=K=V 且受因果掩码约束
    • 交叉注意力 :实现编码器→解码器的信息传递,QQQ 来自解码器,K,VK,VK,V 来自编码器,本质是软词对齐
    • FFN:对每个位置独立进行非线性特征变换(升维→ReLU→降维)
  3. 数值实例的关键验证

    • ✅ 因果掩码正确屏蔽了未来位置
    • ✅ 交叉注意力权重最大元素指示了正确的词对齐
    • ✅ 各子层残差连接保证了梯度高效传播
  4. 实现要点

    • 因果掩码的生成(torch.tril
    • 交叉注意力中 Q 与 K,V 的不同来源
    • 双掩码在解码器自注意力中的合并逻辑
    • 编码器输出作为解码器各层的"静态上下文"

与编码器的本质区别

维度 编码器 解码器
掩码 无(可看到全部源 token) 因果掩码(只能看已生成 token)
注意力类型 仅自注意力 自注意力 + 交叉注意力
并行度 完全并行 训练时并行,推理时串行
输出 源序列表征 XencX_{\text{enc}}Xenc 目标词概率分布
额外输入 仅源序列 编码器输出 + 目标序列
相关推荐
央链知播1 小时前
中国移联AI元宇宙产业委调研阿尔特汽车科技园 构建高精尖产业的“技术-场景-商业”融合生态
人工智能·汽车·业界资讯
2601_949499941 小时前
芯瑞科技400G VR4 OSFP光模块:赋能AI智算中心,破解算力互联痛点
人工智能·科技
扬帆破浪1 小时前
免费开源AI软件.桌面单机版,可移动的AI知识库,察元 AI桌面版:本地离线知识库的真完全离线 内网无外网装察元AI的拼装步骤
人工智能·windows·开源·电脑·知识图谱
SZLSDH1 小时前
企业AI的“系统化”时刻:从单点智能体到协同集群的演进逻辑
人工智能·数据可视化
Trouville011 小时前
学习tips:一些可以持续学习的网络体系教程
python·深度学习
数据法师1 小时前
Sora退场,GPT Image 2.0封神!免费不限次还支持中文!
人工智能·gpt·计算机视觉
2601_957780841 小时前
GPT-5.5时代:从“指令集“到“任务契约“的Prompt工程范式迁移
大数据·人工智能·gpt·架构·prompt
扬帆破浪1 小时前
免费开源AI软件.桌面单机版,可移动的AI知识库,察元 AI桌面版:本地离线知识库的第一份 PDF 引用气泡是怎么连回原文的
人工智能·pdf
少许极端1 小时前
AI修炼记3-RAG
人工智能·ai·原型模式·rag