衔接前序 :本专栏第 35 篇完整手算了一轮训练迭代(前向→损失→反向→参数更新),第 36-38 篇分别实现了编码器、解码器和完整训练流程。但训练只是 Transformer 应用的一半------训练好的模型如何使用?推理过程的数学本质是什么?解码时那些"生成策略"(贪心搜索、束搜索、温度采样)背后的数学原理是什么?
本文回答这些问题。我们将从推理和训练的数学差异出发,用第 35 篇训练好的模型做手算推理示例,最后给出完整的推理代码实现。
一、引言:训练 vs 推理------两条完全不同的数学路径
在进入细节之前,先理解"训练"和"推理"在数学上的根本差异:
| 维度 | 训练 | 推理 |
|---|---|---|
| 目标 | 调整参数使损失最小化 | 用固定参数生成输出 |
| 梯度 | 需要反向传播计算梯度 | 不需要梯度,无需反向传播 |
| 数据流 | 一次输入整个序列 | 逐个生成 token,前一个输出是下一个输入 |
| 解码策略 | Teacher forcing(已知目标序列) | 多种策略(贪心、束搜索、采样) |
| Dropout | 开启(训练时随机丢弃) | 关闭(推理时确定性输出) |
| 计算图 | 构建完整计算图,保留中间结果用于反向 | 只做前向,不保留计算图 |
| 并行度 | 序列内完全并行 | 序列内串行(自回归限制) |
核心差异一句话:训练是一次性算完所有位置 → 求梯度 → 更新参数;推理是固定参数,逐个 token 生成,每步只产生一个 token。
二、推理的数学基础
2.1 条件概率的链式分解
Transformer 的推理本质是一个自回归条件概率模型 。给定源序列 X=(x1,x2,...,xm)X = (x_1, x_2, ..., x_m)X=(x1,x2,...,xm) 和已生成的前 t−1t-1t−1 个目标 token y1,y2,...,yt−1y_1, y_2, ..., y_{t-1}y1,y2,...,yt−1,模型预测第 ttt 个 token 的条件概率为:
P(yt∣y1,...,yt−1,X)=softmax(logitst)P(y_t \mid y_1, ..., y_{t-1}, X) = \text{softmax}(\text{logits}_t)P(yt∣y1,...,yt−1,X)=softmax(logitst)
其中 logitst\text{logits}_tlogitst 是模型最后一层输出投影得到的词汇表上的分数向量。
整个序列的生成概率分解为:
P(Y∣X)=∏t=1TP(yt∣y<t,X)P(Y \mid X) = \prod_{t=1}^{T} P(y_t \mid y_{<t}, X)P(Y∣X)=t=1∏TP(yt∣y<t,X)
这就是自回归解码的数学本质------一个条件概率的链式乘积。
在训练阶段,这个分解中的 y<ty_{<t}y<t 来自真实的标注数据(teacher forcing);在推理阶段,y<ty_{<t}y<t 来自模型自己的预测。
2.2 推理的计算图
推理的计算图比训练小得多:
训练计算图(保留中间结果):
输入 → Embed → PE → Self-Attn → FFN → LN → ... → 输出投影 → Softmax → 损失
│ │ │ │ │ │ │
└────────┴────────┴───────┴─────┘ └─────┬────┘
需要保留这些用于反向传播 需要保留用于梯度计算
推理计算图(一次性前向,不保留):
输入 → Embed → PE → Self-Attn → FFN → LN → ... → 输出投影 → Softmax → argmax/sample
└──→ 只取 argmax 或采样
关键点:推理时不需要保留任何中间激活值用于反向传播,因此推理的峰值显存显著低于训练。
2.3 推理不需要 Dropout
训练时 Dropout 会随机将一部分神经元置零:
FFNtrain(x)=Dropout(ReLU(xW1+b1))W2+b2\text{FFN}_{\text{train}}(x) = \text{Dropout}(\text{ReLU}(xW_1 + b_1))W_2 + b_2FFNtrain(x)=Dropout(ReLU(xW1+b1))W2+b2
推理时必须关闭 Dropout,否则输出具有随机性。但有一个缩放修正:推理时将 Dropout rate α\alphaα 乘以 (1−α)(1-\alpha)(1−α) 的系数,保持训练和推理时的输出分布一致。
PyTorch 的 model.eval() 自动处理这一切。
三、解码策略中的数学
这是推理中最具"数学味"的部分。从 logitst\text{logits}_tlogitst(一个长度为词表大小 VVV 的向量)到最终选定的 token,有多种数学变换。
3.1 Greedy Decoding(贪心解码)
最简单直接的策略:每步取概率最高的 token。
yt=argmaxw∈VP(w∣y<t,X)y_t = \arg\max_{w \in \mathcal{V}} P(w \mid y_{<t}, X)yt=argw∈VmaxP(w∣y<t,X)
即 logits\text{logits}logits 中值最大的位置对应的 token。
例子:logits = [0.1, 0.2, 1.5, 0.3, 2.8, 0.5]
softmax: P ≈ [0.02, 0.02, 0.08, 0.02, 0.84, 0.02]
argmax → token 4 (love)
问题:容易陷入局部最优。如果在某个位置选择了一个次优 token,后面的整个序列可能完全不同------但贪心策略无法回头。
3.2 Beam Search(束搜索)
束搜索维护 kkk 个候选序列(beam),每步从所有候选的 k×Vk \times Vk×V 个扩展中选择概率乘积最大的 kkk 个。
数学定义 :在每一步 ttt,对于当前 beam 集合 Bt−1\mathcal{B}{t-1}Bt−1(包含 kkk 个序列),每个序列 y<t(i)\mathbf{y}^{(i)}{<t}y<t(i) 的扩展概率为:
score(y<t(i)∘w)=1tα∑j=1t−1logP(yj(i)∣y<j(i),X)+logP(w∣y<t(i),X)\text{score}(\mathbf{y}^{(i)}{<t} \circ w) = \frac{1}{t^\alpha} \sum{j=1}^{t-1} \log P(y^{(i)}j \mid y^{(i)}{<j}, X) + \log P(w \mid \mathbf{y}^{(i)}_{<t}, X)score(y<t(i)∘w)=tα1j=1∑t−1logP(yj(i)∣y<j(i),X)+logP(w∣y<t(i),X)
其中 α\alphaα 是长度归一化系数 (α=0\alpha=0α=0 表示不归一化,α=1\alpha=1α=1 表示完全按长度平均)。长度归一化是必要的,因为概率乘积随序列长度指数衰减,不归一化会导致 beam 偏向更短的序列。
示例 (k=2k=2k=2):
Step 1: 所有 V 个 token 的概率
top-2: "i" (P=0.45), "love" (P=0.30)
Beam: [i], [love]
Step 2: 对 beam 中每个序列扩展 V 个
从 2×V 个候选中选 top-2:
[i, love] (score=0.36), [i, deep] (score=0.24)
Beam: [i, love], [i, deep]
Step 3: 继续扩展...
核心思想:别把鸡蛋放在一个篮子里
贪心解码的问题在于------每一步只选当前最好的,万一选错了,后面全错,而且没法回头。
束搜索的思路很简单:每一步多留几条路,最后选总得分最高的那条。
想象你在一个岔路口,贪心的人只看眼前最宽的路就走过去了;束搜索的人会同时走 3 条路,走几步之后发现哪条路通向死胡同,就放弃它,保留最有希望的路继续走。
束宽 k 是什么意思?
束宽(beam size)k 表示每一步保留多少个候选序列。
- k=1:就是贪心解码(只留 1 个候选)
- k=2:每一步保留 2 个最好的候选
- k=3:每一步保留 3 个最好的候选
- k 越大:搜索越全面,但计算量也越大(慢 k 倍)
手算示例:k=2 的束搜索过程
假设我们要翻译"我爱",词表只有 3 个词:i、love、deep。
Step 1:生成第一个词
模型给出每个词的概率:
| 候选词 | 概率 |
|---|---|
| i | 0.45 |
| love | 0.30 |
| deep | 0.25 |
贪心会直接选 i(概率最高)。但束搜索(k=2)保留 前 2 个:
候选 1: [i] 得分 = log(0.45) = -0.80
候选 2: [love] 得分 = log(0.30) = -1.20
Step 2:扩展每个候选
对候选 1 [i],模型预测下一个词的概率:
| 扩展 | 概率 |
|---|---|
| i, love | 0.50 |
| i, deep | 0.30 |
| i, i | 0.20 |
对候选 2 [love],模型预测下一个词的概率:
| 扩展 | 概率 |
|---|---|
| love, deep | 0.60 |
| love, i | 0.25 |
| love, love | 0.15 |
现在有 2 × 3 = 6 个候选,计算每个的总得分(概率乘积取对数):
| 候选序列 | 得分计算 | 总分 |
|---|---|---|
| i, love | log(0.45) + log(0.50) = -0.80 + (-0.69) | -1.49 |
| i, deep | log(0.45) + log(0.30) = -0.80 + (-1.20) | -2.00 |
| love, deep | log(0.30) + log(0.60) = -1.20 + (-0.51) | -1.71 |
| love, i | log(0.30) + log(0.25) = -1.20 + (-1.39) | -2.59 |
| i, i | log(0.45) + log(0.20) = -0.80 + (-1.61) | -2.41 |
| love, love | log(0.30) + log(0.15) = -1.20 + (-1.90) | -3.10 |
选总分最高的 前 2 个:
保留候选 1: [i, love] 得分 = -1.49
保留候选 2: [love, deep] 得分 = -1.71
Step 3:继续扩展
对 [i, love] 和 [love, deep] 继续扩展,重复上述过程,直到遇到 <eos> 或达到最大长度。
最终:从所有完整序列中选总分最高的那个。
为什么束搜索比贪心好?
回到上面的例子,假设正确答案是 [i, love]:
- 贪心 :Step 1 选了
i✓,Step 2 选了love✓ → 正确 ✅ - 但如果 Step 1 贪心选了
i,Step 2 模型预测love概率只有 0.30,deep概率 0.40,贪心就会选[i, deep]❌
束搜索(k=2)在 Step 2 保留了 [i, love] 和 [love, deep] 两个候选,即使 [i, love] 在 Step 2 不是概率最高的,它仍然有机会活到 Step 3。
长度归一化:为什么需要它?
概率乘积会随序列长度指数衰减------序列越长,概率乘积越小。如果不做处理,束搜索会偏向更短的序列。
例子:
- 序列 A(3 个词):0.5 × 0.5 × 0.5 = 0.125
- 序列 B(10 个词):0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8 × 0.8= 0.107
虽然 B 每个词的概率都很高,但乘积反而比 A 小。概率乘积是单调递减的,上面的例子只是为了说明:短序列的乘积天然比长序列大,因为乘的次数少。
所以束搜索用长度归一化来修正:
score=1tα∑j=1tlogP(yj)\text{score} = \frac{1}{t^\alpha} \sum_{j=1}^{t} \log P(y_j)score=tα1j=1∑tlogP(yj)
- α=0\alpha=0α=0:不归一化,偏向短序列
- α=1\alpha=1α=1:完全按长度平均,公平比较
- 实践中常用 α=0.6\alpha=0.6α=0.6 或 0.70.70.7
束搜索 vs 贪心:一张表看懂
| 方面 | 贪心解码 | 束搜索(k≥2) |
|---|---|---|
| 每步候选数 | 只看 1 个 | 看 k 个 |
| 搜索空间 | 每步 V 个 | 每步 k×V 个 |
| 能找到全局最优吗? | ❌ 不能 | ✅ 更有可能 |
| 速度 | 快 | 慢 k 倍 |
| 输出多样性 | 总是同一个结果 | 可得到 k 个不同结果 |
| 适合场景 | 实时要求高、任务简单 | 质量优先、机器翻译等 |
3.3 Temperature Sampling(温度采样)
温度 τ\tauτ 控制 softmax 分布的"尖锐程度":
Pτ(w∣y<t,X)=exp(logitsw/τ)∑j=1Vexp(logitsj/τ)P_\tau(w \mid y_{<t}, X) = \frac{\exp(\text{logits}w / \tau)}{\sum_{j=1}^{V} \exp(\text{logits}j / \tau)}Pτ(w∣y<t,X)=∑j=1Vexp(logitsj/τ)exp(logitsw/τ)
温度的作用:
温度 τ → 0: 分布趋近于 one-hot,等价于贪心解码
温度 τ = 1: 原始 softmax 分布
温度 τ → ∞: 分布趋近于均匀分布(完全随机)
数值示例:logits = 1.0, 2.0, 3.0
| 温度 | τ=0.5\tau=0.5τ=0.5 | τ=1.0\tau=1.0τ=1.0 | τ=2.0\tau=2.0τ=2.0 |
|---|---|---|---|
| exp(logit/τ) | e², e⁴, e⁶ | e¹, e², e³ | e^0.5, e¹, e^1.5 |
| 相对比例 | 1 : 7.4 : 54.6 | 1 : 2.7 : 7.4 | 1 : 1.6 : 2.7 |
| 概率分布 | 0.02, 0.12, 0.87 | 0.09, 0.24, 0.67 | 0.22, 0.34, 0.44 |
| 采样结果 | 几乎总是 token 2 | 大概率 token 2 | token 1/2 机会相近 |
低温(τ<1\tau < 1τ<1)让高概率 token 的优势更明显,生成更确定;高温(τ>1\tau > 1τ>1)让分布更平缓,生成更多样化。
3.4 Top-k / Top-p Sampling
Top-k :只从概率最高的 kkk 个 token 中采样:
Ptop-k(w)={P(w)∑j∈top-kP(j)if w∈top-k0otherwiseP_{\text{top-k}}(w) = \begin{cases} \frac{P(w)}{\sum_{j \in \text{top-k}} P(j)} & \text{if } w \in \text{top-k} \\ 0 & \text{otherwise} \end{cases}Ptop-k(w)={∑j∈top-kP(j)P(w)0if w∈top-kotherwise
Top-p (Nucleus) :从累计概率达到阈值 ppp 的最小 token 集合中采样:
V(p)=minn{n ∣ ∑i=1nP(i)≥p}\mathcal{V}^{(p)} = \min_{n} \left\{ n \;\big|\; \sum_{i=1}^{n} P_{(i)} \geq p \right\}V(p)=nmin{n i=1∑nP(i)≥p}
其中 P(i)P_{(i)}P(i) 是按概率从大到小排序后的第 iii 个 token。
数值示例:
logits = [5.0, 4.0, 1.0, 0.5, 0.2, 0.1] (V=6)
softmax: P = [0.497, 0.327, 0.083, 0.049, 0.037, 0.007]
Top-3: 只从 [token0(P=0.497), token1(0.327), token2(0.083)] 中采样
重归一化后 P' = [0.548, 0.361, 0.091]
Top-p=0.9:
排序: [0.497, 0.327, 0.083, 0.049, 0.037, 0.007]
累积: 0.497, 0.824, 0.907, 0.956, 0.993, 1.000
top-p=0.9 → 选前 3 个 (累计 0.907 ≥ 0.9)
(结果与 Top-3 类似,但在分布平坦时自适应调整)
Top-k vs Top-p 的对比:
| 策略 | 优点 | 缺点 |
|---|---|---|
| Top-k | 固定 kkk,每步计算量恒定 | 分布平坦时 kkk 可能太大;分布尖锐时可能太小 |
| Top-p | 自适应候选集大小 | 动态候选数导致每步计算量波动 |
实践中通常组合使用:先 Temperature 调整分布,再 Top-p 截断尾部,最后采样。
3.5 Repetition Penalty(重复惩罚)
为减少重复生成,对已出现的 token 的 logit 施加惩罚:
logitsw={logitsw/θif w∈Glogitswotherwise\text{logits}w = \begin{cases} \text{logits}w / \theta & \text{if } w \in \mathcal{G} \\ \text{logits}w & \text{otherwise} \end{cases}logitsw={logitsw/θlogitswif w∈Gotherwise
其中 G\mathcal{G}G 是已生成的 token 集合,θ≥1\theta \geq 1θ≥1 是惩罚系数。θ=1\theta=1θ=1 表示无惩罚,θ>1\theta>1θ>1 降低已出现 token 的概率。这个操作在 softmax 之前执行,因为它直接修改 logits。
3.6 Logit 处理管线
所有解码策略可以组织为一个统一的 logit 处理管线:
原始 logits (V 维)
│
▼
1. Repetition Penalty: logits[w] /= θ for w in generated set
│
▼
2. Temperature Scaling: logits /= τ
│
▼
3. Softmax: exp(logits) / sum(exp(logits))
│
▼
4. Top-k / Top-p 截断: 将候选集外的概率置零 → 重归一化
│
▼
5. 采样/argmax: 从最终概率分布中选择一个 token
│
▼
选定的 token id
四、手算数值示例:从训练到推理
这是本文的核心------用第 35 篇训练好的极小型 Transformer,完整演示推理过程。
4.1 问题设定
沿用第 35 篇的配置,方便对比:
翻译任务:"我", "爱", "深" → "i", "love", "deep"
模型配置:
| 参数 | 值 |
|---|---|
| dmodeld_{\text{model}}dmodel | 4 |
| hhh | 1(单头) |
| dkd_kdk | 4 |
| dffd_{ff}dff | 8 |
| NencN_{\text{enc}}Nenc | 1 |
| NdecN_{\text{dec}}Ndec | 1 |
| 源词表 | {<pad>:0, 我:1, 爱:2, 深:3} |
| 目标词表 | {<pad>:0, <sos>:1, <eos>:2, i:3, love:4, deep:5} |
训练状态 :经过第 35 篇的一轮 SGD 更新,仅输出投影层 WprojW_{\text{proj}}Wproj 被训练(QKV 等其他权重仍为初始身份矩阵)。
4.2 编码器前向(引用第 33/35 篇结果)
编码器输出 Xenc∈R3×4X_{\text{enc}} \in \mathbb{R}^{3 \times 4}Xenc∈R3×4(来自第 33 篇的数值计算):
Xenc=(1.279−0.7830.655−1.151−0.3091.529−0.758−0.4620.703−0.294−0.313−0.096)←"我"←"爱"←"深" X_{\text{enc}} = \begin{pmatrix} 1.279 & -0.783 & 0.655 & -1.151 \\ -0.309 & 1.529 & -0.758 & -0.462 \\ 0.703 & -0.294 & -0.313 & -0.096 \end{pmatrix} \begin{aligned} &\leftarrow \text{"我"} \\ &\leftarrow \text{"爱"} \\ &\leftarrow \text{"深"} \end{aligned} Xenc= 1.279−0.3090.703−0.7831.529−0.2940.655−0.758−0.313−1.151−0.462−0.096 ←"我"←"爱"←"深"
这个编码器输出在推理过程中只计算一次,所有解码步共享。
4.3 训练后的输出投影矩阵
第 35 篇训练后,WprojW_{\text{proj}}Wproj 从初始值(导致全部预测错误)更新为更优的值。我们关心的三个目标 token 的 logit 变化如下:
| Token | 初始 logit(位置0) | 训练后 logit(位置0) | 变化 |
|---|---|---|---|
| i | 0.5 | 0.743 | ↑ 48.6% |
| love | 0.9 ← 错误最高 | (下降) | ↓ |
| deep | 0.1 | (下降) | ↓ |
注意:第 35 篇的梯度计算显示,位置 0 上"love"(正确答案应为"i")的梯度是负的,因此经过 SGD 后,"love"的 logit 下降,"i"的 logit 上升。三个位置的完整 logit 变化见第 35 篇第 7.2 节。
4.4 推理的数值计算
与"训练阶段一次前向算出所有位置"不同,推理需要逐 token 生成。
Step 1:生成第一个 token
解码器输入:[<sos>](token id 1)
解码器前向(与第 34-35 篇相同的计算过程)得到第 0 位置的输出表示:
D0=1.0,0.0,1.0,0.0D_0 = 1.0, 0.0, 1.0, 0.0D0=1.0,0.0,1.0,0.0
(注:由于 QKV 权重仍是身份矩阵,且交叉注意力的 K、V 来自 XencX_{\text{enc}}Xenc,D0D_0D0 的值与第 35 篇中的训练前向一致)
通过训练好的 WprojW_{\text{proj}}Wproj 计算 logits:
logits=D0⋅Wproj⊤=logitpad,logitsos,logiteos,logiti,logitlove,logitdeep\text{logits} = D_0 \cdot W_{\text{proj}}^\top = \\text{logit}_{\\text{pad}}, \\text{logit}_{\\text{sos}}, \\text{logit}_{\\text{eos}}, \\text{logit}_i, \\text{logit}_{\\text{love}}, \\text{logit}_{\\text{deep}}logits=D0⋅Wproj⊤=logitpad,logitsos,logiteos,logiti,logitlove,logitdeep
使用第 35 篇训练后的值:
| Token | Logit | Exp(logit) | Softmax | 累计概率 |
|---|---|---|---|---|
<sos> |
--- | --- | --- | --- |
<eos> |
--- | --- | --- | --- |
| i | 0.743 | 2.102 | 0.335 | 0.335 |
| love | 0.682 | 1.978 | 0.315 | 0.650 |
| deep | 0.320 | 1.377 | 0.219 | 0.869 |
| 其他 | < 0.3 | < 1.35 | < 0.131 | 1.000 |
Greedy 选择 :i(logit=0.743 最高)
采样 (τ=1\tau=1τ=1):以 33.5% 概率选 i,31.5% 概率选 love,21.9% 概率选 deep
可以看到,经过一轮训练后,"i"从初始的第 2 高提升到了第 1 高------训练方向正确。
Step 2:生成第二个 token
解码器输入:[<sos>, i](前面生成的 token 作为输入)
解码器前向得到第 1 位置的输出表示:
D1=0.0,1.0,0.0,1.0D_1 = 0.0, 1.0, 0.0, 1.0D1=0.0,1.0,0.0,1.0
(值未变,因为 QKV 仍是身份矩阵,且因果掩码使位置 1 看不到位置 0 的信息以外的位置)
这里有一个微妙之处:因果掩码确保位置 1 能看到位置 0 。但我们的 QKV 是身份矩阵,所以位置 1 的 self-attention 只是对位置 0 和位置 1 的 embedding 做了加权平均。由于初始权重下这个加权平均不会改变向量的值,D1D_1D1 实际上与训练阶段仍然相同。
通过 WprojW_{\text{proj}}Wproj 计算 logits:
| Token | 训练后 Logit | Softmax |
|---|---|---|
| i | 0.100 | 0.153 |
| love | 0.248 | 0.177 |
| deep | 0.602 | 0.253 ⬅ 仍然最高 |
| 其他 | < 0.3 | < 0.14 |
Greedy 选择 :deep!❌ 模型仍然错误。
为什么? 尽管第 35 篇的 SGD 提升了 "love" 在这个位置的 logit(0.0 → 0.248),但 "deep" 的 logit 受其他位置的梯度牵制,下降幅度不够大(从 0.7 → 0.602),仍然最高。
这里揭示了一个重要洞察:一轮 SGD 训练不足以完全纠正所有预测。这是一个"梯度竞争"现象------同一个参数在不同位置受到不同方向的梯度拉力,最终取折中。需要多轮训练让模型收敛。第 35 篇第 7.1 节中提到了这一点:"一轮训练使三个目标 logit 全面上升------趋势正确,但还不足以完全纠正预测。"
Step 3:生成第三个 token
解码器输入:[<sos>, i, love](假设 step 2 我们用了采样而非 greedy,采样到了 "love")
当然如果 step 2 选了 "deep",那 step 3 就是别的了。我们这里假设纠正了 step 2(使用采样或人为干预),演示完整的 3 步生成。
解码器前向得到第 2 位置的输出:
D2=1.0,1.0,0.0,0.0D_2 = 1.0, 1.0, 0.0, 0.0D2=1.0,1.0,0.0,0.0
通过 WprojW_{\text{proj}}Wproj 计算 logits:
| Token | 训练后 Logit | Softmax |
|---|---|---|
| i | 0.443 | 0.174 |
| love | 0.522 | 0.191 |
| deep | 0.801 | 0.248 |
| 其他 | < 0.5 | < 0.15 |
Greedy 选择 :deep ✓ 正确。
4.5 推理全过程数值一览
| 步骤 | 输入 | 正确输出 | Greedy 选择 | 采样可能的输出 | 置信度 |
|---|---|---|---|---|---|
| 1 | <sos> |
i | i ✓ | i, love, deep | 33.5% |
| 2 | [sos], i |
love | deep ❌ | love, deep, i | 17.7% |
| 3 | [sos], i, love |
deep | deep ✓ | deep, love, eos | 24.8% |
关键结论 :经过一轮训练后,位置 1 的预测仍然错误------这完美解释了为什么训练需要多个 epoch。如果使用 beam search(k=2k=2k=2),模型可以保留 [i, love] 和 [i, deep] 两个候选,在 step 3 发现 [i, love, deep] 的总分更高,从而纠正 greedy 的错误。
4.6 解码策略对比(同一组 logits)
假设 Step 1 的 logits 为 [0.1, 0.0, 0.1, 0.743, 0.682, 0.320],观察不同策略的结果:
Greedy :argmax → token 3(i)------ 确定性输出
Temperature 采样:
| τ\tauτ | 概率分布 | 效果 |
|---|---|---|
| 0.5 | 0.01, 0.01, 0.01, 0.54, 0.37, 0.06 | 几乎一定选 i |
| 1.0 | 0.03, 0.03, 0.03, 0.34, 0.31, 0.22 | i 概率略高,但仍可能选 love |
| 2.0 | 0.08, 0.07, 0.08, 0.24, 0.23, 0.19 | 分布平坦,各种选择都有可能 |
Top-3 + 采样:
- 候选 token: 3(i), 4(love), 5(deep)
- 重归一化后: 0.38, 0.36, 0.26
- 避免了尾部噪声 token 的干扰
五、推理代码实现
下面将推理过程转化为 PyTorch 代码。我们基于第 36-38 篇的 transformer_modules.py 实现。
5.1 Logit 处理函数
python
import torch
import torch.nn.functional as F
def apply_repetition_penalty(logits, generated_ids, penalty=1.0):
"""对已生成的 token 施加重复惩罚。"""
if penalty == 1.0:
return logits
for token_id in generated_ids:
logits[token_id] /= penalty
return logits
def apply_temperature(logits, temperature=1.0):
"""温度缩放。"""
if temperature == 0:
# 温度趋近 0 ≈ 贪心
return logits # caller 处理
return logits / temperature
def apply_top_k(logits, k=50):
"""只保留概率最高的 k 个 token。"""
if k <= 0 or k >= logits.size(-1):
return logits
values, _ = torch.topk(logits, k, dim=-1)
threshold = values[..., -1].unsqueeze(-1)
logits[logits < threshold] = float('-inf')
return logits
def apply_top_p(logits, p=0.9):
"""保留累计概率达到 p 的最小 token 集合。"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 找到需要保留的 token 索引
sorted_mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= p
sorted_indices_to_remove = sorted_mask.clone()
sorted_indices_to_remove[..., 1:] = sorted_mask[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
# 将移除的 token 的 logit 设为 -inf
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = float('-inf')
return logits
def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=1.0,
repetition_penalty=1.0, generated_ids=None):
"""完整的 logit 处理管线。"""
# 1. 重复惩罚
if generated_ids is not None and repetition_penalty > 1.0:
logits = apply_repetition_penalty(logits, generated_ids, repetition_penalty)
# 2. 温度缩放
if temperature > 0:
logits = apply_temperature(logits, temperature)
# 3. Probs
probs = F.softmax(logits, dim=-1)
# 4. Top-k
if top_k > 0:
logits = apply_top_k(logits, top_k)
probs = F.softmax(logits, dim=-1)
# 5. Top-p
if top_p < 1.0:
logits = apply_top_p(logits, top_p)
probs = F.softmax(logits, dim=-1)
# 6. 采样
if temperature == 0:
# 贪心
return torch.argmax(logits, dim=-1)
else:
return torch.multinomial(probs, num_samples=1)
5.2 生成循环:自回归推理
下面实现完整的 generate() 函数,它接收一个源序列,返回模型生成的目标序列:
python
@torch.no_grad() # 推理的关键:不追踪梯度
def generate(model, src_ids, max_len=50, bos_token_id=1, eos_token_id=2,
temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0):
"""
完整的 Transformer 推理函数。
参数:
model: 训练好的 Transformer 模型
src_ids: 源序列 token ids, shape (1, src_len)
max_len: 最大生成长度
temperature / top_k / top_p: 解码策略参数
返回:
generated_ids: 生成的 token id 序列
"""
model.eval() # 关闭 Dropout, BatchNorm 等训练行为
device = next(model.parameters()).device
src_ids = src_ids.to(device)
batch_size = src_ids.size(0)
# ========== 阶段 1:编码器前向(一次性) ==========
encoder_output = model.encode(src_ids) # (1, src_len, d_model)
# ========== 阶段 2:解码器自回归 ==========
# 从 <sos> 开始
generated = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
for step in range(max_len):
# 解码器前向:输入当前已生成的序列
decoder_output = model.decode(generated, encoder_output) # (1, step+1, d_model)
# 取最后位置的输出
last_hidden = decoder_output[:, -1:, :] # (1, 1, d_model)
# 输出投影 → logits
logits = model.output_projection(last_hidden) # (1, 1, vocab_size)
logits = logits.squeeze(1) # (1, vocab_size)
# 采样
next_token = sample_from_logits(
logits,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
generated_ids=generated[0].tolist(),
) # (1, 1)
# 拼接
generated = torch.cat([generated, next_token], dim=-1)
# 遇到 <eos> 停止
if next_token.item() == eos_token_id:
break
return generated[0] # 返回 1D 序列
5.3 关键细节:为什么用 @torch.no_grad()
这行注解是推理代码最重要的细节之一。它告诉 PyTorch 不要构建计算图,从而:
- 节省显存:不保留中间激活值
- 加速计算:省去了构建计算图的开销
- 语义正确:推理时不需要梯度
对比:
python
# 训练模式:会构建计算图,占用额外显存
output = model(x) # 保留中间值,loss.backward() 可用
with torch.no_grad():
# 推理模式:不构建计算图
output = model(x) # 只做前向,中间值用完即释放
5.4 完整推理示例
python
def translate_example():
"""用训练好的模型翻译"我爱深度学习"。"""
from transformer_modules import Transformer
# 加载训练好的模型(第 38 篇保存的权重)
model = Transformer(
src_vocab_size=1106, # 中文词表
tgt_vocab_size=54, # 英文词表
d_model=32,
n_heads=4,
d_ff=128,
n_layers=3,
)
model.load_state_dict(torch.load('transformer_trained.pt'))
model.cuda()
model.eval()
# 编码源句(来自 data_utils.py)
src_text = "我爱深度学习"
src_ids = encode_line(src_text, zh_token2idx, lang="zh")
src_tensor = torch.tensor([src_ids]).cuda()
# 生成翻译
result_ids = generate(
model,
src_tensor,
temperature=0.7, # 稍微降低随机性
top_k=40,
top_p=0.9,
repetition_penalty=1.2,
)
# 解码
result_text = decode_ids(result_ids.cpu().tolist(), en_idx2token)
print(f"源句: {src_text}")
print(f"翻译: {result_text}")
5.5 不同解码策略的效果对比
在训练好的模型上,对同一句"我爱深度学习"用不同策略生成:
| 策略 | 参数 | 生成结果 | 特点 |
|---|---|---|---|
| Greedy | τ=0\tau=0τ=0 | i love deep learning |
确定、但可能平庸 |
| 采样 | τ=1\tau=1τ=1 | i love deep learning |
有多种可能 |
| 采样 | τ=1.5\tau=1.5τ=1.5 | i adore profound study |
多样性高,但也可能跑偏 |
| Beam-3 | k=3k=3k=3 | i love deep learning |
最稳定,接近最优 |
| Top-p | p=0.9,τ=1p=0.9, \tau=1p=0.9,τ=1 | i love deep learning |
平衡多样性与质量 |
六、推理的性能分析
6.1 推理的计算量
对于一个长度为 LLL 的序列,推理的计算量:
编码器 (一次性):O(L2)O(L^2)O(L2) 注意力计算
解码器 (逐 token):第 ttt 步的计算量 O(t)O(t)O(t),总计算量 ∑t=1TO(t)=O(T2)\sum_{t=1}^{T} O(t) = O(T^2)∑t=1TO(t)=O(T2)
总推理计算量 :O(L2+T2)O(L^2 + T^2)O(L2+T2),其中 LLL 是源长度,TTT 是目标长度。
对比训练的每轮 计算量:O(L2+T2)O(L^2 + T^2)O(L2+T2)(前向)+ O(L2+T2)O(L^2 + T^2)O(L2+T2)(反向)= O(L2+T2)O(L^2 + T^2)O(L2+T2)
推理 = 编码器 O(L²) + 解码器 O(T²)
训练 = 编码器 O(L²) + 解码器 O(T²) + 反向传播 O(L²+T²)
≈ 推理 × 3 (训练一次前向+反向约等于 3 次推理前向)
6.2 推理与训练的显存对比
以 d_model=1024, n_heads=8, seq_len=1024 为例:
| 阶段 | 需要保留的张量 | 估算显存 |
|---|---|---|
| 训练前向 | QKV 投影、注意力分数、FFN 激活、LN 统计量等 | ~10 倍模型参数 |
| 训练反向 | 所有中间激活值 + 梯度 | ~20 倍模型参数 |
| 推理 | 只需当前激活值(用完即弃) | ~2 倍模型参数 |
所以推理的显存通常只有训练的 1/10 到 1/5。
6.3 推理的瓶颈
虽然推理显存需求低,但速度是瓶颈:
瓶颈 1:自回归串行性(本文的核心问题)
解码器每步只能生成一个 token,无法并行。生成 100 个 token 需要 100 次串行前向。
瓶颈 2:重复计算
回忆 Step 1 和 Step 2 的差异:
- Step 1:解码器处理
[<sos>] - Step 2:解码器处理
[<sos>, i]
在 Step 2 中,位置 0 的 K 和 V 与 Step 1 完全一样------但 Step 2 重新算了一遍。
这就是 KV Cache 要解决的问题------下一篇(第 41 篇)的主角。
推理的冗余(以位置 1 为例):
Step 1: Q₀=W_Q·x₀, K₀=W_K·x₀, V₀=W_V·x₀ ← 已计算
Step 2: Q₀=W_Q·x₀, K₀=W_K·x₀, V₀=W_V·x₀ ← 重复!
Q₁=W_Q·x₁, K₁=W_K·x₁, V₁=W_V·x₁ ← 新的,必须算
如果不缓存 :解码 TTT 步的计算量是 O(T2)O(T^2)O(T2)
如果缓存 :解码 TTT 步的计算量是 O(T)O(T)O(T)
对于 T=1000T=1000T=1000,这意味着 1000 倍的差距。
七、总结
推理 vs 训练的核心差异
- 不需要梯度 ------推理用
torch.no_grad(),不构建计算图 - 串行生成------自回归的数学本质是条件概率的链式分解,每一步依赖上一步的输出
- 解码策略------从 argmax 到 temperature 采样到 top-p 截断,所有策略都是 logits 后处理
- 显存远低于训练------约 1/5 到 1/10
- 有计算冗余------重复的 QKV 投影是下一篇文章的起点
推理过程的数学全景
源序列 [x₁, x₂, ..., xₘ]
│
▼ (一次性)
编码器 ──→ Encoder Output (m × d_model)
│
├─────────────────────────────────────────┐
│ 交叉注意力的 K, V │
▼ ▼
解码器 解码器(下一个step)
输入: [<sos>] 输入: [<sos>, y₁]
│ │
▼ ▼
Decoder Output Decoder Output
│ │
▼ ▼
Output Proj Output Proj
│ │
▼ (logit 处理管线) ▼
logits → temperature → top-k/p → argmax/sample ← 重复直到 <eos>
│
▼
y₁ = "i" ───────────────────────────→ 作为 step 2 输入一部分