💡 图解Transformer生命周期:训练、自回归生成与Beam Search的视觉化解析

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在个人主页。

本文深入解析Transformer模型的训练与推理机制,通过可视化图解和完整代码实现,系统讲解训练过程、自回归生成原理以及Beam Search优化策略。

一、Transformer训练过程解析

1.1 训练流程概览

1.2 关键训练组件实现

python 复制代码
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
# 自定义数据集
class TranslationDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab):
        self.src_enc = [[src_vocab[word] for word in sent.split()] for sent in src_sentences]
        self.tgt_enc = [[tgt_vocab[word] for word in sent.split()] for sent in tgt_sentences]
        
    def __len__(self):
        return len(self.src_enc)
    
    def __getitem__(self, idx):
        return torch.tensor(self.src_enc[idx]), torch.tensor(self.tgt_enc[idx])
# 训练循环函数
def train_transformer(model, dataloader, epochs=10, lr=0.001):
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充符
    optimizer = Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        total_loss = 0
        for src, tgt in dataloader:
            # 准备数据 (添加起始/终止符)
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            # 前向传播
            pred = model(src, tgt_input)
            
            # 计算损失 (展平序列维度)
            loss = criterion(
                pred.reshape(-1, pred.size(-1)), 
                tgt_output.reshape(-1)
            )
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            
            # 梯度裁剪 (防止爆炸)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # 参数更新
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")
    
    return model

1.3 训练过程可视化

