PyTorch Transformer模块详解
目录
- 基础算子层
- 核心模块层
- 位置编码
- 前馈网络
- 完整架构模块
- 学习路径建议
- 输入输出形状速查表
基础算子层
torch.matmul
功能说明: 矩阵乘法,是注意力机制的核心操作,用于计算查询矩阵(Q)和键矩阵(K)的点积。
python
# 计算注意力分数
attn_scores = torch.matmul(q, k.transpose(-2, -1))
参数说明:
q: 查询矩阵,形状[B, h, L_q, d]k: 键矩阵,形状[B, h, L_k, d]- k.transpose(-2, -1) : 形状
[B, h, d,L_k]
输出: 注意力分数矩阵,形状 [B, h, L_q, L_k]
torch.softmax
功能说明: 将注意力分数转换为概率分布,使每个位置的注意力权重和为1。
python
attn_weights = torch.softmax(attn_scores, dim=-1)
参数说明:
attn_scores: 注意力分数矩阵dim: 指定在哪个维度上进行softmax计算,通常为-1(最后一个维度)
输出: 注意力权重矩阵,形状与输入相同,最后一维和为1
torch.masked_fill
功能说明: 掩码操作,用于屏蔽padding位置或未来信息,将指定位置的值替换为指定值。
python
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
参数说明:
mask: 掩码矩阵,0表示需要屏蔽的位置-1e9: 替换值,通常使用一个很大的负数,经过softmax后会趋近于0
输出: 被掩码处理后的注意力分数矩阵
torch.sqrt
功能说明: 计算平方根,用于缩放点积注意力,防止点积结果过大导致梯度消失。
python
scale = torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
attn_scores = attn_scores / scale
参数说明: head_dim - 每个注意力头的维度
输出: 缩放因子
张量形状变换操作
功能说明: 用于多头注意力的拆分与合并,包括view、reshape、transpose、contiguous等操作。
python
# 拆分多头
q = q.view(B, L, h, d).transpose(1, 2) # [B, L, d_model] -> [B, h, L, d]
# 合并多头
out = out.transpose(1, 2).contiguous().view(B, L, d_model) # [B, h, L, d] -> [B, L, d_model]
核心模块层
nn.Linear
功能说明: 全连接层,用于Q/K/V/O投影和前馈网络中的线性变换。
python
self.w_q = nn.Linear(d_model, d_model)
参数说明:
in_features: 输入特征维度out_features: 输出特征维度bias: 是否使用偏置项,默认为True
输入: [B, L, d_model]
输出: [B, L, d_model]
nn.Dropout
功能说明: Dropout层,用于防止过拟合,在训练时随机将部分神经元输出置零。
python
self.dropout = nn.Dropout(p=0.1)
参数说明: p - 丢弃概率,0.1表示10%的神经元被随机置零
输入: 任意形状张量
输出: 同形状张量,训练时部分元素被置零
nn.LayerNorm
功能说明: 层归一化,对每个样本的特征维度进行归一化,加速训练并提高模型稳定性。
python
self.norm = nn.LayerNorm(d_model)
参数说明:
normalized_shape: 需要归一化的维度大小eps: 数值稳定性小量,默认1e-5elementwise_affine: 是否学习缩放和平移参数,默认True
输入: [B, L, d_model]
输出: [B, L, d_model],最后一维被归一化
nn.Embedding
功能说明: 词嵌入层,将离散的token ID映射为连续的向量表示。
python
self.embedding = nn.Embedding(vocab_size, d_model)
参数说明:
num_embeddings: 词汇表大小embedding_dim: 嵌入向量维度
输入: [B, L],元素为token ID的LongTensor
输出: [B, L, d_model]
位置编码
PositionalEncoding
功能说明: 由于Transformer没有循环或卷积结构,需要显式添加位置信息。使用正弦和余弦函数生成位置编码。
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1)]
参数说明:
d_model: 模型维度max_len: 最大序列长度
输入: [B, L, d_model]
输出: [B, L, d_model],加上了位置编码信息
前馈网络
PositionwiseFeedForward
功能说明: 位置级前馈网络,对序列中每个位置独立进行相同的变换。
python
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x):
return self.fc2(self.dropout(self.activation(self.fc1(x))))
参数说明:
d_model: 模型维度d_ff: 隐藏层维度,通常为4*d_modeldropout: Dropout概率
输入: [B, L, d_model]
输出: [B, L, d_model]
完整架构模块
Encoder Layer
结构组成: 多头自注意力 + 前馈网络,每个子层后都有残差连接和层归一化。
输入 X → MultiHeadAttention(Q=X, K=X, V=X) → Add(X) → LayerNorm →
→ FeedForward → Add → LayerNorm → 输出
Decoder Layer
结构组成: 三个子层:掩码多头自注意力、多头注意力(Encoder-Decoder Attention)、前馈网络。
输入 Y → Masked MultiHeadAttention(Q=Y, K=Y, V=Y) → Add(Y) → LayerNorm →
→ MultiHeadAttention(Q=Y, K=Encoder输出, V=Encoder输出) → Add → LayerNorm →
→ FeedForward → Add → LayerNorm → 输出
整体架构
Input → Embedding + PositionalEncoding → N × EncoderLayer → Encoder输出
Target → Embedding + PositionalEncoding → N × DecoderLayer → Linear → Softmax → 输出概率
学习路径建议
按以下顺序逐个实现,由简到繁:
- Scaled Dot-Product Attention(缩放点积注意力)
- Multi-Head Attention(多头注意力)
- Position-wise Feed Forward(前馈网络)
- Positional Encoding(位置编码)
- Encoder Layer(编码器层)
- Decoder Layer(解码器层)
- 完整 Transformer(拼接 Encoder + Decoder)
输入输出形状速查表
| 模块 | 输入形状 | 输出形状 | 说明 |
|---|---|---|---|
| Embedding | [B, L] |
[B, L, d_model] |
B=batch_size, L=seq_len |
| PositionalEncoding | [B, L, d_model] |
[B, L, d_model] |
添加位置信息 |
| MultiHeadAttention | Q/K/V: [B, L, d_model] |
[B, L, d_model] |
多头注意力计算 |
| FeedForward | [B, L, d_model] |
[B, L, d_model] |
位置级前馈网络 |
| LayerNorm | [B, L, d_model] |
[B, L, d_model] |
层归一化 |
| Linear (vocab投影) | [B, L, d_model] |
[B, L, vocab_size] |
词汇表投影 |
完整Transformer架构图
Encoder
Input → Embedding → PositionalEncoding →
[MultiHeadAttention → Add & Norm →
FeedForward → Add & Norm] × N → Encoder输出
Decoder
Target → Embedding → PositionalEncoding →
[MaskedMultiHeadAttention → Add & Norm →
MultiHeadAttention → Add & Norm →
FeedForward → Add & Norm] × N →
Linear → Softmax → 输出概率
通过掌握以上模块,您将能够从零开始实现完整的Transformer架构。建议按照学习路径逐步实现,每完成一个模块都进行充分的测试验证。