Transformer 解码器深度解读 + 代码实战
1. 解码器核心作用
Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列 (如翻译结果、文本续写)。它通过 掩码自注意力 和 编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:
- 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
- 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。
2. 解码器单层结构详解
每层解码器包含以下模块(附 PyTorch 代码):
2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
python
class MaskedMultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# 生成 Q/K/V 的线性层
self.to_qkv = nn.Linear(embed_size, embed_size * 3)
self.scale = self.head_dim ** -0.5 # 缩放因子
# 输出线性层
self.to_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 生成 Q/K/V 并分割多头
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
# 计算注意力分数 QK^T / sqrt(d_k)
attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
# 应用下三角掩码(防止看到未来信息)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e10) # 掩码位置填充极小值
else:
# 自动生成下三角掩码(训练时使用)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)
attn = attn.masked_fill(~causal_mask, -1e10)
# Softmax 归一化
attn = torch.softmax(attn, dim=-1)
# 加权求和
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = out.reshape(batch_size, seq_len, self.embed_size)
return self.to_out(out)
代码解析:
causal_mask
生成下三角矩阵(主对角线及以下为1,其余为0),确保解码时仅能看到当前位置及之前的信息。- 推理时可手动传递掩码,控制生成长度。
2.2 编码器-解码器交叉注意力(Cross-Attention)
python
class CrossAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# 生成 Q 的线性层(解码器输入)
self.to_q = nn.Linear(embed_size, embed_size)
# 生成 K/V 的线性层(编码器输出)
self.to_kv = nn.Linear(embed_size, embed_size * 2)
self.scale = self.head_dim ** -0.5
self.to_out = nn.Linear(embed_size, embed_size)
def forward(self, x, encoder_output, mask=None):
batch_size, seq_len, _ = x.shape
# 生成 Q 来自解码器输入
q = self.to_q(x).view(batch_size, seq_len, self.heads, self.head_dim)
# 生成 K/V 来自编码器输出
k, v = self.to_kv(encoder_output).chunk(2, dim=-1)
k = k.view(batch_size, -1, self.heads, self.head_dim) # 编码器序列长度可能不同
v = v.view(batch_size, -1, self.heads, self.head_dim)
# 计算注意力分数 QK^T / sqrt(d_k)
attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
# 应用掩码(如填充掩码)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e10)
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = out.reshape(batch_size, seq_len, self.embed_size)
return self.to_out(out)
代码解析:
- Q 来自解码器输入 ,K/V 来自编码器输出,实现跨序列信息融合。
- 支持自定义掩码(如处理源序列的填充位置)。
2.3 解码器单层完整实现
python
class TransformerDecoderLayer(nn.Module):
def __init__(self, embed_size, heads, dropout=0.1):
super().__init__()
self.masked_attn = MaskedMultiHeadAttention(embed_size, heads)
self.cross_attn = CrossAttention(embed_size, heads)
self.ffn = FeedForward(embed_size) # 复用编码器的FFN
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.norm3 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# 1. 掩码自注意力
masked_attn_out = self.masked_attn(x, tgt_mask)
x = x + self.dropout(masked_attn_out)
x = self.norm1(x)
# 2. 交叉注意力(Q来自x,K/V来自encoder_output)
cross_attn_out = self.cross_attn(x, encoder_output, src_mask)
x = x + self.dropout(cross_attn_out)
x = self.norm2(x)
# 3. 前馈网络
ffn_out = self.ffn(x)
x = x + self.dropout(ffn_out)
x = self.norm3(x)
return x
3. 完整解码器实现
python
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.pos_encoding = PositionalEncoding(embed_size)
self.layers = nn.ModuleList([
TransformerDecoderLayer(embed_size, heads, dropout)
for _ in range(layers)
])
self.fc_out = nn.Linear(embed_size, vocab_size) # 输出层预测词表概率
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
logits = self.fc_out(x) # (batch_size, seq_len, vocab_size)
return logits
4. 实战测试:文本翻译模拟
python
# 参数设置
vocab_size = 10000 # 目标语言词表大小
embed_size = 512
layers = 6
heads = 8
# 初始化编码器和解码器
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)
decoder = TransformerDecoder(vocab_size, embed_size, layers, heads)
# 模拟输入(源语言句子)
src = torch.randint(0, vocab_size, (32, 20)) # (batch_size=32, src_seq_len=20)
# 编码器输出
encoder_output = encoder(src)
# 模拟目标输入(目标语言句子,训练时右移一位)
tgt = torch.randint(0, vocab_size, (32, 25)) # (batch_size=32, tgt_seq_len=25)
# 解码器输出
logits = decoder(tgt, encoder_output)
print("输出形状:", logits.shape) # torch.Size([32, 25, 10000])
🎉 恭喜! 至此你已经掌握了Transformer解码器的核心原理与实现。无论是机器翻译、文本生成,还是对话系统,解码器都是生成任务的核心引擎。
下一步建议:
- 尝试在真实数据集(如WMT英德翻译)上训练模型。
- 探索 束搜索(Beam Search) 和 温度采样(Temperature Sampling) 等推理优化技术。
- 访问 Transformer官方代码库 或 Hugging Face库 深入学习工业级实现。
动手实践是掌握AI的最佳方式------赶紧修改代码参数,观察模型变化吧!如果遇到问题,欢迎在评论区留言讨论,我们一起解决! 🌟
希望这篇解析能助你彻底理解Transformer解码器,期待看到你的实战成果! 😊