从零实现Transformer:第 9 部分 - 推理(Inference )

从零实现Transformer:第 9 部分 - 推理(Inference )

推理流程

  1. 定义测试序列
    给模型一个待复制的输入:[SOS,1,7,0,9,5,EOS](和训练数据格式完全一致)。
  2. 文本转张量(格式转换)
    模型不认识文字,只认识数字:
    把文字转成数字ID(tokenize)
    填充到固定长度(和训练对齐)
    变成 PyTorch 张量(形状:[1, 序列长度],1代表1个样本)
  3. 生成源掩码
    告诉模型:填充的PAD是无效内容,不要关注
  4. 获取首尾标记ID
    解码必须以 [SOS] 开头,以 [EOS] 结束,提前准备好它们的数字ID。
  5. 调用贪心解码函数
    把所有准备好的张量传给解码函数,让模型自动生成序列
  6. 张量转回文本
    把模型输出的数字ID,还原成人类能看懂的文字,过滤掉无用的PAD。

文本转张量 → 编码器编码一次 → 解码器从 SOS 开始 → 循环逐词贪心生成 → 遇到 EOS 停止 → 张量转回文本

整个推理分为 3大阶段

  1. 外部准备:把人类能看懂的文本 → 模型能看懂的张量
  2. 解码 :调用 greedy_decode,编码器编码1次,解码器循环逐词生成
  3. 结果还原:把模型输出的张量 → 人类能看懂的文本

贪心解码(Greedy Decoding)

在 Transformer / 大模型逐词生成 的时候:
每一步,只选当前概率最大的那个词,直接定下,不看未来、不考虑其他可能 ,这就叫贪心解码

假设模型现在要选词,候选和概率是:
5:40%
7:30%
0:20%
9:10%

贪心做法

直接选 概率最高的 5 ,不管后面会不会组合出更好的句子,只顾当下最优

就像走路:

每一步只选眼前最近的路,不全局规划,走到哪算哪。

自回归生成时:

  1. 模型每一步输出一个词表概率分布
  2. 直接取概率最大的 token 作为下一个词
  3. 把这个词塞回输入,继续生成下一个
  4. 直到出现 EOS 或达到最大长度
python 复制代码
def greedy_decode(...):
    # 1. 模型切换推理模式
    model.eval()
    # 2. 数据移到GPU/CPU
    src = src.to(device)
    src_mask = src_mask.to(device)

    # 3. 关闭梯度计算(推理不需要反向传播)
    with torch.no_grad():
        # 4. 【编码器只执行1次!】对输入序列编码,得到全局语义特征
        encoder_output = model.encode(src, src_mask)

        # 5. 初始化解码器输入:只有[SOS](形状:[1,1])
        decoder_input = torch.tensor([[sos_id]], device=device)

        # 6. 【循环自回归生成】逐词生成,直到最大长度/遇到EOS
        for _ in range(max_len - 1):
            # 6.1 生成目标掩码:禁止解码器看到未来的token
            tgt_mask = create_tgt_mask(decoder_input, PAD_ID)
            # 6.2 解码器推理:输入已生成的序列 + 编码器的语义特征
            decoder_output = model.decode(encoder_output, src_mask, decoder_input, tgt_mask)
            # 6.3 只取最后一个词的输出(预测下一个词)
            logits = model.project(decoder_output[:, -1:])
            # 6.4 贪心选择:选概率最高的token
            _, next_token_id = torch.max(logits, dim=-1)
            # 6.5 把新生成的token拼接到输入(序列长度+1)
            decoder_input = torch.cat([decoder_input, next_token_id], dim=1)
            # 6.6 终止条件:生成EOS就停止
            if next_token_id.item() == eos_id:
                break

    # 7. 返回最终生成的完整序列
    return decoder_input

内部流程

1. 推理模式准备

model.eval():关闭Dropout、BatchNorm等训练层,保证推理结果稳定。

数据移设备:把张量放到和模型一样的设备(CPU/GPU)。

2. 编码器:只执行1次!

输入序列是固定的,编码器只需要编码一次,得到输入序列的语义特征,全程复用。

3. 解码器初始化

解码器必须从 [SOS] 开始生成,这是训练时约定的规则。

4. 循环生成(核心:自回归贪心)

循环里每一轮只生成1个词

  1. 目标掩码:防止解码器偷看未来还没生成的词(生成任务的铁律)。
  2. 解码器推理:用已经生成的序列,结合编码器的语义,计算下一个词的概率。
  3. 取最后一个词:只需要最新的隐状态预测下一个词,前面的历史不用管。
  4. 贪心选择:选概率最高的词(这就是贪心解码)。
  5. 拼接输入:把新生成的词加到解码器输入里,下一轮继续生成。
  6. 提前终止 :生成 [EOS] 就结束,不用跑完所有循环。

