从零实现Transformer:第 9 部分 - 推理(Inference )
推理流程
- 定义测试序列
给模型一个待复制的输入:[SOS,1,7,0,9,5,EOS](和训练数据格式完全一致)。 - 文本转张量(格式转换)
模型不认识文字,只认识数字:
把文字转成数字ID(tokenize)
填充到固定长度(和训练对齐)
变成 PyTorch 张量(形状:[1, 序列长度],1代表1个样本) - 生成源掩码
告诉模型:填充的PAD是无效内容,不要关注。 - 获取首尾标记ID
解码必须以[SOS]开头,以[EOS]结束,提前准备好它们的数字ID。 - 调用贪心解码函数
把所有准备好的张量传给解码函数,让模型自动生成序列。 - 张量转回文本
把模型输出的数字ID,还原成人类能看懂的文字,过滤掉无用的PAD。
文本转张量 → 编码器编码一次 → 解码器从 SOS 开始 → 循环逐词贪心生成 → 遇到 EOS 停止 → 张量转回文本
整个推理分为 3大阶段:
- 外部准备:把人类能看懂的文本 → 模型能看懂的张量
- 解码 :调用
greedy_decode,编码器编码1次,解码器循环逐词生成 - 结果还原:把模型输出的张量 → 人类能看懂的文本
贪心解码(Greedy Decoding)
在 Transformer / 大模型逐词生成 的时候:
每一步,只选当前概率最大的那个词,直接定下,不看未来、不考虑其他可能 ,这就叫贪心解码。
假设模型现在要选词,候选和概率是:
5:40%
7:30%
0:20%
9:10%
贪心做法 :
直接选 概率最高的 5 ,不管后面会不会组合出更好的句子,只顾当下最优。
就像走路:
每一步只选眼前最近的路,不全局规划,走到哪算哪。
自回归生成时:
- 模型每一步输出一个词表概率分布
- 直接取概率最大的 token 作为下一个词
- 把这个词塞回输入,继续生成下一个
- 直到出现
EOS或达到最大长度
python
def greedy_decode(...):
# 1. 模型切换推理模式
model.eval()
# 2. 数据移到GPU/CPU
src = src.to(device)
src_mask = src_mask.to(device)
# 3. 关闭梯度计算(推理不需要反向传播)
with torch.no_grad():
# 4. 【编码器只执行1次!】对输入序列编码,得到全局语义特征
encoder_output = model.encode(src, src_mask)
# 5. 初始化解码器输入:只有[SOS](形状:[1,1])
decoder_input = torch.tensor([[sos_id]], device=device)
# 6. 【循环自回归生成】逐词生成,直到最大长度/遇到EOS
for _ in range(max_len - 1):
# 6.1 生成目标掩码:禁止解码器看到未来的token
tgt_mask = create_tgt_mask(decoder_input, PAD_ID)
# 6.2 解码器推理:输入已生成的序列 + 编码器的语义特征
decoder_output = model.decode(encoder_output, src_mask, decoder_input, tgt_mask)
# 6.3 只取最后一个词的输出(预测下一个词)
logits = model.project(decoder_output[:, -1:])
# 6.4 贪心选择:选概率最高的token
_, next_token_id = torch.max(logits, dim=-1)
# 6.5 把新生成的token拼接到输入(序列长度+1)
decoder_input = torch.cat([decoder_input, next_token_id], dim=1)
# 6.6 终止条件:生成EOS就停止
if next_token_id.item() == eos_id:
break
# 7. 返回最终生成的完整序列
return decoder_input
内部流程
1. 推理模式准备
model.eval():关闭Dropout、BatchNorm等训练层,保证推理结果稳定。
数据移设备:把张量放到和模型一样的设备(CPU/GPU)。
2. 编码器:只执行1次!
输入序列是固定的,编码器只需要编码一次,得到输入序列的语义特征,全程复用。
3. 解码器初始化
解码器必须从 [SOS] 开始生成,这是训练时约定的规则。
4. 循环生成(核心:自回归贪心)
循环里每一轮只生成1个词:
- 目标掩码:防止解码器偷看未来还没生成的词(生成任务的铁律)。
- 解码器推理:用已经生成的序列,结合编码器的语义,计算下一个词的概率。
- 取最后一个词:只需要最新的隐状态预测下一个词,前面的历史不用管。
- 贪心选择:选概率最高的词(这就是贪心解码)。
- 拼接输入:把新生成的词加到解码器输入里,下一轮继续生成。
- 提前终止 :生成
[EOS]就结束,不用跑完所有循环。
5. 返回结果
返回生成好的数字ID序列,交给外部代码转成文本。
Logits 与 Token ID的关系
投影层负责将解码器输出的 d_model 维隐状态向量,通过线性变换映射到目标词汇表维度;其直接输出为各候选单词的原始预测分值(Logits),该分值无数值范围约束,不满足概率的归一化条件,并非合法概率分布。后续需对 Logits 施加 Softmax 归一化运算,才可将其转换为取值介于 ([0,1])([0,1])([0,1])、总和为 1 的词汇条件概率分布,表征当前位置生成各个单词的置信度。
这里是贪心解码,所以没有施加 Softmax。
Logits 以英译汉任务 为例子
第一步:先定「中文目标词表」
人为提前规定好:数组下标 = Token ID
| 数组下标(索引) | Token ID | 对应汉字/特殊符号 |
|---|---|---|
| 0 | 0 | <PAD> 填充符 |
| 1 | 1 | <SOS> 句子开始 |
| 2 | 2 | <EOS> 句子结束 |
| 3 | 3 | 我 |
| 4 | 4 | 和 |
| 5 | 5 | 你 |
词表总大小 vocab_size = 6
所以投影层最后输出维度一定是 6
第二步:模型输出的原始 Logits(某一个生成时刻)
解码器算出隐向量 → 过投影层线性变换
直接得到 6 个原始打分 Logits(可正可负、没有大小限制、不是概率)
举例某一时刻:
Logits = [-2.1, -5.3, -6.0, 1.2, 0.8, 3.5]
对应关系严格对齐:
Logits[0] → ID=0 <PAD> 的打分
Logits[1] → ID=1 <SOS> 的打分
Logits[2] → ID=2 <EOS> 的打分
Logits[3] → ID=3「我」的打分
Logits[4] → ID=4「和」的打分
Logits[5] → ID=5「你」的打分
第三步:Logits 过 Softmax 变成合法概率(贪心解码省略)
对上面一组 logits 做 Softmax,得到概率分布(总和=1,数值 0~1):
直接给整理好易懂结果:
Prob = [0.000011, 0.000129, 0.000064, 0.08559, 0.057373, 0.853688]
能直观看出:
第5个位置概率最高 = 0.853688
第四步:argmax 取索引
对概率分布执行 torch.argmax()
得到:索引 = 5
第五步:索引5就是中文Token ID
因为从模型设计之初就做了强制一一映射 :
Logits 数组的每一位位置下标 ,和目标词表预先定义的 Token ID 编号完全顺序对齐
所以:索引 5 → 直接对应 Token ID = 5,再查表 → ID5 = 汉字「你」
即
- 投影层输出一组和词表长度一致的 Logits原始打分
- 经Softmax转为各汉字的生成概率
- 取概率最大的数组索引
- 该索引天生等价于预先定义好的目标Token ID
- 用ID反向查表,解码出最终可读汉字
完整代码
cpp
import torch
import torch.nn as nn
import math
import random
import copy
import time
from torch.utils.data import Dataset, DataLoader
# ====================== 1. Transformer 基础组件 ======================
class InputEmbeddings(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
return self.dropout(x)
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
d_k = query.shape[-1]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn = torch.softmax(scores, dim=-1)
if dropout is not None:
attn = dropout(attn)
return torch.matmul(attn, value), attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.attention_weights = None
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask, self.dropout)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.W_o(output)
class LayerNormalization(nn.Module):
def __init__(self, features: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
def forward(self, x: torch.Tensor):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True, unbiased=False)
normalized = (x - mean) / torch.sqrt(std ** 2 + self.eps)
return self.gamma * normalized + self.beta
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.linear_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x):
return self.linear_2(self.dropout(self.activation(self.linear_1(x))))
class ResidualConnection(nn.Module):
def __init__(self, features: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization(features)
def forward(self, x, sublayer):
return self.norm(x + self.dropout(sublayer(x)))
# ====================== 2. Encoder & Decoder 模块 ======================
class EncoderBlock(nn.Module):
def __init__(self, features: int, self_attention_block: MultiHeadAttention,
feed_forward_block: PositionwiseFeedForward, dropout: float):
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
def forward(self, x, src_mask):
x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, src_mask))
x = self.residual_connections[1](x, self.feed_forward_block)
return x
class Encoder(nn.Module):
def __init__(self, features: int, layers: nn.ModuleList):
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class DecoderBlock(nn.Module):
def __init__(self, features: int, self_attention_block: MultiHeadAttention,
cross_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, dropout: float):
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])
def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, tgt_mask))
x = self.residual_connections[1](x, lambda x_res:
self.cross_attention_block(x_res, encoder_output, encoder_output, src_mask))
x = self.residual_connections[2](x, self.feed_forward_block)
return x
class Decoder(nn.Module):
def __init__(self, features: int, layers: nn.ModuleList):
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)
def forward(self, x, encoder_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)
class ProjectionLayer(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, x):
return self.proj(x)
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer
def encode(self, src, src_mask):
src = self.src_embed(src)
src = self.src_pos(src)
return self.encoder(src, src_mask)
def decode(self, encoder_output, src_mask, tgt, tgt_mask):
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
def project(self, x):
return self.projection_layer(x)
def forward(self, src, tgt, src_mask, tgt_mask):
enc_out = self.encode(src, src_mask)
dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
return self.project(dec_out)
# ====================== 3. 数据处理模块 ======================
# 特殊标记
PAD_TOKEN = '[PAD]'
SOS_TOKEN = '[SOS]'
EOS_TOKEN = '[EOS]'
# 词汇表:特殊标记 + 数字0-9
VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + [str(i) for i in range(10)]
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for token, i in token_to_id.items()}
VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id[PAD_TOKEN]
# 生成复制任务数据集
def generate_copy_task_data(num_examples: int, min_len: int, max_len: int):
data = []
content_vocab = [token for token in VOCAB if token not in [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]]
for _ in range(num_examples):
seq_len = random.randint(min_len, max_len)
sequence = [random.choice(content_vocab) for _ in range(seq_len)]
src = [SOS_TOKEN] + sequence + [EOS_TOKEN]
tgt = [SOS_TOKEN] + sequence + [EOS_TOKEN]
data.append({'src': src, 'tgt': tgt})
return data
# 序列编码与填充
def tokenize_sequence(sequence, token_to_id_map):
return [token_to_id_map[token] for token in sequence]
def pad_sequence(sequence_ids, max_len, pad_id):
padded_ids = sequence_ids + [pad_id] * (max_len - len(sequence_ids))
return padded_ids[:max_len]
# 掩码生成
def create_src_mask(src_ids, pad_id):
return (src_ids == pad_id).unsqueeze(1).unsqueeze(2)
def create_tgt_mask(tgt_ids, pad_id):
batch_size, tgt_seq_len = tgt_ids.shape
tgt_padding_mask = (tgt_ids == pad_id).unsqueeze(1).unsqueeze(2)
look_ahead_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len, device=tgt_ids.device), diagonal=1).bool()
look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
return tgt_padding_mask | look_ahead_mask
# 自定义数据集
class CopyTaskDataset(Dataset):
def __init__(self, data, max_len, pad_id):
self.data = data
self.max_len = max_len
self.pad_id = pad_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
src_ids = pad_sequence(tokenize_sequence(item['src'], token_to_id), self.max_len, self.pad_id)
tgt_ids = tokenize_sequence(item['tgt'], token_to_id)
decoder_input_ids = pad_sequence(tgt_ids[:-1], self.max_len, self.pad_id)
label_ids = pad_sequence(tgt_ids[1:], self.max_len, self.pad_id)
return {
"src_ids": torch.tensor(src_ids, dtype=torch.long),
"decoder_input_ids": torch.tensor(decoder_input_ids, dtype=torch.long),
"label_ids": torch.tensor(label_ids, dtype=torch.long)
}
# ====================== 4. 超参数配置(集中管理) ======================
CONFIG = {
"num_examples": 1000, # 数据量
"min_seq_len": 5, # 最小序列长度
"max_seq_len": 10, # 最大序列长度
"batch_size": 64, # 批次大小
"d_model": 128, # 模型维度
"num_layers": 3, # 编码器/解码器层数
"num_heads": 4, # 注意力头数
"d_ff": 512, # 前馈网络维度
"dropout": 0.1, # Dropout概率
"lr": 1e-3, # 学习率
"epochs": 40, # 训练轮数
}
MAX_PADDED_LEN = CONFIG["max_seq_len"] + 2 # 含首尾标记的最大长度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ====================== 5. 模型初始化 ======================
def build_transformer(config):
# 嵌入层
src_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
tgt_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
# 位置编码
src_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
tgt_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
# 注意力与前馈网络
attention = MultiHeadAttention(config["d_model"], config["num_heads"], config["dropout"])
ff = PositionwiseFeedForward(config["d_model"], config["d_ff"], config["dropout"])
# 编码器
encoder_blocks = nn.ModuleList([
EncoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(ff), config["dropout"])
for _ in range(config["num_layers"])
])
encoder = Encoder(config["d_model"], encoder_blocks)
# 解码器
decoder_blocks = nn.ModuleList([
DecoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(attention),
copy.deepcopy(ff), config["dropout"])
for _ in range(config["num_layers"])
])
decoder = Decoder(config["d_model"], decoder_blocks)
# 投影层
projection = ProjectionLayer(config["d_model"], VOCAB_SIZE)
# 组装模型
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection)
# 参数初始化
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return transformer
# ====================== 6. 训练函数 ======================
def train_one_epoch(model, dataloader, loss_fn, optimizer, pad_id):
model.train()
total_loss = 0
start_time = time.time()
for batch_idx, batch in enumerate(dataloader):
# 数据移至设备
src_ids = batch['src_ids'].to(device)
decoder_input_ids = batch['decoder_input_ids'].to(device)
label_ids = batch['label_ids'].to(device)
# 生成掩码
src_mask = create_src_mask(src_ids, pad_id).to(device)
tgt_mask = create_tgt_mask(decoder_input_ids, pad_id).to(device)
# 前向传播
optimizer.zero_grad()
logits = model(src_ids, decoder_input_ids, src_mask, tgt_mask)
# 计算损失
loss = loss_fn(logits.reshape(-1, logits.size(-1)), label_ids.reshape(-1))
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
# 打印进度
if (batch_idx + 1) % 10 == 0:
elapsed = time.time() - start_time
print(f" 批次 {batch_idx+1}/{len(dataloader)} | 损失: {loss.item():.4f} | 耗时: {elapsed:.2f}s")
start_time = time.time()
return total_loss / len(dataloader)
# ====================== 通用贪心解码函数 ======================
def greedy_decode(model: Transformer, src: torch.Tensor, src_mask: torch.Tensor, max_len: int, sos_id: int, eos_id: int, device: torch.device):
model.eval()
src = src.to(device)
src_mask = src_mask.to(device)
with torch.no_grad():
encoder_output = model.encode(src, src_mask)
decoder_input = torch.tensor([[sos_id]], dtype=torch.long, device=device)
for _ in range(max_len - 1):
tgt_mask = create_tgt_mask(decoder_input, PAD_ID).to(device)
decoder_output = model.decode(encoder_output, src_mask, decoder_input, tgt_mask)
logits = model.project(decoder_output[:, -1:])
_, next_token_id = torch.max(logits, dim=-1)
decoder_input = torch.cat([decoder_input, next_token_id], dim=1)
if next_token_id.item() == eos_id:
break
return decoder_input
# ====================== 8. 主程序:数据加载 + 训练 + 测试 ======================
if __name__ == "__main__":
# 1. 生成数据集
print("===== 生成复制任务数据集 =====")
raw_data = generate_copy_task_data(
CONFIG["num_examples"],
CONFIG["min_seq_len"],
CONFIG["max_seq_len"]
)
# 数据加载器
train_dataset = CopyTaskDataset(raw_data, MAX_PADDED_LEN, PAD_ID)
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
# 2. 初始化模型、优化器、损失函数
model = build_transformer(CONFIG).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], betas=(0.9, 0.98), eps=1e-9)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
# 3. 开始训练
print(f"\n===== 开始训练 | 设备: {device} =====")
for epoch in range(1, CONFIG["epochs"] + 1):
epoch_start = time.time()
avg_loss = train_one_epoch(model, train_dataloader, loss_fn, optimizer, PAD_ID)
duration = time.time() - epoch_start
print(f"\n【第{epoch}/{CONFIG['epochs']}轮】平均损失: {avg_loss:.4f} | 总耗时: {duration:.2f}s")
print("-" * 60)
# 4. 保存模型
torch.save(model.state_dict(), "transformer_copy_task.pth")
print("\n模型已保存为: transformer_copy_task.pth")
# ====================== 推理测试(使用greedy_decode) ======================
print("\n===== 推理测试(复制任务) =====")
# 1. 定义测试输入
test_seq = [SOS_TOKEN, '1', '7', '0', '9', '5', EOS_TOKEN]
print(f"输入序列: {test_seq}")
# 2. 数据预处理(tokenize + padding + 转张量)
src_ids_inf = torch.tensor(
[pad_sequence(tokenize_sequence(test_seq, token_to_id), MAX_PADDED_LEN, PAD_ID)],
dtype=torch.long
)
# 3. 生成源掩码
src_mask_inf = create_src_mask(src_ids_inf, PAD_ID)
# 4. 获取SOS/EOS ID
sos_id = token_to_id[SOS_TOKEN]
eos_id = token_to_id[EOS_TOKEN]
# 5. 调用通用贪心解码函数
generated_ids = greedy_decode(
model=model,
src=src_ids_inf,
src_mask=src_mask_inf,
max_len=MAX_PADDED_LEN,
sos_id=sos_id,
eos_id=eos_id,
device=device
)
# 6. 后处理:ID转文本
generated_tokens = [id_to_token[idx] for idx in generated_ids[0].cpu().numpy() if idx != PAD_ID]
print(f"模型输出: {generated_tokens}")