深度学习的数学原理(四十)—— Transformer 推理全过程

衔接前序 :本专栏第 35 篇完整手算了一轮训练迭代(前向→损失→反向→参数更新),第 36-38 篇分别实现了编码器、解码器和完整训练流程。但训练只是 Transformer 应用的一半------训练好的模型如何使用?推理过程的数学本质是什么?解码时那些"生成策略"(贪心搜索、束搜索、温度采样)背后的数学原理是什么?

本文回答这些问题。我们将从推理和训练的数学差异出发,用第 35 篇训练好的模型做手算推理示例,最后给出完整的推理代码实现。

一、引言:训练 vs 推理------两条完全不同的数学路径

在进入细节之前,先理解"训练"和"推理"在数学上的根本差异:

维度 训练 推理
目标 调整参数使损失最小化 用固定参数生成输出
梯度 需要反向传播计算梯度 不需要梯度,无需反向传播
数据流 一次输入整个序列 逐个生成 token,前一个输出是下一个输入
解码策略 Teacher forcing(已知目标序列) 多种策略(贪心、束搜索、采样)
Dropout 开启(训练时随机丢弃) 关闭(推理时确定性输出)
计算图 构建完整计算图,保留中间结果用于反向 只做前向,不保留计算图
并行度 序列内完全并行 序列内串行(自回归限制)

核心差异一句话:训练是一次性算完所有位置 → 求梯度 → 更新参数;推理是固定参数,逐个 token 生成,每步只产生一个 token。

二、推理的数学基础

2.1 条件概率的链式分解

Transformer 的推理本质是一个自回归条件概率模型 。给定源序列 X=(x1,x2,...,xm)X = (x_1, x_2, ..., x_m)X=(x1,x2,...,xm) 和已生成的前 t−1t-1t−1 个目标 token y1,y2,...,yt−1y_1, y_2, ..., y_{t-1}y1,y2,...,yt−1,模型预测第 ttt 个 token 的条件概率为:

P(yt∣y1,...,yt−1,X)=softmax(logitst)P(y_t \mid y_1, ..., y_{t-1}, X) = \text{softmax}(\text{logits}_t)P(yt∣y1,...,yt−1,X)=softmax(logitst)

其中 logitst\text{logits}_tlogitst 是模型最后一层输出投影得到的词汇表上的分数向量

整个序列的生成概率分解为:

P(Y∣X)=∏t=1TP(yt∣y<t,X)P(Y \mid X) = \prod_{t=1}^{T} P(y_t \mid y_{<t}, X)P(Y∣X)=t=1∏TP(yt∣y<t,X)

这就是自回归解码的数学本质------一个条件概率的链式乘积。

在训练阶段,这个分解中的 y<ty_{<t}y<t 来自真实的标注数据(teacher forcing);在推理阶段,y<ty_{<t}y<t 来自模型自己的预测。

2.2 推理的计算图

推理的计算图比训练小得多

复制代码
训练计算图(保留中间结果):
输入 → Embed → PE → Self-Attn → FFN → LN → ... → 输出投影 → Softmax → 损失
         │        │        │       │     │                │          │
         └────────┴────────┴───────┴─────┘                └─────┬────┘
             需要保留这些用于反向传播                   需要保留用于梯度计算

推理计算图(一次性前向,不保留):
输入 → Embed → PE → Self-Attn → FFN → LN → ... → 输出投影 → Softmax → argmax/sample
                                                                       └──→ 只取 argmax 或采样

关键点:推理时不需要保留任何中间激活值用于反向传播,因此推理的峰值显存显著低于训练。

2.3 推理不需要 Dropout

训练时 Dropout 会随机将一部分神经元置零:

FFNtrain(x)=Dropout(ReLU(xW1+b1))W2+b2\text{FFN}_{\text{train}}(x) = \text{Dropout}(\text{ReLU}(xW_1 + b_1))W_2 + b_2FFNtrain(x)=Dropout(ReLU(xW1+b1))W2+b2

推理时必须关闭 Dropout,否则输出具有随机性。但有一个缩放修正:推理时将 Dropout rate α\alphaα 乘以 (1−α)(1-\alpha)(1−α) 的系数,保持训练和推理时的输出分布一致。

PyTorch 的 model.eval() 自动处理这一切。


三、解码策略中的数学

这是推理中最具"数学味"的部分。从 logitst\text{logits}_tlogitst(一个长度为词表大小 VVV 的向量)到最终选定的 token,有多种数学变换。

3.1 Greedy Decoding(贪心解码)

最简单直接的策略:每步取概率最高的 token。

yt=arg⁡max⁡w∈VP(w∣y<t,X)y_t = \arg\max_{w \in \mathcal{V}} P(w \mid y_{<t}, X)yt=argw∈VmaxP(w∣y<t,X)

