Transformer-解码器_编码器部分

一、输入部分:词嵌入与位置编码

输入部分是Transformer处理原始文本的第一步,负责将离散的文本符号转化为包含语义和位置信息的连续向量。

1. 词嵌入(embeddings类)

  • ​作用​​:将文本中的每个词(用索引表示)映射到高维向量空间,捕捉词的语义信息。

  • ​核心代码解析​​:

    python 复制代码
    class embeddings(nn.Module):
        def __init__(self, vocab_size, d_model):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, d_model)  # 词嵌入层
    
        def forward(self, x):
            # 乘以缩放系数√d_model,控制嵌入向量的方差
            return self.embedding(x) * math.sqrt(self.d_model)
  • ​关键细节​​:

    • 输入x是词索引张量(形状:[batch_size, seq_len]),输出是词嵌入向量(形状:[batch_size, seq_len, d_model])。

    • 乘以math.sqrt(d_model)的原因:补偿Xavier初始化的非正态分布特性,使嵌入向量保持合理的方差,避免后续计算中数值过大或过小。

2. 位置编码(positional_encoding类)

  • ​作用​​:Transformer没有循环结构,需通过位置编码向模型注入词的位置信息,使模型感知词在序列中的顺序。

  • ​核心代码解析​​:

    python 复制代码
    class positional_encoding(nn.Module):
        def __init__(self, d_model, dropout, max_len=100):
            super().__init__()
            self.droupout = nn.Dropout(p=dropout)
            pe = torch.zeros(max_len, d_model)  # 位置编码矩阵
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 位置索引
            _2i = torch.arange(0, d_model, 2).float()  # 偶数维度索引
            # 偶数维度用正弦函数,奇数维度用余弦函数
            pe[:, 0::2] = torch.sin(position / (10000 **(_2i / d_model)))
            pe[:, 1::2] = torch.cos(position / (10000** (_2i / d_model)))
            pe = pe.unsqueeze(0)  # 增加batch维度
            self.register_buffer('pe', pe)  # 非参数化缓冲区,不参与训练
    
        def forward(self, x):
            x = x + self.pe[:, :x.size(1)]  # 与词嵌入相加(广播机制)
            return self.droupout(x)
  • ​关键细节​​:

    • 位置编码公式:对于位置pos和维度i,偶数isin(pos/10000^(i/d_model)),奇数icos(pos/10000^(i/d_model))

    • 优势:能表示任意长度的序列(通过外推),且相邻位置的编码具有相似性。

    • 输出形状与词嵌入相同([batch_size, seq_len, d_model]),与词嵌入向量逐元素相加后经dropout输出。

二、核心机制:注意力机制

注意力机制是Transformer的核心,用于捕捉序列中不同词之间的依赖关系(如"猫"和"它"的指代关系)。

1. 基础注意力计算(attention函数)

  • ​作用​​:通过query(查询)、key(键)、value(值)计算注意力分布,输出加权求和的向量。

  • ​核心代码解析​​:

    python 复制代码
    def attention(query, key, value, mask=None, dropout=None):
        d_k = query.shape[-1]  # 每个头的维度
        # 计算注意力分数:(Q*K^T)/√d_k(缩放点积)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩码处理(如屏蔽填充词或未来词)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # 掩码位置置为负无穷
        # 计算注意力权重(softmax归一化)
        p_attn = torch.softmax(scores, dim=-1)
        # dropout防止过拟合
        if dropout is not None:
            p_attn = dropout(p_attn)
        # 加权求和(注意力权重 * value)
        return torch.matmul(p_attn, value), p_attn
  • ​关键细节​​:

    • 缩放因子√d_k的作用:当d_k较大时,点积结果可能过大,导致softmax梯度消失,缩放后可稳定数值范围。

    • 掩码(mask):用于屏蔽无效信息(如编码器中屏蔽填充词,解码器中屏蔽未来词),确保模型只关注有效位置。