5. 返回结果

返回生成好的数字ID序列,交给外部代码转成文本。

Logits 与 Token ID的关系

投影层负责将解码器输出的 d_model 维隐状态向量,通过线性变换映射到目标词汇表维度;其直接输出为各候选单词的原始预测分值(Logits),该分值无数值范围约束,不满足概率的归一化条件,并非合法概率分布。后续需对 Logits 施加 Softmax 归一化运算,才可将其转换为取值介于 ([0,1])([0,1])([0,1])、总和为 1 的词汇条件概率分布,表征当前位置生成各个单词的置信度。

这里是贪心解码,所以没有施加 Softmax。

Logits 以英译汉任务 为例子

第一步:先定「中文目标词表」

人为提前规定好:数组下标 = Token ID

数组下标(索引) Token ID 对应汉字/特殊符号
0 0 <PAD> 填充符
1 1 <SOS> 句子开始
2 2 <EOS> 句子结束
3 3
4 4
5 5

词表总大小 vocab_size = 6

所以投影层最后输出维度一定是 6

第二步:模型输出的原始 Logits(某一个生成时刻)

解码器算出隐向量 → 过投影层线性变换

直接得到 6 个原始打分 Logits(可正可负、没有大小限制、不是概率)

举例某一时刻:

复制代码
Logits = [-2.1,  -5.3,  -6.0,  1.2,  0.8,  3.5]

对应关系严格对齐:

Logits[0] → ID=0 <PAD> 的打分

Logits[1] → ID=1 <SOS> 的打分

Logits[2] → ID=2 <EOS> 的打分

Logits[3] → ID=3「我」的打分

Logits[4] → ID=4「和」的打分

Logits[5] → ID=5「你」的打分

第三步:Logits 过 Softmax 变成合法概率(贪心解码省略)

对上面一组 logits 做 Softmax,得到概率分布(总和=1,数值 0~1):

直接给整理好易懂结果:

复制代码
Prob = [0.000011,  0.000129,  0.000064,  0.08559,  0.057373,  0.853688]

能直观看出:
第5个位置概率最高 = 0.853688


第四步:argmax 取索引

对概率分布执行 torch.argmax()

得到:索引 = 5

第五步:索引5就是中文Token ID

因为从模型设计之初就做了强制一一映射

Logits 数组的每一位位置下标 ,和目标词表预先定义的 Token ID 编号完全顺序对齐

所以:索引 5 → 直接对应 Token ID = 5,再查表 → ID5 = 汉字「你」

  1. 投影层输出一组和词表长度一致的 Logits原始打分
  2. 经Softmax转为各汉字的生成概率
  3. 取概率最大的数组索引
  4. 该索引天生等价于预先定义好的目标Token ID
  5. 用ID反向查表,解码出最终可读汉字

完整代码

cpp 复制代码
import torch
import torch.nn as nn
import math
import random
import copy
import time
from torch.utils.data import Dataset, DataLoader

# ====================== 1. Transformer 基础组件 ======================
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)

def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    d_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    
    return torch.matmul(attn, value), attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask, self.dropout)

        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(output)

class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))

    def forward(self, x: torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True, unbiased=False)
        normalized = (x - mean) / torch.sqrt(std ** 2 + self.eps)
        return self.gamma * normalized + self.beta

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.linear_2(self.dropout(self.activation(self.linear_1(x))))

class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return self.norm(x + self.dropout(sublayer(x)))

# ====================== 2. Encoder & Decoder 模块 ======================
class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention, 
                 feed_forward_block: PositionwiseFeedForward, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention,
                 cross_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, tgt_mask))
        x = self.residual_connections[1](x, lambda x_res: 
            self.cross_attention_block(x_res, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        return self.projection_layer(x)

    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
        return self.project(dec_out)

# ====================== 3. 数据处理模块 ======================
# 特殊标记
PAD_TOKEN = '[PAD]'
SOS_TOKEN = '[SOS]'
EOS_TOKEN = '[EOS]'
# 词汇表:特殊标记 + 数字0-9
VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + [str(i) for i in range(10)]
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for token, i in token_to_id.items()}
VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id[PAD_TOKEN]

# 生成复制任务数据集
def generate_copy_task_data(num_examples: int, min_len: int, max_len: int):
    data = []
    content_vocab = [token for token in VOCAB if token not in [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]]
    for _ in range(num_examples):
        seq_len = random.randint(min_len, max_len)
        sequence = [random.choice(content_vocab) for _ in range(seq_len)]
        src = [SOS_TOKEN] + sequence + [EOS_TOKEN]
        tgt = [SOS_TOKEN] + sequence + [EOS_TOKEN]
        data.append({'src': src, 'tgt': tgt})
    return data

# 序列编码与填充
def tokenize_sequence(sequence, token_to_id_map):
    return [token_to_id_map[token] for token in sequence]

def pad_sequence(sequence_ids, max_len, pad_id):
    padded_ids = sequence_ids + [pad_id] * (max_len - len(sequence_ids))
    return padded_ids[:max_len]

# 掩码生成
def create_src_mask(src_ids, pad_id):
    return (src_ids == pad_id).unsqueeze(1).unsqueeze(2)

def create_tgt_mask(tgt_ids, pad_id):
    batch_size, tgt_seq_len = tgt_ids.shape
    tgt_padding_mask = (tgt_ids == pad_id).unsqueeze(1).unsqueeze(2)
    look_ahead_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len, device=tgt_ids.device), diagonal=1).bool()
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
    return tgt_padding_mask | look_ahead_mask