python 复制代码
import matplotlib.pyplot as plt
# 模拟训练损失
epochs = 10
train_loss = [3.2, 2.1, 1.5, 1.2, 0.9, 0.7, 0.6, 0.5, 0.45, 0.4]
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs+1), train_loss, 'o-')
plt.title('Transformer训练损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.grid(True)
plt.xticks(range(1, epochs+1))
plt.show()

训练关键要素:

  • 数据批处理:动态填充与掩码生成

  • 教师强制:训练时使用真实目标序列

  • 梯度裁剪:防止梯度爆炸

  • 学习率调度:预热与衰减策略

二、Transformer推理过程:自回归生成

2.1 自回归生成原理

2.2 贪婪解码实现

ini 复制代码
def greedy_decode(model, src, src_vocab, tgt_vocab, max_len=20):
    """贪婪解码算法"""
    model.eval()
    src_mask = (src != 0).unsqueeze(1)  # 创建源序列掩码
    
    # 初始化解码器输入 (起始符)
    tgt = torch.ones(1, 1).fill_(tgt_vocab['<sos>']).long()
    
    # 编码器前向传播
    with torch.no_grad():
        encoder_output = model.encoder(src, src_mask)
    
    # 逐步生成序列
    for i in range(max_len):
        # 创建目标序列掩码 (下三角矩阵)
        tgt_mask = torch.tril(torch.ones(i+1, i+1)) == 0
        
        # 解码器前向传播
        with torch.no_grad():
            output = model.decoder(
                tgt, 
                encoder_output, 
                src_mask, 
                tgt_mask
            )
            
            # 获取最后一个预测词
            pred_token = output.argmax(dim=-1)[:, -1].item()
            
            # 添加到序列
            tgt = torch.cat([tgt, torch.tensor([[pred_token]])], dim=1)
            
            # 遇到终止符停止
            if pred_token == tgt_vocab['<eos>']:
                break
    
    # 转换为文本
    decoded_tokens = [tgt_vocab_inv[idx] for idx in tgt[0].tolist()]
    return ' '.join(decoded_tokens[1:-1])  # 去掉起始/终止符
# 示例使用
src_sentence = "I love machine learning"
src_tokens = [src_vocab.get(word, src_vocab['<unk>']) for word in src_sentence.split()]
src_tensor = torch.tensor([src_tokens])
translation = greedy_decode(model, src_tensor, src_vocab, tgt_vocab)
print(f"源句: {src_sentence}")
print(f"翻译: {translation}")

2.3 自回归生成可视化

ini 复制代码
import numpy as np
# 创建模拟生成过程
generation_steps = [
    ["<sos>", "", "", "", ""],
    ["<sos>", "Je", "", "", ""],
    ["<sos>", "Je", "t'aime", "", ""],
    ["<sos>", "Je", "t'aime", "l'apprentissage", ""],
    ["<sos>", "Je", "t'aime", "l'apprentissage", "automatique"],
    ["<sos>", "Je", "t'aime", "l'apprentissage", "automatique<eos>"]
]
# 可视化
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title('自回归生成过程')
ax.set_xlabel('生成步骤')
ax.set_ylabel('序列位置')
ax.set_xticks(range(len(generation_steps)))
ax.set_yticks(range(len(generation_steps[0])))
# 绘制表格
for i, step in enumerate(generation_steps):
    for j, token in enumerate(step):
        ax.text(i, j, token, ha='center', va='center', 
                bbox=dict(boxstyle='round', facecolor='lightblue' if token else 'white'))
        
        # 连接线
        if i > 0 and j < len(generation_steps[i-1]) and generation_steps[i-1][j]:
            ax.plot([i-1, i], [j, j], 'k-', lw=1)
            if j < len(step)-1 and step[j+1]:
                ax.plot([i, i], [j, j+1], 'k-', lw=1)
plt.grid(False)
plt.show()

三、Beam Search优化算法

3.1 Beam Search原理

3.2 Beam Search完整实现

ini 复制代码
def beam_search_decode(model, src, src_vocab, tgt_vocab, beam_size=3, max_len=20):
    """Beam Search解码算法"""
    model.eval()
    src_mask = (src != 0).unsqueeze(1)
    
    # 编码器前向传播
    with torch.no_grad():
        encoder_output = model.encoder(src, src_mask)
    
    # 初始化Beam
    start_token = tgt_vocab['<sos>']
    end_token = tgt_vocab['<eos>']
    
    # 初始Beam: (序列, 概率, 完成标志)
    beams = [(torch.tensor([[start_token]]), 0.0, False)]
    
    # 逐步生成
    for step in range(max_len):
        all_candidates = []
        
        # 扩展每个Beam
        for seq, score, done in beams:
            # 已完成序列直接传递
            if done:
                all_candidates.append((seq, score, True))
                continue
                
            # 创建目标序列掩码
            tgt_mask = torch.tril(torch.ones(seq.size(1), seq.size(1)) == 0
            
            # 解码器前向传播
            with torch.no_grad():
                output = model.decoder(
                    seq, 
                    encoder_output, 
                    src_mask, 
                    tgt_mask
                )
                log_probs = torch.log_softmax(output[:, -1], dim=-1)
                topk_probs, topk_tokens = log_probs.topk(beam_size, dim=-1)
            
            # 生成新候选
            for i in range(beam_size):
                token = topk_tokens[0, i].item()
                new_score = score + topk_probs[0, i].item()
                new_seq = torch.cat([seq, torch.tensor([[token]])], dim=1)
                new_done = (token == end_token) or done
                all_candidates.append((new_seq, new_score, new_done))
        
        # 按分数排序并选择Top-k
        ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
        beams = ordered[:beam_size]
        
        # 检查是否全部完成
        if all(done for _, _, done in beams):
            break
    
    # 选择最佳序列
    best_seq = beams[0][0].squeeze().tolist()
    decoded_tokens = [tgt_vocab_inv[idx] for idx in best_seq]
    return ' '.join(decoded_tokens[1:-1])  # 去掉起始/终止符
# 示例使用
translation_beam = beam_search_decode(model, src_tensor, src_vocab, tgt_vocab, beam_size=3)
print(f"Beam Search翻译: {translation_beam}")

3.3 Beam Search可视化

python 复制代码
# 创建模拟Beam Search树
beam_tree = {
    "root": {"seq": ["<sos>"], "prob": 0.0},
    "A": {"parent": "root", "seq": ["<sos>", "Je"], "prob": -0.2},
    "B": {"parent": "root", "seq": ["<sos>", "I"], "prob": -1.5},
    "C": {"parent": "root", "seq": ["<sos>", "Nous"], "prob": -2.0},
    "A1": {"parent": "A", "seq": ["<sos>", "Je", "t'aime"], "prob": -0.5},
    "A2": {"parent": "A", "seq": ["<sos>", "Je", "suis"], "prob": -1.8},
    "A3": {"parent": "A", "seq": ["<sos>", "Je", "adore"], "prob": -1.2},
    "A1a": {"parent": "A1", "seq": ["<sos>", "Je", "t'aime", "l'IA"], "prob": -0.7},
    "A1b": {"parent": "A1", "seq": ["<sos>", "Je", "t'aime", "les"], "prob": -1.5},
}
# 可视化
plt.figure(figsize=(12, 8))
ax = plt.gca()
ax.set_title('Beam Search搜索树 (beam_size=3)')
ax.set_axis_off()
# 节点位置
positions = {
    "root": (0, 0),
    "A": (1, 1), "B": (1, 0), "C": (1, -1),
    "A1": (2, 1.5), "A2": (2, 1), "A3": (2, 0.5),
    "A1a": (3, 1.7), "A1b": (3, 1.3)
}
# 绘制连接线
for node, info in beam_tree.items():
    if node != "root":
        parent = info["parent"]
        x1, y1 = positions[parent]
        x2, y2 = positions[node]
        ax.plot([x1, x2], [y1, y2], 'k-', lw=1)
        
        # 绘制节点
        seq_text = ' '.join(info["seq"])
        prob_text = f"{info['prob']:.1f}"
        ax.text(x2, y2, f"{seq_text}\n{prob_text}", 
                ha='center', va='center', 
                bbox=dict(boxstyle='round', facecolor='lightgreen' if node.startswith('A1') else 'lightblue'))
    
# 标记最终选择
ax.text(positions["A1a"][0]+0.1, positions["A1a"][1], "★", 
        fontsize=20, color='gold', ha='center', va='center')
plt.xlim(-0.5, 4)
plt.ylim(-1.5, 2)
plt.show()

Beam Search关键参数:

四、训练与推理差异对比

4.1 核心差异分析

python 复制代码
def compare_train_inference():
    """训练与推理模式差异对比"""
    print("训练模式:")
    print("- 教师强制: 使用完整目标序列作为输入")
    print("- 并行计算: 同时处理整个序列")
    print("- 梯度更新: 反向传播优化参数")
    print("- 高计算量: 需要计算所有位置")
    
    print("\n推理模式:")
    print("- 自回归生成: 逐步生成序列")
    print("- 序列依赖: 每个步骤依赖前序输出")
    print("- 无梯度计算: 只需前向传播")
    print("- 搜索策略: 使用Beam Search等优化")
# 执行对比
compare_train_inference()

4.2 性能优化策略

五、完整Transformer实现

5.1 Transformer模型定义

python 复制代码
import math
import copy
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 线性变换并分割多头
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # 应用掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        return self.W_o(output), attn_weights
class PositionWiseFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        # 自注意力 + 残差连接
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 前馈网络 + 残差连接
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        return x
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # 掩码自注意力
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 编码器-解码器注意力
        cross_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(cross_output))
        
        # 前馈网络
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_output))
        return x
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len=100):
        super().__init__()
        self.encoder_embed = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)
        
        # 位置编码
        self.position_encoding = self.create_position_encoding(max_seq_len, d_model)
        
        # 编码器
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        
        # 解码器
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        
        # 输出层
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
    
    def create_position_encoding(self, max_len, d_model):
        """创建位置编码矩阵"""
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 嵌入 + 位置编码
        src_emb = self.encoder_embed(src) + self.position_encoding[:src.size(1), :]
        tgt_emb = self.decoder_embed(tgt) + self.position_encoding[:tgt.size(1), :]
        
        # 编码器
        enc_output = src_emb
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)
        
        # 解码器
        dec_output = tgt_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)
        
        # 输出层
        return self.fc_out(dec_output)