即 logits\text{logits}logits 中值最大的位置对应的 token。

复制代码
例子:logits = [0.1, 0.2, 1.5, 0.3, 2.8, 0.5]
      softmax: P ≈ [0.02, 0.02, 0.08, 0.02, 0.84, 0.02]
      argmax → token 4 (love)

问题:容易陷入局部最优。如果在某个位置选择了一个次优 token,后面的整个序列可能完全不同------但贪心策略无法回头。

3.2 Beam Search(束搜索)

束搜索维护 kkk 个候选序列(beam),每步从所有候选的 k×Vk \times Vk×V 个扩展中选择概率乘积最大的 kkk 个。

数学定义 :在每一步 ttt,对于当前 beam 集合 Bt−1\mathcal{B}{t-1}Bt−1(包含 kkk 个序列),每个序列 y<t(i)\mathbf{y}^{(i)}{<t}y<t(i) 的扩展概率为:

score(y<t(i)∘w)=1tα∑j=1t−1log⁡P(yj(i)∣y<j(i),X)+log⁡P(w∣y<t(i),X)\text{score}(\mathbf{y}^{(i)}{<t} \circ w) = \frac{1}{t^\alpha} \sum{j=1}^{t-1} \log P(y^{(i)}j \mid y^{(i)}{<j}, X) + \log P(w \mid \mathbf{y}^{(i)}_{<t}, X)score(y<t(i)∘w)=tα1j=1∑t−1logP(yj(i)∣y<j(i),X)+logP(w∣y<t(i),X)

其中 α\alphaα 是长度归一化系数 (α=0\alpha=0α=0 表示不归一化,α=1\alpha=1α=1 表示完全按长度平均)。长度归一化是必要的,因为概率乘积随序列长度指数衰减,不归一化会导致 beam 偏向更短的序列。

示例 (k=2k=2k=2):

复制代码
Step 1: 所有 V 个 token 的概率
        top-2: "i" (P=0.45), "love" (P=0.30)
        Beam: [i], [love]

Step 2: 对 beam 中每个序列扩展 V 个
        从 2×V 个候选中选 top-2:
        [i, love] (score=0.36), [i, deep] (score=0.24)
        Beam: [i, love], [i, deep]

Step 3: 继续扩展...

核心思想:别把鸡蛋放在一个篮子里

贪心解码的问题在于------每一步只选当前最好的,万一选错了,后面全错,而且没法回头

束搜索的思路很简单:每一步多留几条路,最后选总得分最高的那条

想象你在一个岔路口,贪心的人只看眼前最宽的路就走过去了;束搜索的人会同时走 3 条路,走几步之后发现哪条路通向死胡同,就放弃它,保留最有希望的路继续走。

束宽 k 是什么意思?

束宽(beam size)k 表示每一步保留多少个候选序列。

  • k=1:就是贪心解码(只留 1 个候选)
  • k=2:每一步保留 2 个最好的候选
  • k=3:每一步保留 3 个最好的候选
  • k 越大:搜索越全面,但计算量也越大(慢 k 倍)

手算示例:k=2 的束搜索过程

假设我们要翻译"我爱",词表只有 3 个词:ilovedeep

Step 1:生成第一个词

模型给出每个词的概率:

候选词 概率
i 0.45
love 0.30
deep 0.25

贪心会直接选 i(概率最高)。但束搜索(k=2)保留 前 2 个

复制代码
候选 1: [i]      得分 = log(0.45) = -0.80
候选 2: [love]   得分 = log(0.30) = -1.20

Step 2:扩展每个候选

对候选 1 [i],模型预测下一个词的概率:

扩展 概率
i, love 0.50
i, deep 0.30
i, i 0.20

对候选 2 [love],模型预测下一个词的概率:

扩展 概率
love, deep 0.60
love, i 0.25
love, love 0.15

现在有 2 × 3 = 6 个候选,计算每个的总得分(概率乘积取对数):

候选序列 得分计算 总分
i, love log(0.45) + log(0.50) = -0.80 + (-0.69) -1.49
i, deep log(0.45) + log(0.30) = -0.80 + (-1.20) -2.00
love, deep log(0.30) + log(0.60) = -1.20 + (-0.51) -1.71
love, i log(0.30) + log(0.25) = -1.20 + (-1.39) -2.59
i, i log(0.45) + log(0.20) = -0.80 + (-1.61) -2.41
love, love log(0.30) + log(0.15) = -1.20 + (-1.90) -3.10

选总分最高的 前 2 个

复制代码
保留候选 1: [i, love]    得分 = -1.49
保留候选 2: [love, deep] 得分 = -1.71