# 自定义数据集
class CopyTaskDataset(Dataset):
    def __init__(self, data, max_len, pad_id):
        self.data = data
        self.max_len = max_len
        self.pad_id = pad_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        src_ids = pad_sequence(tokenize_sequence(item['src'], token_to_id), self.max_len, self.pad_id)
        tgt_ids = tokenize_sequence(item['tgt'], token_to_id)
        
        decoder_input_ids = pad_sequence(tgt_ids[:-1], self.max_len, self.pad_id)
        label_ids = pad_sequence(tgt_ids[1:], self.max_len, self.pad_id)

        return {
            "src_ids": torch.tensor(src_ids, dtype=torch.long),
            "decoder_input_ids": torch.tensor(decoder_input_ids, dtype=torch.long),
            "label_ids": torch.tensor(label_ids, dtype=torch.long)
        }

# ====================== 4. 超参数配置(集中管理) ======================
CONFIG = {
    "num_examples": 1000,    # 数据量
    "min_seq_len": 5,        # 最小序列长度
    "max_seq_len": 10,       # 最大序列长度
    "batch_size": 64,        # 批次大小
    "d_model": 128,          # 模型维度
    "num_layers": 3,         # 编码器/解码器层数
    "num_heads": 4,          # 注意力头数
    "d_ff": 512,             # 前馈网络维度
    "dropout": 0.1,          # Dropout概率
    "lr": 1e-3,              # 学习率
    "epochs": 40,            # 训练轮数
}
MAX_PADDED_LEN = CONFIG["max_seq_len"] + 2  # 含首尾标记的最大长度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================== 5. 模型初始化 ======================
def build_transformer(config):
    # 嵌入层
    src_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
    tgt_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
    # 位置编码
    src_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
    tgt_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
    # 注意力与前馈网络
    attention = MultiHeadAttention(config["d_model"], config["num_heads"], config["dropout"])
    ff = PositionwiseFeedForward(config["d_model"], config["d_ff"], config["dropout"])
    # 编码器
    encoder_blocks = nn.ModuleList([
        EncoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(ff), config["dropout"]) 
        for _ in range(config["num_layers"])
    ])
    encoder = Encoder(config["d_model"], encoder_blocks)
    # 解码器
    decoder_blocks = nn.ModuleList([
        DecoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(attention), 
                     copy.deepcopy(ff), config["dropout"]) 
        for _ in range(config["num_layers"])
    ])
    decoder = Decoder(config["d_model"], decoder_blocks)
    # 投影层
    projection = ProjectionLayer(config["d_model"], VOCAB_SIZE)
    # 组装模型
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection)
    # 参数初始化
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return transformer

# ====================== 6. 训练函数 ======================
def train_one_epoch(model, dataloader, loss_fn, optimizer, pad_id):
    model.train()
    total_loss = 0
    start_time = time.time()

    for batch_idx, batch in enumerate(dataloader):
        # 数据移至设备
        src_ids = batch['src_ids'].to(device)
        decoder_input_ids = batch['decoder_input_ids'].to(device)
        label_ids = batch['label_ids'].to(device)

        # 生成掩码
        src_mask = create_src_mask(src_ids, pad_id).to(device)
        tgt_mask = create_tgt_mask(decoder_input_ids, pad_id).to(device)

        # 前向传播
        optimizer.zero_grad()
        logits = model(src_ids, decoder_input_ids, src_mask, tgt_mask)
        
        # 计算损失
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), label_ids.reshape(-1))
        # 反向传播
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # 打印进度
        if (batch_idx + 1) % 10 == 0:
            elapsed = time.time() - start_time
            print(f"  批次 {batch_idx+1}/{len(dataloader)} | 损失: {loss.item():.4f} | 耗时: {elapsed:.2f}s")
            start_time = time.time()

    return total_loss / len(dataloader)