5.2 完整训练到推理流程

ini 复制代码
# 1. 数据准备
src_vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "I": 3, "love": 4, "machine": 5, "learning": 6}
tgt_vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "Je": 3, "t'aime": 4, "l'apprentissage": 5, "automatique": 6}
tgt_vocab_inv = {v: k for k, v in tgt_vocab.items()}
# 2. 创建数据集
src_sentences = ["I love machine learning"]
tgt_sentences = ["Je t'aime l'apprentissage automatique"]
dataset = TranslationDataset(src_sentences, tgt_sentences, src_vocab, tgt_vocab)
dataloader = DataLoader(dataset, batch_size=1)
# 3. 初始化模型
model = Transformer(
    src_vocab_size=len(src_vocab),
    tgt_vocab_size=len(tgt_vocab),
    d_model=128,
    num_heads=8,
    num_layers=3,
    d_ff=512
)
# 4. 训练模型
model = train_transformer(model, dataloader, epochs=10, lr=0.0001)
# 5. 推理生成
src_tensor = torch.tensor([[src_vocab["I"], src_vocab["love"], src_vocab["machine"], src_vocab["learning"]]])
greedy_result = greedy_decode(model, src_tensor, src_vocab, tgt_vocab)
beam_result = beam_search_decode(model, src_tensor, src_vocab, tgt_vocab, beam_size=3)
print(f"贪婪解码结果: {greedy_result}")
print(f"Beam Search结果: {beam_result}")