2. 多头注意力(multi_head_attn类)

  • ​作用​​:将注意力机制分为多个"头"(head),并行计算不同子空间的注意力,捕捉更丰富的依赖关系。

  • ​核心代码解析​​:

    python 复制代码
    class multi_head_attn(nn.Module):
        def __init__(self, d_model, n_head, dropout=0.1):
            super().__init__()
            assert d_model % n_head == 0  # d_model必须能被头数整除
            self.n_head = n_head  # 头数
            self.d_k = d_model // n_head  # 每个头的维度
            self.linears = clones(nn.Linear(d_model, d_model), 4)  # 4个线性层(Q、K、V投影+输出投影)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, query, key, value, mask=None):
            if mask is not None:
                mask = mask.unsqueeze(0)  # 扩展掩码维度以适配多头
            batch_size = query.size(0)
            # 1. 线性投影并拆分多头:[batch, seq_len, d_model] → [batch, n_head, seq_len, d_k]
            query, key, value = [
                model(x).reshape(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
                for model, x in zip(self.linears, [query, key, value])
            ]
            # 2. 计算多头注意力
            attn, self.attn_weights = attention(query, key, value, mask=mask, dropout=self.dropout)
            # 3. 合并多头:[batch, n_head, seq_len, d_k] → [batch, seq_len, d_model]
            attn = attn.transpose(1, 2).reshape(batch_size, -1, self.n_head * self.d_k)
            # 4. 输出投影
            return self.linears[-1](attn)
  • ​关键细节​​:

    • 多头拆分:将d_model维度拆分为n_headd_k维度(d_model = n_head * d_k),每个头独立计算注意力。

    • 优势:并行学习不同的注意力模式(如语法依赖、语义关联),提升模型表达能力。

三、编码器部分

编码器负责对输入序列进行特征提取,由N个相同的编码器层堆叠而成。

1. 前馈网络(FeedForward类)

  • ​作用​​:对注意力输出进行非线性变换,增强模型对复杂模式的拟合能力。

  • ​核心代码解析​​:

    python 复制代码
    class FeedForward(nn.Module):
        def __init__(self, d_model, dff, dropout=0.1):
            super().__init__()
            self.linear1 = nn.Linear(d_model, dff)  # 升维:d_model → dff
            self.linear2 = nn.Linear(dff, d_model)  # 降维:dff → d_model
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x):
            # 线性变换→ReLU激活→dropout→线性变换
            return self.linear2(self.dropout(torch.relu(self.linear1(x))))
  • ​关键细节​​:

    • 输入输出维度均为d_model,中间通过dff(通常为4*d_model)升维,引入非线性(ReLU)后再降维。

2. 层归一化(layer_norm类)

  • ​作用​​:对每一层的输出进行归一化,使数据分布稳定,加速训练收敛。

  • ​核心代码解析​​:

    python 复制代码
    class layer_norm(nn.Module):
        def __init__(self, size, eps=1e-6):
            super().__init__()
            self.a = nn.Parameter(torch.ones(size))  # 缩放参数(可学习)
            self.b = nn.Parameter(torch.zeros(size))  # 偏移参数(可学习)
            self.eps = eps  # 防止除零
    
        def forward(self, x):
            mean = x.mean(dim=-1, keepdim=True)  # 沿最后一维(特征维度)计算均值
            std = x.std(dim=-1, keepdim=True)    # 沿最后一维计算标准差
            return self.a * (x - mean) / (std + self.eps) + self.b  # 归一化+缩放偏移

3. 子层连接(sub_layer_conncetion类)

  • ​作用​​:将注意力层/前馈网络与残差连接、层归一化结合,缓解深层网络的梯度消失问题。

  • ​核心代码解析​​:

    python 复制代码
    class sub_layer_conncetion(nn.Module):
        def __init__(self, size, dropout=0.1):
            super().__init__()
            self.norm = layer_norm(size)  # 层归一化
            self.dropout = nn.Dropout(dropout)  # dropout
    
        def forward(self, x, sub_layer):
            # 残差连接:x + 子层输出(子层输入先归一化)
            return x + self.dropout(sub_layer(self.norm(x)))
  • ​流程​ ​:输入x先经层归一化,再送入子层(注意力或前馈网络),子层输出经dropout后与原始x残差相加。

4. 编码器层(encoder_layer类)

  • ​作用​​:编码器的基本单元,包含"多头自注意力"和"前馈网络"两个子层。

  • ​核心代码解析​​:

    python 复制代码
    class encoder_layer(nn.Module):
        def __init__(self, size, self_attn, feed_forward, dropout):
            super().__init__()
            self.self_attn = self_attn  # 多头自注意力
            self.feed_forward = feed_forward  # 前馈网络
            self.sub_layer = clones(sub_layer_conncetion(size, dropout), 2)  # 2个子层连接
    
        def forward(self, x, mask):
            # 第1个子层:多头自注意力(输入x既是Q、K,也是V)
            x = self.sub_layer[0](x, lambda x: self.self_attn(x, x, x, mask))
            # 第2个子层:前馈网络
            x = self.sub_layer[1](x, lambda x: self.feed_forward(x))
            return x

5. 编码器(encoder类)

  • ​作用​ ​:堆叠N个编码器层(通常N=6),对输入序列进行深度编码。

  • ​核心代码解析​​:

    python 复制代码
    class encoder(nn.Module):
        def __init__(self, layer, N):
            super().__init__()
            self.layers = clones(layer, N)  # 克隆N个编码器层
            self.norm = layer_norm(layer.size)  # 最终层归一化
    
        def forward(self, x, mask):
            for layer in self.layers:
                x = layer(x, mask)  # 依次通过每个编码器层
            return self.norm(x)  # 最终归一化输出

四、解码器部分

解码器负责根据编码器的输出(memory)和目标序列,生成输出序列,由N个相同的解码器层堆叠而成。