# ====================== 通用贪心解码函数 ======================
def greedy_decode(model: Transformer, src: torch.Tensor, src_mask: torch.Tensor, max_len: int, sos_id: int, eos_id: int, device: torch.device):
    model.eval()
    src = src.to(device)
    src_mask = src_mask.to(device)

    with torch.no_grad():
        encoder_output = model.encode(src, src_mask)
        decoder_input = torch.tensor([[sos_id]], dtype=torch.long, device=device)

        for _ in range(max_len - 1):
            tgt_mask = create_tgt_mask(decoder_input, PAD_ID).to(device)
            decoder_output = model.decode(encoder_output, src_mask, decoder_input, tgt_mask)
            logits = model.project(decoder_output[:, -1:])
            _, next_token_id = torch.max(logits, dim=-1)
            decoder_input = torch.cat([decoder_input, next_token_id], dim=1)

            if next_token_id.item() == eos_id:
                break

    return decoder_input

# ====================== 8. 主程序:数据加载 + 训练 + 测试 ======================
if __name__ == "__main__":
    # 1. 生成数据集
    print("===== 生成复制任务数据集 =====")
    raw_data = generate_copy_task_data(
        CONFIG["num_examples"],
        CONFIG["min_seq_len"],
        CONFIG["max_seq_len"]
    )
    # 数据加载器
    train_dataset = CopyTaskDataset(raw_data, MAX_PADDED_LEN, PAD_ID)
    train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)

    # 2. 初始化模型、优化器、损失函数
    model = build_transformer(CONFIG).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], betas=(0.9, 0.98), eps=1e-9)
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

    # 3. 开始训练
    print(f"\n===== 开始训练 | 设备: {device} =====")
    for epoch in range(1, CONFIG["epochs"] + 1):
        epoch_start = time.time()
        avg_loss = train_one_epoch(model, train_dataloader, loss_fn, optimizer, PAD_ID)
        duration = time.time() - epoch_start
        print(f"\n【第{epoch}/{CONFIG['epochs']}轮】平均损失: {avg_loss:.4f} | 总耗时: {duration:.2f}s")
        print("-" * 60)

    # 4. 保存模型
    torch.save(model.state_dict(), "transformer_copy_task.pth")
    print("\n模型已保存为: transformer_copy_task.pth")

    # ====================== 推理测试(使用greedy_decode) ======================
    print("\n===== 推理测试(复制任务) =====")
    # 1. 定义测试输入
    test_seq = [SOS_TOKEN, '1', '7', '0', '9', '5', EOS_TOKEN]
    print(f"输入序列: {test_seq}")

    # 2. 数据预处理(tokenize + padding + 转张量)
    src_ids_inf = torch.tensor(
        [pad_sequence(tokenize_sequence(test_seq, token_to_id), MAX_PADDED_LEN, PAD_ID)],
        dtype=torch.long
    )
    # 3. 生成源掩码
    src_mask_inf = create_src_mask(src_ids_inf, PAD_ID)
    # 4. 获取SOS/EOS ID
    sos_id = token_to_id[SOS_TOKEN]
    eos_id = token_to_id[EOS_TOKEN]

    # 5. 调用通用贪心解码函数
    generated_ids = greedy_decode(
        model=model,
        src=src_ids_inf,
        src_mask=src_mask_inf,
        max_len=MAX_PADDED_LEN,
        sos_id=sos_id,
        eos_id=eos_id,
        device=device
    )

    # 6. 后处理:ID转文本
    generated_tokens = [id_to_token[idx] for idx in generated_ids[0].cpu().numpy() if idx != PAD_ID]
    print(f"模型输出: {generated_tokens}")
相关推荐
All The Way North-2 小时前
AdamW 深度解析:从数学原理到 PyTorch 实现,对比分析AdamW与Adam
transformer·优化器·数学原理·adam·权重衰减·adamw·对比分析
机器学习之心2 小时前
多工况车速数据集训练BiLSTM-Attention用于车速预测,输出未来多个时间步车速,MATLAB代码
matlab·attention·bilstm·车速预测
小何code3 小时前
人工智能【第24篇】BERT模型详解:预训练语言模型的里程碑
自然语言处理·bert·transformer·预训练模型
kishu_iOS&AI5 小时前
NLP - Transformer原理解析
人工智能·自然语言处理·transformer
名字不好奇5 小时前
大模型如何理解上下文:Attention 机制详解
人工智能·llm·transformer
牧子川13 小时前
009-Transformer-Architecture
人工智能·深度学习·transformer
这张生成的图像能检测吗21 小时前
(论文速读)DSFormer:用于高光谱图像分类的双选择融合变压器网络
人工智能·深度学习·计算机视觉·transformer
dfsj660111 天前
第九章:Transformer 架构
深度学习·架构·transformer
高洁011 天前
知识图谱与检索增强的实战结合
人工智能·深度学习·数据挖掘·transformer·知识图谱