transformer-实现单层Decoder 层

Decoder Layer

解码器层结构

  • Transformer解码器层由三种核心组件构成:

    1. Masked多头自注意力:关注解码器序列当前位置之前的上下文(因果掩码)

    2. Encoder-Decoder多头注意力:关注编码器输出的相关上下文

    3. 前馈神经网络:进行非线性特征变换

    今天这里实现的是上图中蓝色框中的单层DecoderLayer,不包含 embedding和位置编码,以及最后的Linear和Softmax。

    主要处理流程:

    1. Decoder 的Masked自注意力
    2. Encoder-Decoder自注意力
    3. 前馈神经网络:进行非线性特征变换
    4. 残差连接 + 层归一化
    5. Dropout:最终输出前进行随机失活

数学表达

  • 解码器层计算过程分为三个阶段:

    1. Masked自注意力阶段

    MaskedAtt ( Q , K , V ) = LayerNorm ( MultiHead ( Q , K , V ) + R e s i d u a l ) \text{MaskedAtt}(Q,K,V) = \text{LayerNorm}(\text{MultiHead}(Q,K,V) + Residual) MaskedAtt(Q,K,V)=LayerNorm(MultiHead(Q,K,V)+Residual)
    2. Encoder-Decoder注意力阶段

    CrossAtt ( Q d e c , K e n c , V e n c ) = LayerNorm ( MultiHead ( Q d e c , K e n c , V e n c ) + R e s i d u a l ) \text{CrossAtt}(Q_{dec}, K_{enc}, V_{enc}) = \text{LayerNorm}(\text{MultiHead}(Q_{dec},K_{enc},V_{enc}) + Residual) CrossAtt(Qdec,Kenc,Venc)=LayerNorm(MultiHead(Qdec,Kenc,Venc)+Residual)
    3. 前馈网络阶段

    FFN ( x ) = LayerNorm ( ReLU ( x W 1 + b 1 ) W 2 + b 2 + x ) \text{FFN}(x) = \text{LayerNorm}(\text{ReLU}(xW_1 + b_1)W_2 + b_2 + x) FFN(x)=LayerNorm(ReLU(xW1+b1)W2+b2+x)

    其中:

    1. d_model 为模型维度
    2. Residual 为残差连接
    3. 下标dec来源于Decoder自己的输出,下标enc为Encoder的输出

代码实现

  • 实现单层

    其他层的实现

    层名 链接
    PositionEncoding https://blog.csdn.net/hbkybkzw/article/details/147431820
    calculate_attention https://blog.csdn.net/hbkybkzw/article/details/147462845
    MultiHeadAttention https://blog.csdn.net/hbkybkzw/article/details/147490387
    FeedForward https://blog.csdn.net/hbkybkzw/article/details/147515883
    LayerNorm https://blog.csdn.net/hbkybkzw/article/details/147516529
    EncoderLayer https://blog.csdn.net/hbkybkzw/article/details/147591824

    下面统一在before.py中导入

  • 实现单层的DecoderLayer

    python 复制代码
    import torch 
    from torch import nn
    
    from before import PositionEncoding,calculate_attention,MultiHeadAttention,FeedForward,LayerNorm
    
    
    class DecoderLayer(nn.Module):
        def __init__(self, n_heads, d_model, ffn_hidden, dropout_prob=0.1):
            super(DecoderLayer, self).__init__()
            self.masked_att = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dropout_prob=dropout_prob)
            self.att = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dropout_prob=dropout_prob)
            self.norms = nn.ModuleList([LayerNorm(d_model=d_model) for _ in range(3)])  # 三个归一化层
            self.ffn = FeedForward(d_model=d_model, ffn_hidden=ffn_hidden, dropout_prob=dropout_prob)
            self.dropout = nn.Dropout(dropout_prob)
    
        def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None):
            # 第一阶段:Decoder 的Masked自注意力
            _x = x
            mask_att_out = self.masked_att(q=x, k=x, v=x, mask=dst_mask)
            mask_att_out = self.norms[0](mask_att_out + _x)  # 残差连接后归一化
    
            # 第二阶段:Encoder-Decoder注意力
            _x = mask_att_out
            att_out = self.att(q=mask_att_out, k=encoder_kv, v=encoder_kv, mask=src_dst_mask)
            att_out = self.norms[1](att_out + _x)
    
            # 第三阶段:前馈网络
            _x = att_out
            ffn_out = self.ffn(att_out)
            ffn_out = self.norms[2](ffn_out + _x)
            
            return self.dropout(ffn_out)
  • 注意力掩码机制

    掩码类型 作用域 功能描述
    dst_mask 目标序列自注意力 防止当前位置关注未来信息(因果掩码)
    src_dst_mask 编码器-解码器注意力 控制解码器查询对编码器键值对的访问权限
  • 参数说明

    参数名 类型 说明
    n_heads int 注意力头数量
    d_model int 模型隐藏层维度
    ffn_hidden int 前馈网络中间层维度(通常4倍)
    dropout_prob float Dropout概率(默认0.1)

使用示例

  • 测试代码

    python 复制代码
    if __name__ == "__main__":
        # 实例化解码器层:8头,512维,前馈层2048,20% dropout
        decoder_layer = DecoderLayer(n_heads=8, d_model=512, ffn_hidden=2048, dropout_prob=0.2)
    
        # 模拟输入:batch_size=4,目标序列长度50,编码器输出长度80
        x = torch.randn(4, 50, 512)
        encoder_out = torch.randn(4, 80, 512)
    
        tgt_mask = None
        src_mask = None
    
        output = decoder_layer(x, encoder_out, dst_mask=tgt_mask, src_dst_mask=src_mask)
    
        print("输入形状:", x.shape)
        print("encode_kv 形状:", encoder_out.shape)
        print("输出形状:", output.shape)

相关推荐
PersistJiao11 分钟前
Codex、Claude Code、gstack三者的关系
人工智能
数智工坊29 分钟前
【Mask2Former论文阅读】:基于掩码注意力的通用分割Transformer,大一统全景/实例/语义分割
论文阅读·深度学习·transformer
一切皆是因缘际会36 分钟前
AI数字分身的底层原理:破解意识、自我与人格复刻的核心难题
大数据·人工智能·ai·架构
翔云12345640 分钟前
vLLM全解析:定义、用途与竞品对比
人工智能·ai·大模型
ASKED_20191 小时前
KDD Cup 2026 腾讯算法广告大赛赛题解读: UNI-REC (统一序列建模与特征交叉)
人工智能
fpcc1 小时前
AI和大模型——Fine-tuning
人工智能·深度学习
爱问的艾文1 小时前
八周带你手搓AI应用-Day4-赋予你的AI“记忆力”
人工智能
ACP广源盛139246256731 小时前
IX8024与科学大模型的碰撞@ACP#筑牢科研 AI 算力高速枢纽分享
运维·服务器·网络·数据库·人工智能·嵌入式硬件·电脑
向量引擎2 小时前
向量引擎接入 GPT Image 2 和 deepseek v4:一个 api key 把热门模型串起来,开发者终于不用深夜修接口了
人工智能·gpt·计算机视觉·aigc·api·ai编程·key
努力努力再努力FFF2 小时前
医生对AI辅助诊断感兴趣,作为临床人员该怎么了解和学习?
人工智能·学习