通过“单词补全”演示 Transformer 原理(Python代码可运行)

我们来选一个 简单、直观、适合教学 的场景 ------


🎯 场景选择:字符级序列补全(Char-Level Sequence Completion)

比如:

  • 输入:"hel" → 输出:"hello"
  • 输入:"wor" → 输出:"world"
  • 输入:"by" → 输出:"bye"

我们训练一个极小的 Transformer,让它学会根据前几个字符"猜"出完整单词。

✅ 为什么选这个?

  • 数据小、维度低、训练快
  • 不需要预训练、不需要GPU
  • 能直观看到Attention在"关注"哪些字符
  • 完美展示Transformer的Encoder-Decoder结构

🧱 模型目标

  • 输入长度 = 3(如 "hel"
  • 输出长度 = 7(如 "hello")(确保总长度 <= 8)
  • 词表大小 = 29(a-z + <pad> + <sos> + <eos>
  • Embedding维度 = 64
  • 层数 = 1层Encoder + 1层Decoder
  • Head数 = 2
  • FFN维度 = 64

🧩 步骤概览

  1. 构建小型词表和数据集
  2. 实现位置编码
  3. 实现多头注意力(带mask)
  4. 实现Encoder层 & Decoder层
  5. 组装Transformer
  6. 训练 + 测试

💻 完整可运行Python代码(PyTorch)

无第三方依赖(除了torch),复制粘贴即可跑!

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os

# 设置随机种子,确保可复现
torch.manual_seed(42)
random.seed(42)

# ========== 1. 数据准备 ==========
vocab = ['<pad>', '<sos>', '<eos>'] + [chr(i) for i in range(ord('a'), ord('z') + 1)]
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for idx, ch in enumerate(vocab)}
vocab_size = len(vocab)

# 训练数据:输入3字符,输出完整单词(确保总长度 <= 8)
training_data = [
    ("hel", "hello"),   # <sos>hello<eos> → 7 tokens
    ("wor", "world"),   # 7
    ("by",  "bye"),     # 5
    ("goo", "good"),    # 6
    ("ple", "please"),  # 8
    ("app", "apple"),   # 7
    ("yes", "yes"),     # 5
    ("no",  "no"),      # 4
    ("tha", "thanks"),  # 8
    ("nig", "night"),   # 7
]

# ========== 2. 数据预处理函数 ==========
def process_data(data, src_len=3, tgt_len=8):
    src_list = []
    tgt_list = []
    for src, tgt in data:
        # 源序列:固定长度3,不足补<pad>
        src_indices = [char2idx[c] for c in src] + [char2idx['<pad>']] * (src_len - len(src))
        src_list.append(src_indices)

        # 目标序列:<sos> + 单词 + <eos>
        tgt_chars = ['<sos>'] + list(tgt) + ['<eos>']
        if len(tgt_chars) > tgt_len:
            # 超长时:保留<sos>和<eos>,中间截断
            tgt_chars = [tgt_chars[0]] + tgt_chars[1:tgt_len - 1] + ['<eos>']
        else:
            # 不足时补<pad>
            tgt_chars += ['<pad>'] * (tgt_len - len(tgt_chars))
        tgt_indices = [char2idx[c] for c in tgt_chars]
        tgt_list.append(tgt_indices)
    return torch.tensor(src_list), torch.tensor(tgt_list)

src_tensor, tgt_tensor = process_data(training_data, tgt_len=8)
print("✅ 数据预处理完成")
print("输入形状:", src_tensor.shape)   # [10, 3]
print("目标形状:", tgt_tensor.shape)  # [10, 8]
print("目标示例1:", ''.join([idx2char[i] for i in tgt_tensor[0].tolist()]))  # <sos>hello<eos><pad>
print("目标示例2:", ''.join([idx2char[i] for i in tgt_tensor[4].tolist()]))  # <sos>please<eos>

# ========== 3. 位置编码 ==========
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # [1, max_len, d_model]

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# ========== 4. 多头注意力 ==========
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        output = torch.matmul(attn, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.out(output)

# ========== 5. 前馈网络 ==========
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=128, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(self.dropout(x))

# ========== 6. Encoder层 ==========
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=128, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        attn_out = self.attn(x, x, x, src_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_out))
        return x

# ========== 7. Decoder层 ==========
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=128, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        self_attn_out = self.self_attn(x, x, x, tgt_mask)
        x = self.ln1(x + self.dropout(self_attn_out))
        cross_attn_out = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.ln2(x + self.dropout(cross_attn_out))
        ffn_out = self.ffn(x)
        x = self.ln3(x + self.dropout(ffn_out))
        return x

# ========== 8. Transformer模型 ==========
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=128, dropout=0.1, max_len=10):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def make_src_mask(self, src):
        return (src != char2idx['<pad>']).unsqueeze(1).unsqueeze(2)

    def make_tgt_mask(self, tgt):
        pad_mask = (tgt != char2idx['<pad>']).unsqueeze(1).unsqueeze(2)
        seq_len = tgt.size(1)
        nopeak_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(tgt.device)
        return pad_mask & nopeak_mask.unsqueeze(0).unsqueeze(0)

    def encode(self, src):
        src_mask = self.make_src_mask(src)
        x = self.dropout(self.pos_encoding(self.embedding(src)))
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x, src_mask

    def decode(self, tgt, enc_out, src_mask):
        tgt_mask = self.make_tgt_mask(tgt)
        x = self.dropout(self.pos_encoding(self.embedding(tgt)))
        for layer in self.decoder_layers:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return self.fc_out(x)

    def forward(self, src, tgt):
        enc_out, src_mask = self.encode(src)
        return self.decode(tgt, enc_out, src_mask)

    def generate(self, src, max_len=8):
        self.eval()
        with torch.no_grad():
            enc_out, src_mask = self.encode(src)
            batch_size = src.size(0)
            outputs = torch.full((batch_size, 1), char2idx['<sos>'], dtype=torch.long, device=src.device)

            for _ in range(max_len - 1):
                logits = self.decode(outputs, enc_out, src_mask)
                next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
                outputs = torch.cat([outputs, next_token], dim=1)

                # 如果生成了<eos>,立即停止
                if next_token.item() == char2idx['<eos>']:
                    break

            return outputs

