模型推理prefill和decode过程

用一个完整、清晰、技术准确但避免冗余公式 的例子,系统讲解 LLM(以 LLaMA 为例)在 vLLM 框架下如何完成 prefill 和 decode 的全过程,包括:

  • 隐藏状态生成
  • Q/K/V 的计算
  • KV 缓存机制
  • 注意力分数、softmax、加权求和的含义
  • prefill 与 decode 的区别与衔接

示例输入

用户输入:

"hello, who are you"

假设 tokenizer 将其分为 5 个 token
[t₁=hello, t₂=,, t₃=who, t₄=are, t₅=you]

模型目标:生成回答,如 "I am an AI."

我们聚焦 某一层的一个注意力头(head_dim = 128),说明从输入到生成第一个输出 token("I")的完整流程。


第一阶段:Prefill(预填充)

目标

处理完整 prompt,计算并缓存所有 token 的 Key(K)和 Value(V),并预测第一个输出 token。


步骤 1:生成每个 token 的隐藏状态

对每个 token ID:

  1. 查 embedding 表 → 得到基础向量(如 embed(hello) ∈ ℝ¹²⁸
  2. 加入位置信息(使用 RoPE)→ 得到带位置的隐藏状态:
    • ( x_1, x_2, x_3, x_4, x_5 \in \mathbb{R}^{128} )

这些 ( x_i ) 是当前 attention 层的输入。


步骤 2:计算 Q、K、V(并行)

对每个 ( x_i ),分别乘以该层的三个权重矩阵:

q_i = x_i W_Q,\\quad k_i = x_i W_K,\\quad v_i = x_i W_V

得到:

  • Queries: ( Q = [q_1, q_2, q_3, q_4, q_5] )
  • Keys: ( K = [k_1, k_2, k_3, k_4, k_5] )
  • Values: ( V = [v_1, v_2, v_3, v_4, v_5] )

✅ 所有计算一次性完成(并行),因为输入已知。


步骤 3:为最后一个 token(t₅="you")计算注意力输出

我们要用它来预测下一个词,所以重点看 ( q_5 )。

  1. 计算相似度(点积)

    ( q_5 ) 与每个 ( k_j )(j=1..5)做点积,得到原始分数:

    s_j = q_5 \\cdot k_j

  2. 缩放

    除以 ( \sqrt{128} ) 防止数值过大 → ( \hat{s}_j = s_j / \sqrt{128} )

  3. Softmax 归一化

    将 ( \hat{s}_j ) 转为概率分布:

    a_j = \\frac{\\exp(\\hat{s}*j)}{\\sum* {i=1}\^5 \\exp(\\hat{s}_i)}

    • 结果:( [a_1, a_2, a_3, a_4, a_5] ),总和为 1
    • 含义:模型认为当前位置应从各历史 token 中"关注"多少信息
  4. 加权求和

    用权重 ( a_j ) 对 Value 向量加权平均:

    o_5 = a_1 v_1 + a_2 v_2 + a_3 v_3 + a_4 v_4 + a_5 v_5

    • 输出 ( o_5 \in \mathbb{R}^{128} ) 是融合上下文后的表示

步骤 4:缓存 K 和 V

将整个 ( K = [k_1..k_5] ) 和 ( V = [v_1..v_5] ) 存入 KV Cache

在 vLLM 中,这些被写入一个或多个内存块(block),由 block table 管理,支持非连续存储。


步骤 5:生成第一个输出 token

  • ( o_5 ) 经过多头拼接、FFN、最终投影等,得到 logits 向量(维度=词表大小,如 32000)
  • 对 logits 做 softmax → 概率分布
  • 采样(如 greedy)→ 得到最高概率的 token:"I"(记为 t₆)

Prefill 阶段结束。KV Cache 已包含 prompt 的全部 K/V。


第二阶段:Decode(解码)

目标

逐个生成后续 token,每次只处理一个新 token,复用历史 KV。


Decode Step 1:生成 t₇(在 "I" 之后)

输入:t₆ = "I"
  1. 生成隐藏状态

    • 查 embedding + RoPE(位置=5)→ ( x_6 \in \mathbb{R}^{128} )
  2. 计算当前 Q/K/V

    • ( q_6 = x_6 W_Q )
    • ( k_6 = x_6 W_K )
    • ( v_6 = x_6 W_V )
  3. 读取历史 KV

    • 从 KV Cache 读出:( [k_1..k_5], [v_1..v_5] )
  4. 拼接完整上下文

    • Full Keys: ( [k_1, k_2, k_3, k_4, k_5, k_6] )
    • Full Values: ( [v_1, v_2, v_3, v_4, v_5, v_6] )
  5. 计算注意力(以 ( q_6 ) 为中心)

    • 点积:( q_6 \cdot k_j ) for j=1..6
    • 缩放 → softmax → 权重 ( [a_1..a_6] )
    • 加权求和:( o_6 = \sum_{j=1}^6 a_j v_j )
  6. 更新 KV Cache

    • 将 ( k_6, v_6 ) 追加到缓存(现在长度=6)
  7. 生成下一个 token

    • ( o_6 ) → logits → softmax → 采样 → "am"(t₇)

Decode Step 2+:继续生成

重复上述过程:

  • 输入 "am" → 计算 q₇/k₇/v₇
  • 读 cache(k₁--k₆, v₁--v₆)
  • 拼接 → attention → 输出 o₇
  • 更新 cache(加入 k₇/v₇)
  • 采样 → "an" → "AI" → "." → <eos>

每一步都只新增一个 token 的 K/V ,但 attention 始终基于完整历史上下文


关键对比总结

项目 Prefill 阶段 Decode 阶段
输入 完整 prompt(多 token) 单个新 token
Q/K/V 计算 并行计算所有 只计算当前 token
KV 来源 全部来自当前输入 当前新算 + 历史从 cache 读
Attention 范围 每个 qᵢ 看 k₁--kᵢ qₜ 看 k₁--kₜ(含 prompt + 已生成)
KV Cache 操作 批量写入 逐个追加
计算特点 计算密集、高并行 内存密集、低并行、依赖缓存

vLLM 的优化亮点

  • PagedAttention:KV Cache 被分页存储(如每页 16 token),避免内存碎片
  • Continuous Batching:多个请求的 decode 步可动态合并成 batch,提升 GPU 利用率
  • 高效缓存管理:Block Manager 自动分配/释放 blocks,支持 swap to CPU

总结:整个链条

复制代码
token ID 
   ↓ (embedding + RoPE)
hidden state x_i 
   ↓ (× W_Q, W_K, W_V)
q_i, k_i, v_i 
   ↓ (q 与所有可见 k 点积 → 缩放)
raw scores 
   ↓ (softmax)
attention weights a_j 
   ↓ (∑ a_j · v_j)
output vector o_i 
   ↓ (多头拼接 + FFN + ...)
logits → next token

Q&A

第一步:每个 token 的隐藏状态是如何生成的?

1. Token Embedding

每个 token ID 通过 embedding lookup 转为向量:

  • embed(1540) → e₁ ∈ ℝ¹²⁸
  • embed(11) → e₂ ∈ ℝ¹²⁸
  • embed(2345) → e₃ ∈ ℝ¹²⁸
  • embed(456) → e₄ ∈ ℝ¹²⁸
  • embed(789) → e₅ ∈ ℝ¹²⁸

✅ 这些 embedding 向量是模型训练好的参数,每个 token 对应一个固定向量(在推理时查表即可)。

2. 加入位置编码(Positional Encoding)

因为 Transformer 本身没有顺序概念,需加入位置信息。

  • 使用 RoPE(Rotary Position Embedding)(LLaMA 采用的方式)
  • 对每个位置 ii ,将 eiei 旋转一定角度,得到带位置信息的向量

例如:

  • x1=RoPE(e1,pos=0)x1=RoPE(e1,pos=0)
  • x2=RoPE(e2,pos=1)x2=RoPE(e2,pos=1)
  • ...
  • x5=RoPE(e5,pos=4)x5=RoPE(e5,pos=4)

🔸 结果: x1,x2,...,x5∈R128x1​,x2​,...,x5​∈R128 ------ 这就是每个 token 的隐藏状态(hidden state),作为当前 attention 层的输入。
⚠️ 注意:在完整模型中,这些 xixi​ 可能已经经过前面若干层(如 LayerNorm + previous attention/FFN),但为简化,我们假设这是第一层的输入。


第二步:如何通过线性变换得到 Q、K、V?

每个隐藏状态 xi∈R128xi​∈R128 分别乘以三个可学习的权重矩阵

  • WQ∈R128×128WQ∈R128×128
  • WK∈R128×128WK∈R128×128
  • WV∈R128×128WV∈R128×128

💡 这些矩阵是模型训练时学出来的,在推理时固定不变。

计算示例(以 token 1 为例):

q1=x1⋅WQ⇒q1∈R128q1​=x1​⋅WQ​⇒q1​∈R128

k1=x1⋅WK⇒k1∈R128k1​=x1​⋅WK​⇒k1​∈R128

v1=x1⋅WV⇒v1∈R128v1​=x1​⋅WV​⇒v1​∈R128

对所有 5 个 token 并行计算:

  • Q=[q1;q2;q3;q4;q5]∈R5×128Q=[q1;q2;q3;q4;q5]∈R5×128
  • K=[k1;k2;k3;k4;k5]∈R5×128K=[k1;k2;k3;k4;k5]∈R5×128
  • V=[v1;v2;v3;v4;v5]∈R5×128V=[v1;v2;v3;v4;v5]∈R5×128

✅ 这就是 Q、K、V 矩阵的来源:每个 token 的隐藏状态分别做三次不同的线性映射


第三步:相似度计算(点积)

目标:让每个 Query 去"匹配"所有它能看到的 Key,看哪些历史 token 更相关。

以最后一个 token("you",即位置 5)为例:

它的 Query 是 q5∈R128q5​∈R128

它要和所有 k1k1​ 到 k5k5​ 计算相似度(因为是 causal,不能看未来)。

计算点积(未缩放):

sj=q5⋅kj=∑d=1128q5(d)⋅kj(d)for j=1,2,3,4,5sj​=q5​⋅kj​=d=1∑128​q5(d)​⋅kj(d)​for j=1,2,3,4,5

假设结果为(虚构数值):

  • s1=12.3s1=12.3
  • s2=8.7s2=8.7
  • s3=15.1s3=15.1
  • s4=14.9s4=14.9
  • s5=16.0s5=16.0
缩放(Scale by √d):

为防止点积过大导致 softmax 梯度消失,除以 128≈11.31128​≈11.31

s^j=sj/128s^j​=sj​/128​

得到:

  • s^=[1.09,0.77,1.33,1.32,1.41]s^=[1.09,0.77,1.33,1.32,1.41]

✅ 这些就是未归一化的注意力分数,表示 "q₅ 与每个 kⱼ 的匹配程度"。


第四步:Softmax 函数在做什么?

Softmax 把这些分数转换成概率分布(总和为 1),表示"应该从每个历史位置取多少信息"。

aj=exp⁡(s^j)∑i=15exp⁡(s^i)aj​=∑i=15​exp(s^i​)exp(s^j​)​

继续用上面的数值(近似计算):

  • exp⁡(s^)≈[2.97,2.16,3.78,3.74,4.10]exp(s^)≈[2.97,2.16,3.78,3.74,4.10]
  • 总和 ≈ 16.75
  • 所以:
    • a1≈2.97/16.75≈0.177a1≈2.97/16.75≈0.177
    • a2≈0.129a2≈0.129
    • a3≈0.226a3≈0.226
    • a4≈0.223a4≈0.223
    • a5≈0.245a5≈0.245

✅ 这些 ajaj​ 就是注意力权重:模型认为当前位置("you")最关注自己(a₅ 最大),其次关注 "who" 和 "are"。


第五步:加权求和(得到注意力输出)

用这些权重对 Value 向量做加权平均:

o5=∑j=15aj⋅vjo5​=j=1∑5​aj​⋅vj​

即:

o5=0.177⋅v1+0.129⋅v2+0.226⋅v3+0.223⋅v4+0.245⋅v5o5​=0.177⋅v1​+0.129⋅v2​+0.226⋅v3​+0.223⋅v4​+0.245⋅v5​

结果是一个新的向量 o5∈R128o5​∈R128 ,它融合了所有历史 token 的信息,且更侧重于相关部分。

✅ 这个 o5o5​ 就是该注意力头对第 5 个 token 的输出。


后续步骤(简述)

  1. 如果是多头注意力(32 头),会把 32 个头的输出拼接起来(32 × 128 = 4096)
  2. 经过 output projection 矩阵 WO∈R4096×4096WO∈R4096×4096 得到最终 attention 输出
  3. 再经过残差连接、LayerNorm、FFN 等,最终得到 logits
  4. logits 经过 softmax → 概率分布 → 采样出下一个 token(如 "I")

Prefill vs Decode 的关键区别(在此例中)

步骤 Prefill(处理 "hello...you") Decode(生成 "I")
输入 5 个 token 同时输入 1 个 token("I")
Q/K/V 计算 并行计算 5 组 只计算 1 组(q₆, k₆, v₆)
K/V 来源 全部来自当前输入 k₆/v₆ 新算 + k₁--k₅/v₁--v₅ 从 cache 读
Attention 范围 q₅ 看 k₁--k₅ q₆ 看 k₁--k₆
KV Cache 写入 k₁--k₅, v₁--v₅ 追加 k₆, v₆
相关推荐
青铜弟弟2 小时前
数据同化 - 机器学习(DA-ML)融合的区域玉米估产模型构建:步骤、原理与细节解析
机器学习·文献学习·物理机制
X54先生(人文科技)2 小时前
20260212_Meta-CreationPower_Development_Log(启蒙灯塔起源团队开发日志)
人工智能·机器学习·架构·团队开发·零知识证明
ViiTor_AI2 小时前
视频字幕怎么去除?5 种方法删除硬编码字幕与软字幕(CapCut 实操)
人工智能·计算机视觉·音视频
咚咚王者2 小时前
人工智能之视觉领域 计算机视觉 第三章 NumPy 与图像矩阵
人工智能·计算机视觉·numpy
天天进步20152 小时前
赋予 AI “手”的能力:使用 OpenClaw 自动化执行 Shell 脚本与浏览器任务
人工智能
百度智能云技术站2 小时前
百度百舸 Day0 完成昆仑芯和智谱 GLM-5 适配,实现「发布即可用」
人工智能·开源·vllm·百度百舸
曦云沐2 小时前
第六篇:LangChain 1.0 消息系统与 Prompt 工程:从入门到精通的完整教程
人工智能·langchain·prompt·大模型开发框架
格林威2 小时前
Baumer相机玻璃纤维布经纬密度测量:用于复合材料工艺控制的 6 个核心方法,附 OpenCV+Halcon 实战代码!
人工智能·opencv·计算机视觉·视觉检测·工业相机·智能相机·堡盟相机