Transformer解码器终极指南:从Masked Attention到Cross-Attention的PyTorch逐行实现


Transformer 解码器深度解读 + 代码实战


1. 解码器核心作用

Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列 (如翻译结果、文本续写)。它通过 掩码自注意力编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:

  • 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
  • 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。

2. 解码器单层结构详解

每层解码器包含以下模块(附 PyTorch 代码):


2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
python 复制代码
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 生成 Q/K/V 的线性层
        self.to_qkv = nn.Linear(embed_size, embed_size * 3)
        self.scale = self.head_dim ** -0.5  # 缩放因子
        
        # 输出线性层
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q/K/V 并分割多头
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
        
        # 计算注意力分数 QK^T / sqrt(d_k)
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 应用下三角掩码(防止看到未来信息)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)  # 掩码位置填充极小值
        else:
            # 自动生成下三角掩码(训练时使用)
            causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)
            attn = attn.masked_fill(~causal_mask, -1e10)
        
        # Softmax 归一化
        attn = torch.softmax(attn, dim=-1)
        
        # 加权求和
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.reshape(batch_size, seq_len, self.embed_size)
        return self.to_out(out)

代码解析

  • causal_mask 生成下三角矩阵(主对角线及以下为1,其余为0),确保解码时仅能看到当前位置及之前的信息。
  • 推理时可手动传递掩码,控制生成长度。

2.2 编码器-解码器交叉注意力(Cross-Attention)
python 复制代码
class CrossAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 生成 Q 的线性层(解码器输入)
        self.to_q = nn.Linear(embed_size, embed_size)
        # 生成 K/V 的线性层(编码器输出)
        self.to_kv = nn.Linear(embed_size, embed_size * 2)
        
        self.scale = self.head_dim ** -0.5
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, encoder_output, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q 来自解码器输入
        q = self.to_q(x).view(batch_size, seq_len, self.heads, self.head_dim)
        
        # 生成 K/V 来自编码器输出
        k, v = self.to_kv(encoder_output).chunk(2, dim=-1)
        k = k.view(batch_size, -1, self.heads, self.head_dim)  # 编码器序列长度可能不同
        v = v.view(batch_size, -1, self.heads, self.head_dim)
        
        # 计算注意力分数 QK^T / sqrt(d_k)
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 应用掩码(如填充掩码)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)
        
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.reshape(batch_size, seq_len, self.embed_size)
        return self.to_out(out)

代码解析

  • Q 来自解码器输入K/V 来自编码器输出,实现跨序列信息融合。
  • 支持自定义掩码(如处理源序列的填充位置)。

2.3 解码器单层完整实现
python 复制代码
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, dropout=0.1):
        super().__init__()
        self.masked_attn = MaskedMultiHeadAttention(embed_size, heads)
        self.cross_attn = CrossAttention(embed_size, heads)
        self.ffn = FeedForward(embed_size)  # 复用编码器的FFN
        
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # 1. 掩码自注意力
        masked_attn_out = self.masked_attn(x, tgt_mask)
        x = x + self.dropout(masked_attn_out)
        x = self.norm1(x)
        
        # 2. 交叉注意力(Q来自x,K/V来自encoder_output)
        cross_attn_out = self.cross_attn(x, encoder_output, src_mask)
        x = x + self.dropout(cross_attn_out)
        x = self.norm2(x)
        
        # 3. 前馈网络
        ffn_out = self.ffn(x)
        x = x + self.dropout(ffn_out)
        x = self.norm3(x)
        return x

3. 完整解码器实现
python 复制代码
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_size, heads, dropout)
            for _ in range(layers)
        ])
        self.fc_out = nn.Linear(embed_size, vocab_size)  # 输出层预测词表概率
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        logits = self.fc_out(x)  # (batch_size, seq_len, vocab_size)
        return logits

4. 实战测试:文本翻译模拟
python 复制代码
# 参数设置
vocab_size = 10000  # 目标语言词表大小
embed_size = 512
layers = 6
heads = 8

# 初始化编码器和解码器
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)
decoder = TransformerDecoder(vocab_size, embed_size, layers, heads)

# 模拟输入(源语言句子)
src = torch.randint(0, vocab_size, (32, 20))  # (batch_size=32, src_seq_len=20)
# 编码器输出
encoder_output = encoder(src)

# 模拟目标输入(目标语言句子,训练时右移一位)
tgt = torch.randint(0, vocab_size, (32, 25))  # (batch_size=32, tgt_seq_len=25)
# 解码器输出
logits = decoder(tgt, encoder_output)
print("输出形状:", logits.shape)  # torch.Size([32, 25, 10000])

🎉 恭喜! 至此你已经掌握了Transformer解码器的核心原理与实现。无论是机器翻译、文本生成,还是对话系统,解码器都是生成任务的核心引擎。

下一步建议

  1. 尝试在真实数据集(如WMT英德翻译)上训练模型。
  2. 探索 束搜索(Beam Search)温度采样(Temperature Sampling) 等推理优化技术。
  3. 访问 Transformer官方代码库Hugging Face库 深入学习工业级实现。

动手实践是掌握AI的最佳方式------赶紧修改代码参数,观察模型变化吧!如果遇到问题,欢迎在评论区留言讨论,我们一起解决! 🌟


希望这篇解析能助你彻底理解Transformer解码器,期待看到你的实战成果! 😊

相关推荐
啦啦啦在冲冲冲1 小时前
解释一下roberta,bert-chinese和bert-case有啥区别还有bert-large这些
人工智能·深度学习·bert
2401_897930061 小时前
PyTorch 中训练语言模型过程
人工智能·pytorch·语言模型
张子夜 iiii2 小时前
传统神经网络实现-----手写数字识别(MNIST)项目
人工智能·pytorch·python·深度学习·算法
全息数据2 小时前
DDPM代码讲解【详细!!!】
深度学习·stable diffusion·多模态·ddpm
西猫雷婶2 小时前
神经网络|(十九)概率论基础知识-伽马函数·下
人工智能·深度学习·神经网络·机器学习·回归·scikit-learn·概率论
Honeysea_704 小时前
容器的定义及工作原理
人工智能·深度学习·机器学习·docker·ai·持续部署
大千AI助手5 小时前
梯度消失问题:深度学习中的「记忆衰退」困境与解决方案
人工智能·深度学习·神经网络·梯度·梯度消失·链式法则·vanishing
研梦非凡5 小时前
CVPR 2025|无类别词汇的视觉-语言模型少样本学习
人工智能·深度学习·学习·语言模型·自然语言处理
max5006005 小时前
本地部署开源数据生成器项目实战指南
开发语言·人工智能·python·深度学习·算法·开源
一颗20216 小时前
深度解读:PSPNet(Pyramid Scene Parsing Network) — 用金字塔池化把“场景理解”装进分割网络
人工智能·深度学习·计算机视觉