我们来选一个 简单、直观、适合教学 的场景 ------
🎯 场景选择:字符级序列补全(Char-Level Sequence Completion)
比如:
- 输入:
"hel"
→ 输出:"hello"
- 输入:
"wor"
→ 输出:"world"
- 输入:
"by"
→ 输出:"bye"
我们训练一个极小的 Transformer,让它学会根据前几个字符"猜"出完整单词。
✅ 为什么选这个?
- 数据小、维度低、训练快
- 不需要预训练、不需要GPU
- 能直观看到Attention在"关注"哪些字符
- 完美展示Transformer的Encoder-Decoder结构
🧱 模型目标
- 输入长度 = 3(如
"hel"
) - 输出长度 = 7(如
"hello"
)(确保总长度 <= 8) - 词表大小 = 29(a-z +
<pad>
+<sos>
+<eos>
) - Embedding维度 = 64
- 层数 = 1层Encoder + 1层Decoder
- Head数 = 2
- FFN维度 = 64
🧩 步骤概览
- 构建小型词表和数据集
- 实现位置编码
- 实现多头注意力(带mask)
- 实现Encoder层 & Decoder层
- 组装Transformer
- 训练 + 测试
💻 完整可运行Python代码(PyTorch)
无第三方依赖(除了torch),复制粘贴即可跑!
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os
# 设置随机种子,确保可复现
torch.manual_seed(42)
random.seed(42)
# ========== 1. 数据准备 ==========
vocab = ['<pad>', '<sos>', '<eos>'] + [chr(i) for i in range(ord('a'), ord('z') + 1)]
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for idx, ch in enumerate(vocab)}
vocab_size = len(vocab)
# 训练数据:输入3字符,输出完整单词(确保总长度 <= 8)
training_data = [
("hel", "hello"), # <sos>hello<eos> → 7 tokens
("wor", "world"), # 7
("by", "bye"), # 5
("goo", "good"), # 6
("ple", "please"), # 8
("app", "apple"), # 7
("yes", "yes"), # 5
("no", "no"), # 4
("tha", "thanks"), # 8
("nig", "night"), # 7
]
# ========== 2. 数据预处理函数 ==========
def process_data(data, src_len=3, tgt_len=8):
src_list = []
tgt_list = []
for src, tgt in data:
# 源序列:固定长度3,不足补<pad>
src_indices = [char2idx[c] for c in src] + [char2idx['<pad>']] * (src_len - len(src))
src_list.append(src_indices)
# 目标序列:<sos> + 单词 + <eos>
tgt_chars = ['<sos>'] + list(tgt) + ['<eos>']
if len(tgt_chars) > tgt_len:
# 超长时:保留<sos>和<eos>,中间截断
tgt_chars = [tgt_chars[0]] + tgt_chars[1:tgt_len - 1] + ['<eos>']
else:
# 不足时补<pad>
tgt_chars += ['<pad>'] * (tgt_len - len(tgt_chars))
tgt_indices = [char2idx[c] for c in tgt_chars]
tgt_list.append(tgt_indices)
return torch.tensor(src_list), torch.tensor(tgt_list)
src_tensor, tgt_tensor = process_data(training_data, tgt_len=8)
print("✅ 数据预处理完成")
print("输入形状:", src_tensor.shape) # [10, 3]
print("目标形状:", tgt_tensor.shape) # [10, 8]
print("目标示例1:", ''.join([idx2char[i] for i in tgt_tensor[0].tolist()])) # <sos>hello<eos><pad>
print("目标示例2:", ''.join([idx2char[i] for i in tgt_tensor[4].tolist()])) # <sos>please<eos>
# ========== 3. 位置编码 ==========
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model]
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
# ========== 4. 多头注意力 ==========
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
return self.out(output)
# ========== 5. 前馈网络 ==========
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=128, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(self.dropout(x))
# ========== 6. Encoder层 ==========
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=128, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, src_mask=None):
attn_out = self.attn(x, x, x, src_mask)
x = self.ln1(x + self.dropout(attn_out))
ffn_out = self.ffn(x)
x = self.ln2(x + self.dropout(ffn_out))
return x
# ========== 7. Decoder层 ==========
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=128, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ln3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
self_attn_out = self.self_attn(x, x, x, tgt_mask)
x = self.ln1(x + self.dropout(self_attn_out))
cross_attn_out = self.cross_attn(x, enc_out, enc_out, src_mask)
x = self.ln2(x + self.dropout(cross_attn_out))
ffn_out = self.ffn(x)
x = self.ln3(x + self.dropout(ffn_out))
return x
# ========== 8. Transformer模型 ==========
class TinyTransformer(nn.Module):
def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=128, dropout=0.1, max_len=10):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
])
self.fc_out = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(dropout)
def make_src_mask(self, src):
return (src != char2idx['<pad>']).unsqueeze(1).unsqueeze(2)
def make_tgt_mask(self, tgt):
pad_mask = (tgt != char2idx['<pad>']).unsqueeze(1).unsqueeze(2)
seq_len = tgt.size(1)
nopeak_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(tgt.device)
return pad_mask & nopeak_mask.unsqueeze(0).unsqueeze(0)
def encode(self, src):
src_mask = self.make_src_mask(src)
x = self.dropout(self.pos_encoding(self.embedding(src)))
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x, src_mask
def decode(self, tgt, enc_out, src_mask):
tgt_mask = self.make_tgt_mask(tgt)
x = self.dropout(self.pos_encoding(self.embedding(tgt)))
for layer in self.decoder_layers:
x = layer(x, enc_out, src_mask, tgt_mask)
return self.fc_out(x)
def forward(self, src, tgt):
enc_out, src_mask = self.encode(src)
return self.decode(tgt, enc_out, src_mask)
def generate(self, src, max_len=8):
self.eval()
with torch.no_grad():
enc_out, src_mask = self.encode(src)
batch_size = src.size(0)
outputs = torch.full((batch_size, 1), char2idx['<sos>'], dtype=torch.long, device=src.device)
for _ in range(max_len - 1):
logits = self.decode(outputs, enc_out, src_mask)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
outputs = torch.cat([outputs, next_token], dim=1)
# 如果生成了<eos>,立即停止
if next_token.item() == char2idx['<eos>']:
break
return outputs
def save_model(model, filepath):
"""保存模型参数"""
torch.save(model.state_dict(), filepath)
print(f"✅ 模型已保存到: {filepath}")
def load_model(filepath, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=128, dropout=0.1, max_len=10):
"""加载模型参数"""
model = TinyTransformer(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_len)
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
print(f"✅ 模型已从 {filepath} 加载")
return model
# ========== 9. 训练设置 ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TinyTransformer(vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=128).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=char2idx['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
src_tensor = src_tensor.to(device)
tgt_tensor = tgt_tensor.to(device)
MODEL_PATH = "tiny_transformer.pth" # 👈 模型保存路径
# ========== 10. 训练 or 加载 ==========
if os.path.exists(MODEL_PATH):
print(f"🔍 发现已保存模型: {MODEL_PATH},跳过训练,直接加载...")
model = load_model(MODEL_PATH, vocab_size).to(device)
else:
print("\n🚀 开始训练...")
model.train()
for epoch in range(3000):
optimizer.zero_grad()
output = model(src_tensor, tgt_tensor[:, :-1])
target = tgt_tensor[:, 1:].contiguous()
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
optimizer.step()
if epoch % 500 == 0:
print(f"Epoch {epoch:4d} | Loss: {loss.item():.6f}")
# 训练完成后保存模型
save_model(model, MODEL_PATH)
# ========== 11. 测试生成 ==========
test_inputs = ["hel", "wor", "by", "goo", "ple", "app", "ye", "no", "tha", "nig"]
model.eval().to(device) # 确保在正确设备上
print("\n" + "="*50)
print("🎉 生成测试结果")
print("="*50)
for text in test_inputs:
src_indices = [char2idx[c] for c in text] + [char2idx['<pad>']] * (3 - len(text))
src_tensor_test = torch.tensor([src_indices]).to(device)
generated = model.generate(src_tensor_test, max_len=8)
gen_indices = generated[0].cpu().numpy()
raw_pred = ''.join([idx2char[i] for i in gen_indices])
gen_text = ''.join([
idx2char[i] for i in gen_indices
if idx2char[i] not in ['<sos>', '<eos>', '<pad>']
])
print(f"输入: '{text:3s}' → 生成: '{gen_text:7s}' | 原始: {raw_pred}")
✅ 添加功能说明:
我们将添加两个函数:
save_model(model, filepath)
------ 保存模型参数load_model(filepath, vocab_size, ...)
------ 加载模型参数(需先初始化相同结构的模型)
并在训练完成后自动保存模型,在测试前尝试加载(如果存在)。
🧪 实际运行结果:
xml
✅ 数据预处理完成
输入形状: torch.Size([10, 3])
目标形状: torch.Size([10, 8])
目标示例1: <sos>hello<eos><pad>
目标示例2: <sos>please<eos>
🚀 开始训练...
Epoch 0 | Loss: 3.406661
Epoch 500 | Loss: 0.047509
Epoch 1000 | Loss: 0.002136
Epoch 1500 | Loss: 0.001080
Epoch 2000 | Loss: 0.000439
Epoch 2500 | Loss: 0.009502
✅ 模型已保存到: tiny_transformer.pth
==================================================
🎉 生成测试结果
==================================================
输入: 'hel' → 生成: 'hello ' | 原始: <sos>hello<eos>
输入: 'wor' → 生成: 'world ' | 原始: <sos>world<eos>
输入: 'by ' → 生成: 'bye ' | 原始: <sos>bye<eos>
输入: 'goo' → 生成: 'good ' | 原始: <sos>good<eos>
输入: 'ple' → 生成: 'please ' | 原始: <sos>please<eos>
输入: 'app' → 生成: 'apple ' | 原始: <sos>apple<eos>
输入: 'ye ' → 生成: 'yes ' | 原始: <sos>yes<eos>
输入: 'no ' → 生成: 'no ' | 原始: <sos>no<eos>
输入: 'tha' → 生成: 'thanks ' | 原始: <sos>thanks<eos>
输入: 'nig' → 生成: 'night ' | 原始: <sos>night<eos>
再次运行,加载模型直接推理
xml
✅ 数据预处理完成
输入形状: torch.Size([10, 3])
目标形状: torch.Size([10, 8])
目标示例1: <sos>hello<eos><pad>
目标示例2: <sos>please<eos>
🔍 发现已保存模型: tiny_transformer.pth,跳过训练,直接加载...
✅ 模型已从 tiny_transformer.pth 加载
==================================================
🎉 生成测试结果
==================================================
输入: 'hel' → 生成: 'hello ' | 原始: <sos>hello<eos>
输入: 'wor' → 生成: 'world ' | 原始: <sos>world<eos>
输入: 'by ' → 生成: 'bye ' | 原始: <sos>bye<eos>
输入: 'goo' → 生成: 'good ' | 原始: <sos>good<eos>
输入: 'ple' → 生成: 'please ' | 原始: <sos>please<eos>
输入: 'app' → 生成: 'apple ' | 原始: <sos>apple<eos>
输入: 'ye ' → 生成: 'yes ' | 原始: <sos>yes<eos>
输入: 'no ' → 生成: 'no ' | 原始: <sos>no<eos>
输入: 'tha' → 生成: 'thanks ' | 原始: <sos>thanks<eos>
输入: 'nig' → 生成: 'night ' | 原始: <sos>night<eos>
✅ 你得到了一个真正的Transformer!
- ✅ Encoder-Decoder结构
- ✅ 多头注意力(含mask)
- ✅ 位置编码
- ✅ LayerNorm + 残差连接
- ✅ 自回归生成
虽然很小(d_model=64),但它具备了Transformer的所有核心组件! 这个"麻雀虽小,五脏俱全"的Transformer,是你深入理解大模型的第一块基石 🧱