def save_model(model, filepath):
    """保存模型参数"""
    torch.save(model.state_dict(), filepath)
    print(f"✅ 模型已保存到: {filepath}")

def load_model(filepath, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=128, dropout=0.1, max_len=10):
    """加载模型参数"""
    model = TinyTransformer(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_len)
    model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
    print(f"✅ 模型已从 {filepath} 加载")
    return model


# ========== 9. 训练设置 ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TinyTransformer(vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=128).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=char2idx['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

src_tensor = src_tensor.to(device)
tgt_tensor = tgt_tensor.to(device)

MODEL_PATH = "tiny_transformer.pth"  # 👈 模型保存路径

# ========== 10. 训练 or 加载 ==========
if os.path.exists(MODEL_PATH):
    print(f"🔍 发现已保存模型: {MODEL_PATH},跳过训练,直接加载...")
    model = load_model(MODEL_PATH, vocab_size).to(device)
else:
    print("\n🚀 开始训练...")
    model.train()
    for epoch in range(3000):
        optimizer.zero_grad()
        output = model(src_tensor, tgt_tensor[:, :-1])
        target = tgt_tensor[:, 1:].contiguous()
        loss = criterion(output.view(-1, vocab_size), target.view(-1))
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(f"Epoch {epoch:4d} | Loss: {loss.item():.6f}")

    # 训练完成后保存模型
    save_model(model, MODEL_PATH)

# ========== 11. 测试生成 ==========
test_inputs = ["hel", "wor", "by", "goo", "ple", "app", "ye", "no", "tha", "nig"]
model.eval().to(device)  # 确保在正确设备上

print("\n" + "="*50)
print("🎉 生成测试结果")
print("="*50)
for text in test_inputs:
    src_indices = [char2idx[c] for c in text] + [char2idx['<pad>']] * (3 - len(text))
    src_tensor_test = torch.tensor([src_indices]).to(device)
    generated = model.generate(src_tensor_test, max_len=8)
    gen_indices = generated[0].cpu().numpy()
    raw_pred = ''.join([idx2char[i] for i in gen_indices])
    gen_text = ''.join([
        idx2char[i] for i in gen_indices
        if idx2char[i] not in ['<sos>', '<eos>', '<pad>']
    ])
    print(f"输入: '{text:3s}' → 生成: '{gen_text:7s}' | 原始: {raw_pred}")

✅ 添加功能说明:

