7.1 PyTorch Transformer模块详解

PyTorch Transformer模块详解

目录

  • 基础算子层
  • 核心模块层
  • 位置编码
  • 前馈网络
  • 完整架构模块
  • 学习路径建议
  • 输入输出形状速查表

基础算子层

torch.matmul

功能说明: 矩阵乘法,是注意力机制的核心操作,用于计算查询矩阵(Q)和键矩阵(K)的点积。

python 复制代码
# 计算注意力分数
attn_scores = torch.matmul(q, k.transpose(-2, -1))

参数说明:

  • q: 查询矩阵,形状 [B, h, L_q, d]
  • k: 键矩阵,形状 [B, h, L_k, d]
  • k.transpose(-2, -1) : 形状 [B, h, d,L_k]

输出: 注意力分数矩阵,形状 [B, h, L_q, L_k]

torch.softmax

功能说明: 将注意力分数转换为概率分布,使每个位置的注意力权重和为1。

python 复制代码
attn_weights = torch.softmax(attn_scores, dim=-1)

参数说明:

  • attn_scores: 注意力分数矩阵
  • dim: 指定在哪个维度上进行softmax计算,通常为-1(最后一个维度)

输出: 注意力权重矩阵,形状与输入相同,最后一维和为1

torch.masked_fill

功能说明: 掩码操作,用于屏蔽padding位置或未来信息,将指定位置的值替换为指定值。

python 复制代码
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

参数说明:

  • mask: 掩码矩阵,0表示需要屏蔽的位置
  • -1e9: 替换值,通常使用一个很大的负数,经过softmax后会趋近于0

输出: 被掩码处理后的注意力分数矩阵

torch.sqrt

功能说明: 计算平方根,用于缩放点积注意力,防止点积结果过大导致梯度消失。

python 复制代码
scale = torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
attn_scores = attn_scores / scale

参数说明: head_dim - 每个注意力头的维度

输出: 缩放因子

张量形状变换操作

功能说明: 用于多头注意力的拆分与合并,包括view、reshape、transpose、contiguous等操作。

python 复制代码
# 拆分多头
q = q.view(B, L, h, d).transpose(1, 2)  # [B, L, d_model] -> [B, h, L, d]
# 合并多头
out = out.transpose(1, 2).contiguous().view(B, L, d_model)  # [B, h, L, d] -> [B, L, d_model]

核心模块层

nn.Linear

功能说明: 全连接层,用于Q/K/V/O投影和前馈网络中的线性变换。

python 复制代码
self.w_q = nn.Linear(d_model, d_model)

参数说明:

  • in_features: 输入特征维度
  • out_features: 输出特征维度
  • bias: 是否使用偏置项,默认为True

输入: [B, L, d_model]

输出: [B, L, d_model]

nn.Dropout

功能说明: Dropout层,用于防止过拟合,在训练时随机将部分神经元输出置零。

python 复制代码
self.dropout = nn.Dropout(p=0.1)

参数说明: p - 丢弃概率,0.1表示10%的神经元被随机置零

输入: 任意形状张量

输出: 同形状张量,训练时部分元素被置零

nn.LayerNorm

功能说明: 层归一化,对每个样本的特征维度进行归一化,加速训练并提高模型稳定性。

python 复制代码
self.norm = nn.LayerNorm(d_model)

参数说明:

  • normalized_shape: 需要归一化的维度大小
  • eps: 数值稳定性小量,默认1e-5
  • elementwise_affine: 是否学习缩放和平移参数,默认True

输入: [B, L, d_model]

输出: [B, L, d_model],最后一维被归一化

nn.Embedding

功能说明: 词嵌入层,将离散的token ID映射为连续的向量表示。

python 复制代码
self.embedding = nn.Embedding(vocab_size, d_model)

参数说明:

  • num_embeddings: 词汇表大小
  • embedding_dim: 嵌入向量维度

输入: [B, L],元素为token ID的LongTensor

输出: [B, L, d_model]

位置编码

PositionalEncoding

功能说明: 由于Transformer没有循环或卷积结构,需要显式添加位置信息。使用正弦和余弦函数生成位置编码。

python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(1)]

参数说明:

  • d_model: 模型维度
  • max_len: 最大序列长度

输入: [B, L, d_model]

输出: [B, L, d_model],加上了位置编码信息

前馈网络

PositionwiseFeedForward

功能说明: 位置级前馈网络,对序列中每个位置独立进行相同的变换。

python 复制代码
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        return self.fc2(self.dropout(self.activation(self.fc1(x))))

参数说明:

  • d_model: 模型维度
  • d_ff: 隐藏层维度,通常为4*d_model
  • dropout: Dropout概率

输入: [B, L, d_model]

输出: [B, L, d_model]

完整架构模块

Encoder Layer

结构组成: 多头自注意力 + 前馈网络,每个子层后都有残差连接和层归一化。

复制代码
输入 X → MultiHeadAttention(Q=X, K=X, V=X) → Add(X) → LayerNorm → 
       → FeedForward → Add → LayerNorm → 输出

Decoder Layer

结构组成: 三个子层:掩码多头自注意力、多头注意力(Encoder-Decoder Attention)、前馈网络。

复制代码
输入 Y → Masked MultiHeadAttention(Q=Y, K=Y, V=Y) → Add(Y) → LayerNorm →
       → MultiHeadAttention(Q=Y, K=Encoder输出, V=Encoder输出) → Add → LayerNorm →
       → FeedForward → Add → LayerNorm → 输出

整体架构

复制代码
Input → Embedding + PositionalEncoding → N × EncoderLayer → Encoder输出

Target → Embedding + PositionalEncoding → N × DecoderLayer → Linear → Softmax → 输出概率

学习路径建议

按以下顺序逐个实现,由简到繁:

  1. Scaled Dot-Product Attention(缩放点积注意力)
  2. Multi-Head Attention(多头注意力)
  3. Position-wise Feed Forward(前馈网络)
  4. Positional Encoding(位置编码)
  5. Encoder Layer(编码器层)
  6. Decoder Layer(解码器层)
  7. 完整 Transformer(拼接 Encoder + Decoder)

输入输出形状速查表

模块 输入形状 输出形状 说明
Embedding [B, L] [B, L, d_model] B=batch_size, L=seq_len
PositionalEncoding [B, L, d_model] [B, L, d_model] 添加位置信息
MultiHeadAttention Q/K/V: [B, L, d_model] [B, L, d_model] 多头注意力计算
FeedForward [B, L, d_model] [B, L, d_model] 位置级前馈网络
LayerNorm [B, L, d_model] [B, L, d_model] 层归一化
Linear (vocab投影) [B, L, d_model] [B, L, vocab_size] 词汇表投影

完整Transformer架构图

复制代码
                    Encoder
Input → Embedding → PositionalEncoding → 
        [MultiHeadAttention → Add & Norm → 
         FeedForward → Add & Norm] × N → Encoder输出

                    Decoder
Target → Embedding → PositionalEncoding → 
         [MaskedMultiHeadAttention → Add & Norm →
          MultiHeadAttention → Add & Norm →
          FeedForward → Add & Norm] × N → 
         Linear → Softmax → 输出概率

通过掌握以上模块,您将能够从零开始实现完整的Transformer架构。建议按照学习路径逐步实现,每完成一个模块都进行充分的测试验证。