逐行拆解一个经典Transformer的PyTorch实现,并对应到"猫坐在垫子上"这个例子。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 假设词汇表大小=1000,向量维度=512,批次大小=2
# 输入: ["猫坐在垫子上", "狗在跑"]
1. 位置编码(Positional Encoding)
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# 创建[max_len, d_model]的全0矩阵
pe = torch.zeros(max_len, d_model)
# position: [0,1,2,...,max_len-1] 纵向
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# div_term: 10000^(2i/d_model) 的分母
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 偶数位置用sin,奇数位置用cos
pe[:, 0::2] = torch.sin(position * div_term) # 0,2,4,...维度
pe[:, 1::2] = torch.cos(position * div_term) # 1,3,5,...维度
pe = pe.unsqueeze(0).transpose(0, 1) # shape: [max_len, 1, d_model]
self.register_buffer('pe', pe) # 不更新的常量
def forward(self, x):
# x: [seq_len, batch_size, d_model]
# 将对应长度的位置编码加到输入上
x = x + self.pe[:x.size(0), :] # "猫"在第0位,"坐"在第1位...
return x
2. 缩放点积注意力(Scaled Dot-Product Attention)
python
def attention(q, k, v, mask=None, dropout=None):
# q,k,v: [batch_size, num_heads, seq_len, d_k] (d_k = d_model/num_heads)
d_k = q.size(-1)
# 1) Q*K^T / sqrt(d_k)
# scores: [batch, heads, seq_len, seq_len]
# 第i行第j列表示"第i个词对第j个词的注意力分数"
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 2) 应用mask(解码器用,防止看到未来词)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # 填充极小值
# 3) Softmax得到注意力权重
# attn_weights: [batch, heads, seq_len, seq_len]
# 每行和为1,表示每个词对其他词的注意力分布
attn_weights = F.softmax(scores, dim=-1)
# 4) 应用dropout防止过拟合
if dropout is not None:
attn_weights = dropout(attn_weights)
# 5) 加权求和: weights * V
# output: [batch, heads, seq_len, d_k]
# 此时"垫子"的向量已融合"猫"的信息
output = torch.matmul(attn_weights, v)
return output, attn_weights
3. 多头注意力(Multi-Head Attention)
python
class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % h == 0 # 必须整除
self.d_k = d_model // h # 每个头的维度,例如512/8=64
self.h = h # 头数
# 4个线性层: Q, K, V, 输出
self.linear_q = nn.Linear(d_model, d_model)
self.linear_k = nn.Linear(d_model, d_model)
self.linear_v = nn.Linear(d_model, d_model)
self.linear_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 1) 线性变换 + 分头
# q: [batch, seq_len, d_model] -> [batch, seq_len, heads, d_k]
q = self.linear_q(q).view(batch_size, -1, self.h, self.d_k)
k = self.linear_k(k).view(batch_size, -1, self.h, self.d_k)
v = self.linear_v(v).view(batch_size, -1, self.h, self.d_k)
# 2) 转置: [batch, seq_len, heads, d_k] -> [batch, heads, seq_len, d_k]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 3) 计算注意力
# attn_output: [batch, heads, seq_len, d_k]
# attn_weights: [batch, heads, seq_len, seq_len]
attn_output, attn_weights = attention(q, k, v, mask, self.dropout)
# 4) 拼接多头
# 先转回: [batch, seq_len, heads, d_k]
# 再view: [batch, seq_len, d_model]
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.h * self.d_k)
# 5) 最后的线性层
output = self.linear_out(attn_output)
return output, attn_weights
4. 前馈网络(Position-wise Feed-Forward)
python
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(FeedForward, self).__init__()
# 两层全连接: d_model -> d_ff -> d_model
# d_ff通常是4*d_model,如2048
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x: [batch, seq_len, d_model]
# 第一层 + ReLU激活
# intermediate: [batch, seq_len, d_ff]
# 这一步学习非线性特征,如"坐"+"垫子"→"休息"
intermediate = F.relu(self.w1(x))
# Dropout防止过拟合
intermediate = self.dropout(intermediate)
# 第二层投影回原维度
output = self.w2(intermediate)
return output
5. 编码器层(Encoder Layer)
python
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.attention = MultiHeadAttention(num_heads, d_model, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# x: [batch, seq_len, d_model]
# 1) 多头自注意力
# 编码器Q,K,V都来自自己
# attn_output: [batch, seq_len, d_model]
attn_output, _ = self.attention(x, x, x, mask)
# 2) 残差连接 + 层归一化
# 先dropout,再加原输入(短路连接)
# 防止梯度消失,让训练更稳定
x = self.norm1(x + self.dropout(attn_output))
# 3) 前馈网络
ff_output = self.feed_forward(x)
# 4) 残差连接 + 层归一化
x = self.norm2(x + self.dropout(ff_output))
return x
6. 编码器(Encoder)
python
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout):
super(Encoder, self).__init__()
# 词嵌入层: 将token id转为向量
# 如"猫"的id=5 -> 512维向量
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model)
# 多个编码器层堆叠
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers) # 如12层
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# x: [batch, seq_len] token ids
# e.g., [[5, 10, 2, 50, 3, 1], [8, 4, 2, 30, 1, 0]]
# 1) 词嵌入
# x: [batch, seq_len, d_model]
x = self.embedding(x)
# 2) 乘以sqrt(d_model)来缩放(和经验发现有关)
x = x * math.sqrt(x.size(-1))
# 3) 位置编码
x = self.pos_encoding(x)
# 4) Dropout
x = self.dropout(x)
# 5) 依次通过所有编码器层
# 每层都让token间的交互更深
# 第1层:"垫子"知道有"猫"
# 第6层:理解"猫坐在垫子上"的场景
# 第12层:捕捉完整语义和潜在逻辑
for layer in self.layers:
x = layer(x, mask)
return x
7. 解码器层(Decoder Layer)
python
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
# 1) 带掩码的自注意力(防止看到未来词)
self.self_attn = MultiHeadAttention(num_heads, d_model, dropout)
# 2) 编码器-解码器注意力
self.enc_attn = MultiHeadAttention(num_heads, d_model, dropout)
# 3) 前馈网络
self.feed_forward = FeedForward(d_model, d_ff, dropout)
# 三个层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# x: 解码器输入 [batch, tgt_len, d_model]
# enc_output: 编码器输出 [batch, src_len, d_model]
# 1) 带掩码的自注意力
# 防止生成第i个词时看到i+1及以后的词
self_attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(self_attn_output))
# 2) 编码器-解码器注意力
# Q来自解码器,K,V来自编码器
# 让解码器关注编码器的输出
enc_attn_output, attn_weights = self.enc_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(enc_attn_output))
# 3) 前馈网络
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x, attn_weights
8. 解码器(Decoder)
python
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
x = self.embedding(x) * math.sqrt(x.size(-1))
x = self.pos_encoding(x)
x = self.dropout(x)
for layer in self.layers:
x, attn_weights = layer(x, enc_output, src_mask, tgt_mask)
return x, attn_weights
9. 完整Transformer
python
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=512,
num_heads=8, d_ff=2048, num_layers=6, dropout=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(src_vocab, d_model, num_heads, d_ff, num_layers, dropout)
self.decoder = Decoder(tgt_vocab, d_model, num_heads, d_ff, num_layers, dropout)
# 最后的线性层 + Softmax,输出词表概率
self.linear_out = nn.Linear(d_model, tgt_vocab)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# src: [batch, src_len] 源语言token ids
# tgt: [batch, tgt_len] 目标语言token ids
# 1) 编码器处理源语言
# enc_output: [batch, src_len, d_model]
enc_output = self.encoder(src, src_mask)
# 2) 解码器生成目标语言
# dec_output: [batch, tgt_len, d_model]
dec_output, attn_weights = self.decoder(tgt, enc_output, src_mask, tgt_mask)
# 3) 投影到词表空间
# output: [batch, tgt_len, tgt_vocab]
output = self.linear_out(dec_output)
return output, attn_weights
10. 使用示例
python
# 超参数
src_vocab = 1000 # 中文词表大小
tgt_vocab = 1000 # 英文词表大小
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6
# 创建模型
model = Transformer(src_vocab, tgt_vocab, d_model, num_heads, d_ff, num_layers)
# 模拟输入
src = torch.tensor([[5, 10, 2, 50, 3, 1], [8, 4, 2, 30, 1, 0]]) # [2,6]
tgt = torch.tensor([[1, 15, 20, 4, 2, 0], [1, 12, 6, 7, 2, 0]]) # [2,6]
# 前向传播
output, attn = model(src, tgt)
# output: [2, 6, 1000],每个位置是英文词的概率分布
# 如output[0,0]是"<BOS>"的概率,output[0,1]是"The"的概率...
关键变量形状演变总结
输入id: [batch, seq_len] e.g., [2, 6]
↓ 嵌入
词向量: [batch, seq_len, d_model] e.g., [2, 6, 512]
↓ 多头注意力
分头后: [batch, heads, seq_len, d_k] e.g., [2, 8, 6, 64]
↓ 注意力计算
输出: [batch, heads, seq_len, d_k] e.g., [2, 8, 6, 64]
↓ 拼接
合并头: [batch, seq_len, d_model] e.g., [2, 6, 512]
↓ 前馈网络
最终输出: [batch, seq_len, d_model] e.g., [2, 6, 512]
每一层都在同一个向量空间中操作,但通过注意力机制不断交换信息,让模型逐层抽象出从字形→词义→句法→语义的层次化理解。