Step 3:继续扩展

[i, love][love, deep] 继续扩展,重复上述过程,直到遇到 <eos> 或达到最大长度。

最终:从所有完整序列中选总分最高的那个。

为什么束搜索比贪心好?

回到上面的例子,假设正确答案是 [i, love]

  • 贪心 :Step 1 选了 i ✓,Step 2 选了 love ✓ → 正确 ✅
  • 但如果 Step 1 贪心选了 i,Step 2 模型预测 love 概率只有 0.30,deep 概率 0.40,贪心就会选 [i, deep]

束搜索(k=2)在 Step 2 保留了 [i, love][love, deep] 两个候选,即使 [i, love] 在 Step 2 不是概率最高的,它仍然有机会活到 Step 3。

长度归一化:为什么需要它?

概率乘积会随序列长度指数衰减------序列越长,概率乘积越小。如果不做处理,束搜索会偏向更短的序列

例子

  • 序列 A(3 个词):0.5 × 0.5 × 0.5 = 0.125
  • 序列 B(10 个词):0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8= 0.107

虽然 B 每个词的概率都很高,但乘积反而比 A 小。概率乘积是单调递减的,上面的例子只是为了说明:短序列的乘积天然比长序列大,因为乘的次数少。

所以束搜索用长度归一化来修正:

score=1tα∑j=1tlog⁡P(yj)\text{score} = \frac{1}{t^\alpha} \sum_{j=1}^{t} \log P(y_j)score=tα1j=1∑tlogP(yj)

  • α=0\alpha=0α=0:不归一化,偏向短序列
  • α=1\alpha=1α=1:完全按长度平均,公平比较
  • 实践中常用 α=0.6\alpha=0.6α=0.6 或 0.70.70.7

束搜索 vs 贪心:一张表看懂

方面 贪心解码 束搜索(k≥2)
每步候选数 只看 1 个 看 k 个
搜索空间 每步 V 个 每步 k×V 个
能找到全局最优吗? ❌ 不能 ✅ 更有可能
速度 慢 k 倍
输出多样性 总是同一个结果 可得到 k 个不同结果
适合场景 实时要求高、任务简单 质量优先、机器翻译等

3.3 Temperature Sampling(温度采样)

温度 τ\tauτ 控制 softmax 分布的"尖锐程度":

Pτ(w∣y<t,X)=exp⁡(logitsw/τ)∑j=1Vexp⁡(logitsj/τ)P_\tau(w \mid y_{<t}, X) = \frac{\exp(\text{logits}w / \tau)}{\sum_{j=1}^{V} \exp(\text{logits}j / \tau)}Pτ(w∣y<t,X)=∑j=1Vexp(logitsj/τ)exp(logitsw/τ)

温度的作用

复制代码
温度 τ → 0:     分布趋近于 one-hot,等价于贪心解码
温度 τ = 1:     原始 softmax 分布
温度 τ → ∞:     分布趋近于均匀分布(完全随机)

数值示例:logits = 1.0, 2.0, 3.0

温度 τ=0.5\tau=0.5τ=0.5 τ=1.0\tau=1.0τ=1.0 τ=2.0\tau=2.0τ=2.0
exp(logit/τ) e², e⁴, e⁶ e¹, e², e³ e^0.5, e¹, e^1.5
相对比例 1 : 7.4 : 54.6 1 : 2.7 : 7.4 1 : 1.6 : 2.7
概率分布 0.02, 0.12, 0.87 0.09, 0.24, 0.67 0.22, 0.34, 0.44
采样结果 几乎总是 token 2 大概率 token 2 token 1/2 机会相近

低温(τ<1\tau < 1τ<1)让高概率 token 的优势更明显,生成更确定;高温(τ>1\tau > 1τ>1)让分布更平缓,生成更多样化。

3.4 Top-k / Top-p Sampling

Top-k :只从概率最高的 kkk 个 token 中采样:

Ptop-k(w)={P(w)∑j∈top-kP(j)if w∈top-k0otherwiseP_{\text{top-k}}(w) = \begin{cases} \frac{P(w)}{\sum_{j \in \text{top-k}} P(j)} & \text{if } w \in \text{top-k} \\ 0 & \text{otherwise} \end{cases}Ptop-k(w)={∑j∈top-kP(j)P(w)0if w∈top-kotherwise

Top-p (Nucleus) :从累计概率达到阈值 ppp 的最小 token 集合中采样:

V(p)=min⁡n{n  ∣  ∑i=1nP(i)≥p}\mathcal{V}^{(p)} = \min_{n} \left\{ n \;\big|\; \sum_{i=1}^{n} P_{(i)} \geq p \right\}V(p)=nmin{n i=1∑nP(i)≥p}

其中 P(i)P_{(i)}P(i) 是按概率从大到小排序后的第 iii 个 token。

数值示例

复制代码
logits = [5.0, 4.0, 1.0, 0.5, 0.2, 0.1]  (V=6)
softmax: P = [0.497, 0.327, 0.083, 0.049, 0.037, 0.007]

Top-3:  只从 [token0(P=0.497), token1(0.327), token2(0.083)] 中采样
        重归一化后 P' = [0.548, 0.361, 0.091]

Top-p=0.9:
        排序: [0.497, 0.327, 0.083, 0.049, 0.037, 0.007]
        累积: 0.497, 0.824, 0.907, 0.956, 0.993, 1.000
        top-p=0.9 → 选前 3 个 (累计 0.907 ≥ 0.9)
        (结果与 Top-3 类似,但在分布平坦时自适应调整)

Top-k vs Top-p 的对比

策略 优点 缺点
Top-k 固定 kkk,每步计算量恒定 分布平坦时 kkk 可能太大;分布尖锐时可能太小
Top-p 自适应候选集大小 动态候选数导致每步计算量波动

实践中通常组合使用:先 Temperature 调整分布,再 Top-p 截断尾部,最后采样。

3.5 Repetition Penalty(重复惩罚)

为减少重复生成,对已出现的 token 的 logit 施加惩罚:

logitsw={logitsw/θif w∈Glogitswotherwise\text{logits}w = \begin{cases} \text{logits}w / \theta & \text{if } w \in \mathcal{G} \\ \text{logits}w & \text{otherwise} \end{cases}logitsw={logitsw/θlogitswif w∈Gotherwise

其中 G\mathcal{G}G 是已生成的 token 集合,θ≥1\theta \geq 1θ≥1 是惩罚系数。θ=1\theta=1θ=1 表示无惩罚,θ>1\theta>1θ>1 降低已出现 token 的概率。这个操作在 softmax 之前执行,因为它直接修改 logits。

3.6 Logit 处理管线

所有解码策略可以组织为一个统一的 logit 处理管线:

复制代码
原始 logits (V 维)
    │
    ▼
1. Repetition Penalty: logits[w] /= θ  for w in generated set
    │
    ▼
2. Temperature Scaling: logits /= τ
    │
    ▼
3. Softmax: exp(logits) / sum(exp(logits))
    │
    ▼
4. Top-k / Top-p 截断: 将候选集外的概率置零 → 重归一化
    │
    ▼
5. 采样/argmax: 从最终概率分布中选择一个 token
    │
    ▼
选定的 token id

四、手算数值示例:从训练到推理

这是本文的核心------用第 35 篇训练好的极小型 Transformer,完整演示推理过程。

4.1 问题设定

沿用第 35 篇的配置,方便对比:

翻译任务"我", "爱", "深""i", "love", "deep"

模型配置

参数
dmodeld_{\text{model}}dmodel 4
hhh 1(单头)
dkd_kdk 4
dffd_{ff}dff 8
NencN_{\text{enc}}Nenc 1
NdecN_{\text{dec}}Ndec 1
源词表 {<pad>:0, 我:1, 爱:2, 深:3}
目标词表 {<pad>:0, <sos>:1, <eos>:2, i:3, love:4, deep:5}

训练状态 :经过第 35 篇的一轮 SGD 更新,仅输出投影层 WprojW_{\text{proj}}Wproj 被训练(QKV 等其他权重仍为初始身份矩阵)。

4.2 编码器前向(引用第 33/35 篇结果)

编码器输出 Xenc∈R3×4X_{\text{enc}} \in \mathbb{R}^{3 \times 4}Xenc∈R3×4(来自第 33 篇的数值计算):

Xenc=(1.279−0.7830.655−1.151−0.3091.529−0.758−0.4620.703−0.294−0.313−0.096)←"我"←"爱"←"深" X_{\text{enc}} = \begin{pmatrix} 1.279 & -0.783 & 0.655 & -1.151 \\ -0.309 & 1.529 & -0.758 & -0.462 \\ 0.703 & -0.294 & -0.313 & -0.096 \end{pmatrix} \begin{aligned} &\leftarrow \text{"我"} \\ &\leftarrow \text{"爱"} \\ &\leftarrow \text{"深"} \end{aligned} Xenc= 1.279−0.3090.703−0.7831.529−0.2940.655−0.758−0.313−1.151−0.462−0.096 ←"我"←"爱"←"深"

这个编码器输出在推理过程中只计算一次,所有解码步共享。

4.3 训练后的输出投影矩阵

第 35 篇训练后,WprojW_{\text{proj}}Wproj 从初始值(导致全部预测错误)更新为更优的值。我们关心的三个目标 token 的 logit 变化如下:

Token 初始 logit(位置0) 训练后 logit(位置0) 变化
i 0.5 0.743 ↑ 48.6%
love 0.9 ← 错误最高 (下降)
deep 0.1 (下降)

注意:第 35 篇的梯度计算显示,位置 0 上"love"(正确答案应为"i")的梯度是负的,因此经过 SGD 后,"love"的 logit 下降,"i"的 logit 上升。三个位置的完整 logit 变化见第 35 篇第 7.2 节。

4.4 推理的数值计算

与"训练阶段一次前向算出所有位置"不同,推理需要逐 token 生成

Step 1:生成第一个 token

解码器输入:[<sos>](token id 1)

解码器前向(与第 34-35 篇相同的计算过程)得到第 0 位置的输出表示:

D0=1.0,0.0,1.0,0.0D_0 = 1.0, 0.0, 1.0, 0.0D0=1.0,0.0,1.0,0.0

(注:由于 QKV 权重仍是身份矩阵,且交叉注意力的 K、V 来自 XencX_{\text{enc}}Xenc,D0D_0D0 的值与第 35 篇中的训练前向一致)

通过训练好的 WprojW_{\text{proj}}Wproj 计算 logits:

logits=D0⋅Wproj⊤=logitpad,logitsos,logiteos,logiti,logitlove,logitdeep\text{logits} = D_0 \cdot W_{\text{proj}}^\top = \\text{logit}_{\\text{pad}}, \\text{logit}_{\\text{sos}}, \\text{logit}_{\\text{eos}}, \\text{logit}_i, \\text{logit}_{\\text{love}}, \\text{logit}_{\\text{deep}}logits=D0⋅Wproj⊤=logitpad,logitsos,logiteos,logiti,logitlove,logitdeep

使用第 35 篇训练后的值:

Token Logit Exp(logit) Softmax 累计概率
<sos> --- --- --- ---
<eos> --- --- --- ---
i 0.743 2.102 0.335 0.335
love 0.682 1.978 0.315 0.650
deep 0.320 1.377 0.219 0.869
其他 < 0.3 < 1.35 < 0.131 1.000

Greedy 选择i(logit=0.743 最高)

采样 (τ=1\tau=1τ=1):以 33.5% 概率选 i,31.5% 概率选 love,21.9% 概率选 deep

可以看到,经过一轮训练后,"i"从初始的第 2 高提升到了第 1 高------训练方向正确。

Step 2:生成第二个 token

解码器输入:[<sos>, i](前面生成的 token 作为输入)

解码器前向得到第 1 位置的输出表示:

D1=0.0,1.0,0.0,1.0D_1 = 0.0, 1.0, 0.0, 1.0D1=0.0,1.0,0.0,1.0

(值未变,因为 QKV 仍是身份矩阵,且因果掩码使位置 1 看不到位置 0 的信息以外的位置)

这里有一个微妙之处:因果掩码确保位置 1 能看到位置 0 。但我们的 QKV 是身份矩阵,所以位置 1 的 self-attention 只是对位置 0 和位置 1 的 embedding 做了加权平均。由于初始权重下这个加权平均不会改变向量的值,D1D_1D1 实际上与训练阶段仍然相同。

通过 WprojW_{\text{proj}}Wproj 计算 logits:

Token 训练后 Logit Softmax
i 0.100 0.153
love 0.248 0.177
deep 0.602 0.253 ⬅ 仍然最高
其他 < 0.3 < 0.14

Greedy 选择deep!❌ 模型仍然错误。

为什么? 尽管第 35 篇的 SGD 提升了 "love" 在这个位置的 logit(0.0 → 0.248),但 "deep" 的 logit 受其他位置的梯度牵制,下降幅度不够大(从 0.7 → 0.602),仍然最高。

这里揭示了一个重要洞察:一轮 SGD 训练不足以完全纠正所有预测。这是一个"梯度竞争"现象------同一个参数在不同位置受到不同方向的梯度拉力,最终取折中。需要多轮训练让模型收敛。第 35 篇第 7.1 节中提到了这一点:"一轮训练使三个目标 logit 全面上升------趋势正确,但还不足以完全纠正预测。"

Step 3:生成第三个 token

解码器输入:[<sos>, i, love](假设 step 2 我们用了采样而非 greedy,采样到了 "love")

当然如果 step 2 选了 "deep",那 step 3 就是别的了。我们这里假设纠正了 step 2(使用采样或人为干预),演示完整的 3 步生成。

解码器前向得到第 2 位置的输出:

D2=1.0,1.0,0.0,0.0D_2 = 1.0, 1.0, 0.0, 0.0D2=1.0,1.0,0.0,0.0

通过 WprojW_{\text{proj}}Wproj 计算 logits:

Token 训练后 Logit Softmax
i 0.443 0.174
love 0.522 0.191
deep 0.801 0.248
其他 < 0.5 < 0.15

Greedy 选择deep ✓ 正确。

4.5 推理全过程数值一览

步骤 输入 正确输出 Greedy 选择 采样可能的输出 置信度
1 <sos> i i i, love, deep 33.5%
2 [sos], i love deep ❌ love, deep, i 17.7%
3 [sos], i, love deep deep deep, love, eos 24.8%

关键结论 :经过一轮训练后,位置 1 的预测仍然错误------这完美解释了为什么训练需要多个 epoch。如果使用 beam search(k=2k=2k=2),模型可以保留 [i, love][i, deep] 两个候选,在 step 3 发现 [i, love, deep] 的总分更高,从而纠正 greedy 的错误。

4.6 解码策略对比(同一组 logits)

假设 Step 1 的 logits 为 [0.1, 0.0, 0.1, 0.743, 0.682, 0.320],观察不同策略的结果:

Greedyargmax → token 3(i)------ 确定性输出

Temperature 采样

τ\tauτ 概率分布 效果
0.5 0.01, 0.01, 0.01, 0.54, 0.37, 0.06 几乎一定选 i
1.0 0.03, 0.03, 0.03, 0.34, 0.31, 0.22 i 概率略高,但仍可能选 love
2.0 0.08, 0.07, 0.08, 0.24, 0.23, 0.19 分布平坦,各种选择都有可能

Top-3 + 采样

  • 候选 token: 3(i), 4(love), 5(deep)
  • 重归一化后: 0.38, 0.36, 0.26
  • 避免了尾部噪声 token 的干扰

五、推理代码实现

下面将推理过程转化为 PyTorch 代码。我们基于第 36-38 篇的 transformer_modules.py 实现。

5.1 Logit 处理函数

python 复制代码
import torch
import torch.nn.functional as F

def apply_repetition_penalty(logits, generated_ids, penalty=1.0):
    """对已生成的 token 施加重复惩罚。"""
    if penalty == 1.0:
        return logits
    for token_id in generated_ids:
        logits[token_id] /= penalty
    return logits

def apply_temperature(logits, temperature=1.0):
    """温度缩放。"""
    if temperature == 0:
        # 温度趋近 0 ≈ 贪心
        return logits  # caller 处理
    return logits / temperature

def apply_top_k(logits, k=50):
    """只保留概率最高的 k 个 token。"""
    if k <= 0 or k >= logits.size(-1):
        return logits
    values, _ = torch.topk(logits, k, dim=-1)
    threshold = values[..., -1].unsqueeze(-1)
    logits[logits < threshold] = float('-inf')
    return logits

def apply_top_p(logits, p=0.9):
    """保留累计概率达到 p 的最小 token 集合。"""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    
    # 找到需要保留的 token 索引
    sorted_mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= p
    sorted_indices_to_remove = sorted_mask.clone()
    sorted_indices_to_remove[..., 1:] = sorted_mask[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    
    # 将移除的 token 的 logit 设为 -inf
    indices_to_remove = sorted_indices_to_remove.scatter(
        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
    )
    logits[indices_to_remove] = float('-inf')
    return logits

def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=1.0, 
                       repetition_penalty=1.0, generated_ids=None):
    """完整的 logit 处理管线。"""
    # 1. 重复惩罚
    if generated_ids is not None and repetition_penalty > 1.0:
        logits = apply_repetition_penalty(logits, generated_ids, repetition_penalty)
    
    # 2. 温度缩放
    if temperature > 0:
        logits = apply_temperature(logits, temperature)
    
    # 3. Probs
    probs = F.softmax(logits, dim=-1)
    
    # 4. Top-k
    if top_k > 0:
        logits = apply_top_k(logits, top_k)
        probs = F.softmax(logits, dim=-1)
    
    # 5. Top-p
    if top_p < 1.0:
        logits = apply_top_p(logits, top_p)
        probs = F.softmax(logits, dim=-1)
    
    # 6. 采样
    if temperature == 0:
        # 贪心
        return torch.argmax(logits, dim=-1)
    else:
        return torch.multinomial(probs, num_samples=1)

5.2 生成循环:自回归推理

下面实现完整的 generate() 函数,它接收一个源序列,返回模型生成的目标序列:

python 复制代码
@torch.no_grad()  # 推理的关键:不追踪梯度
def generate(model, src_ids, max_len=50, bos_token_id=1, eos_token_id=2,
             temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0):
    """
    完整的 Transformer 推理函数。
    
    参数:
        model: 训练好的 Transformer 模型
        src_ids: 源序列 token ids, shape (1, src_len)
        max_len: 最大生成长度
        temperature / top_k / top_p: 解码策略参数
    
    返回:
        generated_ids: 生成的 token id 序列
    """
    model.eval()  # 关闭 Dropout, BatchNorm 等训练行为
    device = next(model.parameters()).device
    
    src_ids = src_ids.to(device)
    batch_size = src_ids.size(0)
    
    # ========== 阶段 1:编码器前向(一次性) ==========
    encoder_output = model.encode(src_ids)  # (1, src_len, d_model)
    
    # ========== 阶段 2:解码器自回归 ==========
    # 从 <sos> 开始
    generated = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
    
    for step in range(max_len):
        # 解码器前向:输入当前已生成的序列
        decoder_output = model.decode(generated, encoder_output)  # (1, step+1, d_model)
        
        # 取最后位置的输出
        last_hidden = decoder_output[:, -1:, :]  # (1, 1, d_model)
        
        # 输出投影 → logits
        logits = model.output_projection(last_hidden)  # (1, 1, vocab_size)
        logits = logits.squeeze(1)  # (1, vocab_size)
        
        # 采样
        next_token = sample_from_logits(
            logits,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            generated_ids=generated[0].tolist(),
        )  # (1, 1)
        
        # 拼接
        generated = torch.cat([generated, next_token], dim=-1)
        
        # 遇到 <eos> 停止
        if next_token.item() == eos_token_id:
            break
    
    return generated[0]  # 返回 1D 序列

5.3 关键细节:为什么用 @torch.no_grad()

这行注解是推理代码最重要的细节之一。它告诉 PyTorch 不要构建计算图,从而:

  1. 节省显存:不保留中间激活值
  2. 加速计算:省去了构建计算图的开销
  3. 语义正确:推理时不需要梯度

对比:

python 复制代码
# 训练模式:会构建计算图,占用额外显存
output = model(x)  # 保留中间值,loss.backward() 可用

with torch.no_grad():
    # 推理模式:不构建计算图
    output = model(x)  # 只做前向,中间值用完即释放

5.4 完整推理示例

python 复制代码
def translate_example():
    """用训练好的模型翻译"我爱深度学习"。"""
    from transformer_modules import Transformer
    
    # 加载训练好的模型(第 38 篇保存的权重)
    model = Transformer(
        src_vocab_size=1106,  # 中文词表
        tgt_vocab_size=54,    # 英文词表
        d_model=32,
        n_heads=4,
        d_ff=128,
        n_layers=3,
    )
    model.load_state_dict(torch.load('transformer_trained.pt'))
    model.cuda()
    model.eval()
    
    # 编码源句(来自 data_utils.py)
    src_text = "我爱深度学习"
    src_ids = encode_line(src_text, zh_token2idx, lang="zh")
    src_tensor = torch.tensor([src_ids]).cuda()
    
    # 生成翻译
    result_ids = generate(
        model, 
        src_tensor,
        temperature=0.7,     # 稍微降低随机性
        top_k=40,            
        top_p=0.9,
        repetition_penalty=1.2,
    )
    
    # 解码
    result_text = decode_ids(result_ids.cpu().tolist(), en_idx2token)
    print(f"源句: {src_text}")
    print(f"翻译: {result_text}")

5.5 不同解码策略的效果对比

在训练好的模型上,对同一句"我爱深度学习"用不同策略生成:

策略 参数 生成结果 特点
Greedy τ=0\tau=0τ=0 i love deep learning 确定、但可能平庸
采样 τ=1\tau=1τ=1 i love deep learning 有多种可能
采样 τ=1.5\tau=1.5τ=1.5 i adore profound study 多样性高,但也可能跑偏
Beam-3 k=3k=3k=3 i love deep learning 最稳定,接近最优
Top-p p=0.9,τ=1p=0.9, \tau=1p=0.9,τ=1 i love deep learning 平衡多样性与质量

六、推理的性能分析

6.1 推理的计算量

对于一个长度为 LLL 的序列,推理的计算量:

编码器 (一次性):O(L2)O(L^2)O(L2) 注意力计算

解码器 (逐 token):第 ttt 步的计算量 O(t)O(t)O(t),总计算量 ∑t=1TO(t)=O(T2)\sum_{t=1}^{T} O(t) = O(T^2)∑t=1TO(t)=O(T2)

总推理计算量 :O(L2+T2)O(L^2 + T^2)O(L2+T2),其中 LLL 是源长度,TTT 是目标长度。

对比训练的每轮 计算量:O(L2+T2)O(L^2 + T^2)O(L2+T2)(前向)+ O(L2+T2)O(L^2 + T^2)O(L2+T2)(反向)= O(L2+T2)O(L^2 + T^2)O(L2+T2)

复制代码
推理 = 编码器 O(L²) + 解码器 O(T²)
训练 = 编码器 O(L²) + 解码器 O(T²) + 反向传播 O(L²+T²)
     ≈ 推理 × 3 (训练一次前向+反向约等于 3 次推理前向)

6.2 推理与训练的显存对比

以 d_model=1024, n_heads=8, seq_len=1024 为例:

阶段 需要保留的张量 估算显存
训练前向 QKV 投影、注意力分数、FFN 激活、LN 统计量等 ~10 倍模型参数
训练反向 所有中间激活值 + 梯度 ~20 倍模型参数
推理 只需当前激活值(用完即弃) ~2 倍模型参数

所以推理的显存通常只有训练的 1/10 到 1/5

6.3 推理的瓶颈

虽然推理显存需求低,但速度是瓶颈:

瓶颈 1:自回归串行性(本文的核心问题)

解码器每步只能生成一个 token,无法并行。生成 100 个 token 需要 100 次串行前向。

瓶颈 2:重复计算

回忆 Step 1 和 Step 2 的差异:

  • Step 1:解码器处理 [<sos>]
  • Step 2:解码器处理 [<sos>, i]

在 Step 2 中,位置 0 的 K 和 V 与 Step 1 完全一样------但 Step 2 重新算了一遍。

这就是 KV Cache 要解决的问题------下一篇(第 41 篇)的主角。

复制代码
推理的冗余(以位置 1 为例):
Step 1: Q₀=W_Q·x₀,  K₀=W_K·x₀,  V₀=W_V·x₀    ← 已计算
Step 2: Q₀=W_Q·x₀,  K₀=W_K·x₀,  V₀=W_V·x₀    ← 重复!
        Q₁=W_Q·x₁,  K₁=W_K·x₁,  V₁=W_V·x₁    ← 新的,必须算

如果不缓存 :解码 TTT 步的计算量是 O(T2)O(T^2)O(T2)

如果缓存 :解码 TTT 步的计算量是 O(T)O(T)O(T)

对于 T=1000T=1000T=1000,这意味着 1000 倍的差距。


七、总结

推理 vs 训练的核心差异

  1. 不需要梯度 ------推理用 torch.no_grad(),不构建计算图
  2. 串行生成------自回归的数学本质是条件概率的链式分解,每一步依赖上一步的输出
  3. 解码策略------从 argmax 到 temperature 采样到 top-p 截断,所有策略都是 logits 后处理
  4. 显存远低于训练------约 1/5 到 1/10
  5. 有计算冗余------重复的 QKV 投影是下一篇文章的起点

推理过程的数学全景

复制代码
源序列 [x₁, x₂, ..., xₘ]
    │
    ▼ (一次性)
编码器 ──→ Encoder Output (m × d_model)
              │
              ├─────────────────────────────────────────┐
              │ 交叉注意力的 K, V                         │
              ▼                                          ▼
             解码器                                 解码器(下一个step)
  输入: [<sos>]                             输入: [<sos>, y₁]
             │                                         │
             ▼                                         ▼
         Decoder Output                           Decoder Output
             │                                         │
             ▼                                         ▼
        Output Proj                                Output Proj
             │                                         │
             ▼  (logit 处理管线)                       ▼
  logits → temperature → top-k/p → argmax/sample    ← 重复直到 <eos>
             │
             ▼
          y₁ = "i" ───────────────────────────→ 作为 step 2 输入一部分
相关推荐
Bingorl1 小时前
机器学习之集成学习
人工智能·机器学习·集成学习
weixin_468466851 小时前
SURF 图像特征提取算法新手实战指南
图像处理·人工智能·算法·机器视觉·surf·sift
盛夏光年爱学习1 小时前
Agentic RAG 深度解析:让 Agent 自己决定要不要检索、检索几次,这才是 RAG 的正确打开方式
人工智能
weiwin1231 小时前
MAF入门(3 下):多轮对话进阶——清除历史、注入 System、截断策略
人工智能·agent
Coder小相1 小时前
LangChain 1.0 第五篇 - Tool与MCP让Agent拥有行动力
人工智能·langchain·ai编程
太华1 小时前
学习AI Agent编程-第五天-LlamaIndex - 将Nodes生成索引并存储
人工智能
太华1 小时前
学习AI Agent编程-第三天-LlamaIndex - 如何将PDF文件正确转成Document
人工智能
jiayong231 小时前
AI架构师面试问题与解答 - 深度学习架构篇
人工智能·深度学习
unclejet1 小时前
颠覆传统开发!AI根治软件工程技术债务顽疾
大数据·人工智能·软件工程