Transformer解码器终极指南:从Masked Attention到Cross-Attention的PyTorch逐行实现


Transformer 解码器深度解读 + 代码实战


1. 解码器核心作用

Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列 (如翻译结果、文本续写)。它通过 掩码自注意力编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:

  • 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
  • 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。

2. 解码器单层结构详解

每层解码器包含以下模块(附 PyTorch 代码):


2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
python 复制代码
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 生成 Q/K/V 的线性层
        self.to_qkv = nn.Linear(embed_size, embed_size * 3)
        self.scale = self.head_dim ** -0.5  # 缩放因子
        
        # 输出线性层
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q/K/V 并分割多头
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
        
        # 计算注意力分数 QK^T / sqrt(d_k)
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 应用下三角掩码(防止看到未来信息)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)  # 掩码位置填充极小值
        else:
            # 自动生成下三角掩码(训练时使用)
            causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)
            attn = attn.masked_fill(~causal_mask, -1e10)
        
        # Softmax 归一化
        attn = torch.softmax(attn, dim=-1)
        
        # 加权求和
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.reshape(batch_size, seq_len, self.embed_size)
        return self.to_out(out)

代码解析

  • causal_mask 生成下三角矩阵(主对角线及以下为1,其余为0),确保解码时仅能看到当前位置及之前的信息。
  • 推理时可手动传递掩码,控制生成长度。

2.2 编码器-解码器交叉注意力(Cross-Attention)
python 复制代码
class CrossAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 生成 Q 的线性层(解码器输入)
        self.to_q = nn.Linear(embed_size, embed_size)
        # 生成 K/V 的线性层(编码器输出)
        self.to_kv = nn.Linear(embed_size, embed_size * 2)
        
        self.scale = self.head_dim ** -0.5
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, encoder_output, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q 来自解码器输入
        q = self.to_q(x).view(batch_size, seq_len, self.heads, self.head_dim)
        
        # 生成 K/V 来自编码器输出
        k, v = self.to_kv(encoder_output).chunk(2, dim=-1)
        k = k.view(batch_size, -1, self.heads, self.head_dim)  # 编码器序列长度可能不同
        v = v.view(batch_size, -1, self.heads, self.head_dim)
        
        # 计算注意力分数 QK^T / sqrt(d_k)
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 应用掩码(如填充掩码)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)
        
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.reshape(batch_size, seq_len, self.embed_size)
        return self.to_out(out)

代码解析

  • Q 来自解码器输入K/V 来自编码器输出,实现跨序列信息融合。
  • 支持自定义掩码(如处理源序列的填充位置)。

2.3 解码器单层完整实现
python 复制代码
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, dropout=0.1):
        super().__init__()
        self.masked_attn = MaskedMultiHeadAttention(embed_size, heads)
        self.cross_attn = CrossAttention(embed_size, heads)
        self.ffn = FeedForward(embed_size)  # 复用编码器的FFN
        
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # 1. 掩码自注意力
        masked_attn_out = self.masked_attn(x, tgt_mask)
        x = x + self.dropout(masked_attn_out)
        x = self.norm1(x)
        
        # 2. 交叉注意力(Q来自x,K/V来自encoder_output)
        cross_attn_out = self.cross_attn(x, encoder_output, src_mask)
        x = x + self.dropout(cross_attn_out)
        x = self.norm2(x)
        
        # 3. 前馈网络
        ffn_out = self.ffn(x)
        x = x + self.dropout(ffn_out)
        x = self.norm3(x)
        return x

3. 完整解码器实现
python 复制代码
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_size, heads, dropout)
            for _ in range(layers)
        ])
        self.fc_out = nn.Linear(embed_size, vocab_size)  # 输出层预测词表概率
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        logits = self.fc_out(x)  # (batch_size, seq_len, vocab_size)
        return logits

4. 实战测试:文本翻译模拟
python 复制代码
# 参数设置
vocab_size = 10000  # 目标语言词表大小
embed_size = 512
layers = 6
heads = 8

# 初始化编码器和解码器
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)
decoder = TransformerDecoder(vocab_size, embed_size, layers, heads)

# 模拟输入(源语言句子)
src = torch.randint(0, vocab_size, (32, 20))  # (batch_size=32, src_seq_len=20)
# 编码器输出
encoder_output = encoder(src)

# 模拟目标输入(目标语言句子,训练时右移一位)
tgt = torch.randint(0, vocab_size, (32, 25))  # (batch_size=32, tgt_seq_len=25)
# 解码器输出
logits = decoder(tgt, encoder_output)
print("输出形状:", logits.shape)  # torch.Size([32, 25, 10000])

🎉 恭喜! 至此你已经掌握了Transformer解码器的核心原理与实现。无论是机器翻译、文本生成,还是对话系统,解码器都是生成任务的核心引擎。

下一步建议

  1. 尝试在真实数据集(如WMT英德翻译)上训练模型。
  2. 探索 束搜索(Beam Search)温度采样(Temperature Sampling) 等推理优化技术。
  3. 访问 Transformer官方代码库Hugging Face库 深入学习工业级实现。

动手实践是掌握AI的最佳方式------赶紧修改代码参数,观察模型变化吧!如果遇到问题,欢迎在评论区留言讨论,我们一起解决! 🌟


希望这篇解析能助你彻底理解Transformer解码器,期待看到你的实战成果! 😊

相关推荐
Chatopera 研发团队33 分钟前
使用 AlexNet 实现图片分类 | PyTorch 深度学习实战
pytorch·深度学习·分类·cnn·cv·alexnet
新加坡内哥谈技术1 小时前
ChunkKV:优化 KV 缓存压缩,让 LLM 长文本推理更高效
人工智能·科技·深度学习·语言模型·机器人
珠江上上上1 小时前
支持向量机原理
人工智能·深度学习·算法·机器学习·支持向量机·数据挖掘
郑万通2 小时前
10.推荐系统的用户研究
深度学习·推荐系统
机器学习之心3 小时前
三角拓扑聚合优化器TTAO-Transformer-BiLSTM多变量回归预测(Maltab)
深度学习·回归·transformer·bilstm·多变量回归预测
Francek Chen3 小时前
【DeepSeek】在本地计算机上部署DeepSeek-R1大模型实战(完整版)
人工智能·深度学习·语言模型·ai编程·deepseek
FF-Studio3 小时前
读 DeepSeek-R1 论文笔记
论文阅读·人工智能·深度学习·机器学习·语言模型·自然语言处理·deepseek
曦云沐4 小时前
手撕Transformer编码器:从Self-Attention到Positional Encoding的PyTorch逐行实现
pytorch·深度学习·transformer
AuGuSt_814 小时前
N-Beats:一种用于时间序列预测的纯前馈神经网络模型
人工智能·深度学习·神经网络