六、高级推理优化技术

6.1 KV缓存优化

python 复制代码
class DecoderWithCache(nn.Module):
    """带KV缓存的解码器优化"""
    def __init__(self, decoder_layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([decoder_layer for _ in range(num_layers)])
        self.cache = None
    
    def init_cache(self, batch_size, max_len):
        """初始化缓存"""
        self.cache = [{
            'k': torch.zeros(batch_size, max_len, self.layers[0].d_model),
            'v': torch.zeros(batch_size, max_len, self.layers[0].d_model)
        } for _ in range(len(self.layers))]
    
    def forward(self, x, encoder_output, step=0):
        """带缓存的推理前向传播"""
        if self.cache is None:
            self.init_cache(x.size(0), 100)  # 初始化缓存
        
        for i, layer in enumerate(self.layers):
            # 更新缓存
            self.cache[i]['k'][:, step:step+1] = layer.k_proj(x)
            self.cache[i]['v'][:, step:step+1] = layer.v_proj(x)
            
            # 使用缓存计算注意力
            k = self.cache[i]['k'][:, :step+1]
            v = self.cache[i]['v'][:, :step+1]
            x = layer.attention(x, k, v)
            
            # 后续计算...
        
        return x

6.2 混合精度推理

python 复制代码
from torch.cuda.amp import autocast
def generate_with_amp(model, src):
    """混合精度推理"""
    model.eval()
    with torch.no_grad():
        with autocast():
            output = model(src)
    return output

6.3 量化推理加速

ini 复制代码
# 训练后动态量化
quantized_model = torch.quantization.quantize_dynamic(
    model, 
    {nn.Linear}, 
    dtype=torch.qint8
)
# 保存量化模型
torch.save(quantized_model.state_dict(), "quantized_transformer.pth")

关键要点总结

训练核心流程:

scss 复制代码
for epoch in range(epochs):
    for batch in dataloader:
        # 前向传播
        pred = model(src, tgt_input)
        loss = criterion(pred, tgt_output)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

自回归生成步骤:

css 复制代码
while not end_condition:
    输入 = 当前序列
    输出 = model(输入)
    新词 = argmax(输出[-1])
    序列 = 序列 + 新词

Beam Search伪代码:

ini 复制代码
初始化: beams = [(<sos>, 0.0)]
for step in range(max_len):
    候选列表 = []
    for beam in beams:
        扩展候选 = beam 扩展 top_k 个词
        候选列表 += 扩展候选
    beams = 候选列表中分数最高的k个
return beams[0]  # 最佳序列

性能优化对比:

通过掌握Transformer的训练与推理全流程,你将能够高效开发和生产部署各类大语言模型,为构建实际AI应用奠定坚实基础!更多AI大模型应用开发学习视频内容和资料,尽在聚客AI学院

相关推荐
舒一笑1 小时前
基础RAG实现,最佳入门选择(三)
人工智能
知识趣动1 小时前
AI 入门启航:了解什么 AI
人工智能
rocksun4 小时前
认识Embabel:一个使用Java构建AI Agent的框架
java·人工智能
Java中文社群5 小时前
AI实战:一键生成数字人视频!
java·人工智能·后端
AI大模型技术社5 小时前
🔧 PyTorch高阶开发工具箱:自定义模块+损失函数+部署流水线完整实现
人工智能·pytorch
LLM大模型5 小时前
LangChain篇-基于SQL实现数据分析问答
人工智能·程序员·llm
LLM大模型5 小时前
LangChain篇-整合维基百科实现网页问答
人工智能·程序员·llm
DeepSeek忠实粉丝5 小时前
微调篇--基于GPT定制化微调训练
人工智能·程序员·llm
神经星星7 小时前
从石英到铁电材料,哈佛大学提出等变机器学习框架,加速材料大规模电场模拟
人工智能·深度学习·机器学习