1. 解码器层(DecoderLayer类)

  • ​作用​​:解码器的基本单元,包含3个子层:"解码器自注意力""编码器-解码器注意力""前馈网络"。

  • ​核心代码解析​​:

    python 复制代码
    class DecoderLayer(nn.Module):
        def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
            super().__init__()
            self.self_attn = self_attn  # 解码器自注意力(目标序列内部)
            self.src_attn = src_attn    # 编码器-解码器注意力(关联输入和目标)
            self.feed_forward = feed_forward  # 前馈网络
            self.sub_layers = clones(sub_layer_conncetion(size, dropout), 3)  # 3个子层连接
    
        def forward(self, x, memory, source_mask, target_mask):
            m = memory  # 编码器输出
            # 第1个子层:解码器自注意力(带目标掩码,防止关注未来词)
            x = self.sub_layers[0](x, lambda x: self.self_attn(x, x, x, target_mask))
            # 第2个子层:编码器-解码器注意力(Q=解码器输出,K=V=编码器输出)
            x = self.sub_layers[1](x, lambda x: self.src_attn(x, m, m, source_mask))
            # 第3个子层:前馈网络
            x = self.sub_layers[2](x, lambda x: self.feed_forward(x))
            return x

2. 解码器(Decoder类)

  • ​作用​ ​:堆叠N个解码器层(通常N=6),生成目标序列的特征表示。

  • ​核心代码解析​​:

    python 复制代码
    class Decoder(nn.Module):
        def __init__(self, layer, N):
            super().__init__()
            self.layers = clones(layer, N)  # 克隆N个解码器层
            self.norm = layer_norm(layer.size)  # 最终层归一化
    
        def forward(self, x, memory, source_mask, target_mask):
            for layer in self.layers:
                x = layer(x, memory, source_mask, target_mask)  # 依次通过每个解码器层
            return self.norm(x)  # 最终归一化输出

五、生成器(Generator类)

  • ​作用​​:将解码器的输出转化为词表上的概率分布,生成最终的输出序列。

  • ​核心代码解析​​:

    python 复制代码
    class Generator(nn.Module):
        def __init__(self, d_model, vocab_size):
            super().__init__()
            self.linear = nn.Linear(d_model, vocab_size)  # 映射到词表大小
    
        def forward(self, x):
            # log_softmax便于计算交叉熵损失
            return torch.log_softmax(self.linear(x), dim=-1)
  • ​输出​ ​:形状为[batch_size, seq_len, vocab_size],每个位置对应词表中所有词的对数概率。

六、关键工具函数

  1. clones函数​ ​:克隆N个相同的模块(如编码器层、线性层),确保参数独立但结构相同。

    python 复制代码
    def clones(module, N):
        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  2. ​掩码函数​​:

    • subsequent_mask(size):生成下三角掩码,用于解码器自注意力,防止模型关注未来的词(如翻译时"我吃"不能关注"饭"来预测"吃")。

七、整体流程总结

  1. ​输入处理​​:文本→词嵌入(+缩放)→+位置编码→输入向量。

  2. ​编码器​ ​:输入向量经N个编码器层(多头自注意力+前馈网络)→memory(编码后的输入特征)。

  3. ​解码器​ ​:目标序列经词嵌入+位置编码后,与memory一起输入N个解码器层(解码器自注意力+编码器-解码器注意力+前馈网络)→解码特征。

  4. ​生成器​​:解码特征→词表概率分布→输出序列。

通过上述模块的协作,Transformer能够高效捕捉序列中的长距离依赖,在机器翻译、文本摘要等任务中表现优异。

相关推荐
悟乙己4 小时前
PandasAI :使用 AI 优化你的分析工作流
人工智能·pandas·pandasai
东临碣石824 小时前
【AI论文】CoDA:面向协作数据可视化的智能体系统
人工智能
【建模先锋】4 小时前
多源信息融合+经典卷积网络故障诊断模型合集
深度学习·信号处理·故障诊断·多源信息融合
中杯可乐多加冰4 小时前
无代码开发实践 | 基于权限管理能力快速开发人力资源管理系统
人工智能·低代码
钊气蓬勃.4 小时前
深度学习笔记:入门
人工智能·笔记·深度学习
神码小Z4 小时前
特斯拉前AI总监开源的一款“小型本地版ChatGPT”,普通家用电脑就能运行!
人工智能·chatgpt
IT_陈寒4 小时前
Redis性能翻倍的7个冷门技巧:从P5到P8都在偷偷用的优化策略!
前端·人工智能·后端
AKAMAI4 小时前
直播监控的生死时速:深夜告警引发的系统崩溃危机
人工智能·云计算·直播
B站计算机毕业设计之家4 小时前
深度学习实战:python动物识别分类检测系统 计算机视觉 Django框架 CNN算法 深度学习 卷积神经网络 TensorFlow 毕业设计(建议收藏)✅
python·深度学习·算法·计算机视觉·分类·毕业设计·动物识别