我们将添加两个函数:

  1. save_model(model, filepath) ------ 保存模型参数
  2. load_model(filepath, vocab_size, ...) ------ 加载模型参数(需先初始化相同结构的模型)

并在训练完成后自动保存模型,在测试前尝试加载(如果存在)。


🧪 实际运行结果:

xml 复制代码
✅ 数据预处理完成
输入形状: torch.Size([10, 3])
目标形状: torch.Size([10, 8])
目标示例1: <sos>hello<eos><pad>
目标示例2: <sos>please<eos>

🚀 开始训练...
Epoch    0 | Loss: 3.406661
Epoch  500 | Loss: 0.047509
Epoch 1000 | Loss: 0.002136
Epoch 1500 | Loss: 0.001080
Epoch 2000 | Loss: 0.000439
Epoch 2500 | Loss: 0.009502
✅ 模型已保存到: tiny_transformer.pth

==================================================
🎉 生成测试结果
==================================================
输入: 'hel' → 生成: 'hello  ' | 原始: <sos>hello<eos>
输入: 'wor' → 生成: 'world  ' | 原始: <sos>world<eos>
输入: 'by ' → 生成: 'bye    ' | 原始: <sos>bye<eos>
输入: 'goo' → 生成: 'good   ' | 原始: <sos>good<eos>
输入: 'ple' → 生成: 'please ' | 原始: <sos>please<eos>
输入: 'app' → 生成: 'apple  ' | 原始: <sos>apple<eos>
输入: 'ye ' → 生成: 'yes    ' | 原始: <sos>yes<eos>
输入: 'no ' → 生成: 'no     ' | 原始: <sos>no<eos>
输入: 'tha' → 生成: 'thanks ' | 原始: <sos>thanks<eos>
输入: 'nig' → 生成: 'night  ' | 原始: <sos>night<eos>

再次运行,加载模型直接推理

xml 复制代码
✅ 数据预处理完成
输入形状: torch.Size([10, 3])
目标形状: torch.Size([10, 8])
目标示例1: <sos>hello<eos><pad>
目标示例2: <sos>please<eos>
🔍 发现已保存模型: tiny_transformer.pth,跳过训练,直接加载...
✅ 模型已从 tiny_transformer.pth 加载

==================================================
🎉 生成测试结果
==================================================
输入: 'hel' → 生成: 'hello  ' | 原始: <sos>hello<eos>
输入: 'wor' → 生成: 'world  ' | 原始: <sos>world<eos>
输入: 'by ' → 生成: 'bye    ' | 原始: <sos>bye<eos>
输入: 'goo' → 生成: 'good   ' | 原始: <sos>good<eos>
输入: 'ple' → 生成: 'please ' | 原始: <sos>please<eos>
输入: 'app' → 生成: 'apple  ' | 原始: <sos>apple<eos>
输入: 'ye ' → 生成: 'yes    ' | 原始: <sos>yes<eos>
输入: 'no ' → 生成: 'no     ' | 原始: <sos>no<eos>
输入: 'tha' → 生成: 'thanks ' | 原始: <sos>thanks<eos>
输入: 'nig' → 生成: 'night  ' | 原始: <sos>night<eos>

✅ 你得到了一个真正的Transformer!

  • ✅ Encoder-Decoder结构
  • ✅ 多头注意力(含mask)
  • ✅ 位置编码
  • ✅ LayerNorm + 残差连接
  • ✅ 自回归生成

虽然很小(d_model=64),但它具备了Transformer的所有核心组件! 这个"麻雀虽小,五脏俱全"的Transformer,是你深入理解大模型的第一块基石 🧱

相关推荐
c8i2 小时前
关于python中的钩子方法和内置函数的举例
python
禁默2 小时前
第六届机器学习与计算机应用国际学术会议
运维·人工智能·机器学习·自动化
念念01072 小时前
基于机器学习的P2P网贷平台信用违约预测模型
人工智能·机器学习
悟乙己2 小时前
机器学习超参数调优全方法介绍指南
人工智能·机器学习·超参数
阿里云大数据AI技术2 小时前
Mem0 + Milvus:为人工智能构建持久化长时记忆
人工智能
悟乙己2 小时前
探讨Hyperband 等主要机器学习调优方法的机制和权衡
人工智能·机器学习·超参数·调参
藓类少女2 小时前
【深度学习】重采样(Resampling)
人工智能·深度学习
真上帝的左手2 小时前
26. AI-Agent-Dify
人工智能