深度分析字节最新研究cola-DLM 第 07 章:推理流水线逐行拆解 —— 从 prompt 到生成文本

第 07 章:推理流水线逐行拆解 ------ 从 prompt 到生成文本

论文Continuous Latent Diffusion Language Model
项目地址ByteDance-Seed/Cola-DLM
源码inference.py

核心困惑generate_task_repaint_inference 这 454 行代码到底在做什么?


一、宏观流程

推理严格对应论文式 2.2.4--2.2.6 的三步流程:

复制代码
输入 prompt: "The capital of France is"
        │
        ▼
  ┌─────────────┐
  │ Step 1: 分词  │  tokenizer.encode(prompt)
  │ + block 对齐  │  pad 到 patch_size*block_size 的倍数
  └──────┬──────┘
         │
         ▼
  ┌─────────────┐
  │ Step 2: 前缀  │  z_pre = vae.encode(input_ids)
  │ 编码         │  + latent label 分类
  └──────┬──────┘
         │
         ▼
  ┌─────────────────────────────┐
  │ Step 3: 分块先验传输          │  for block in blocks:
  │ (DiT + CFG + Euler ODE)     │    noise → DiT → z_0
  └──────┬──────────────────────┘
         │
         ▼
  ┌─────────────┐
  │ Step 4: 条件  │  logits = vae.decode(z_0)
  │ 解码 + 采样   │  tokens = sample(logits)
  └──────┬──────┘
         │
         ▼
  输出: "Paris"

二、Step 1:分词与 block 对齐

代码位置:inference.py:367-397

python 复制代码
chunk = patch_size * block_size  # 1 * 4 = 4
for item in prompts:
    ids = tokenizer.encode(prompt_str).ids

    # 对齐到 chunk 的整数倍
    p_pad_len = (chunk - len(ids) % chunk) % chunk
    t_labels = [1] * len(ids) + [3] * p_pad_len  # 1=prompt, 3=pad
    ids = ids + [pad_token_id] * p_pad_len

latent label 分类

  • 1 = prompt token(需要被编码)
  • 2 = 待生成 token
  • 3 = 硬 pad(对齐用,后续被裁掉)

三、Step 2:前缀编码

代码位置:inference.py:404-408

python 复制代码
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    enc = vae.encode(input_ids_list)
    latents_list = [((lat - shift) * scale).float() for lat in enc.latents_list]
  • vae.encode() 返回 TextVAEEncoderOutput,包含 latents_list(每样本的隐向量)
  • 隐向量做缩放 (z - shift) * scale,转为 fp32

四、Step 3:分块先验传输(核心循环)

4.1 噪声初始化

代码位置:inference.py:578-606

python 复制代码
# 默认:全局随机噪声
txt = torch.randn(batch_size * block_size, latent_dim, device=device)

# 可选:per-sample 确定性噪声(通过环境变量启用)
if per_sample_seed_env:
    for b, item in enumerate(prompts):
        g = torch.Generator(device=device)
        g.manual_seed(base_seed + sid_int * 1_000 + step * 10_000_000)
        noise_3d[b] = torch.randn(block_size, latent_dim, device=device, generator=g)

4.2 Euler ODE 求解

代码位置:inference.py:608-654

python 复制代码
timesteps = torch.linspace(int(T), 0, timestep_num + 1)  # [1000, 937.5, ..., 0]

for t_curr, t_next in zip(timesteps[:-1], timesteps[1:]):
    ts_batch = torch.full((txt.shape[0],), t_curr, device=device)
    dt = (float(t_curr) - float(t_next)) / max(T, 1.0)

    # 条件前向(用 KV cache)
    drift_cond = dit(txt=txt_bf16, txt_shape=txt_shape_cum,
                     txt_q_shape=txt_q_shape, timestep=ts_bf16,
                     update_kv=False, use_kv_cache=True).txt_sample

    # 无条件前向(不用 cache)
    drift_uncond = dit(txt=txt_bf16, txt_shape=txt_q_shape,
                       txt_q_shape=txt_q_shape, timestep=ts_bf16,
                       update_kv=False, use_kv_cache=False).txt_sample

    # CFG 融合
    s = cfg_scale_first_block if step == 0 else guidance_scale
    drift = s * (drift_cond - drift_uncond) + drift_uncond

    # Euler 更新: z_{t-Δ} = z_t - (Δ/T) * v_ψ
    txt_next = txt - drift * dt
    txt = txt_next

关键细节:每个 block 需要做 16 步 × 2 次(cond + uncond)= 32 次 DiT 前向。

4.3 KV Cache 提交

代码位置:inference.py:690-700

python 复制代码
# 把刚去噪的 block 提交到 DiT KV cache
txt_bf16 = txt.to(torch.bfloat16)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    _ = dit(txt=txt_bf16, txt_shape=txt_shape_cum,
            txt_q_shape=txt_q_shape,
            timestep=torch.zeros(...),
            update_kv=True,   # ← 关键:提交到 cache
            use_kv_cache=True)

五、Step 4:条件解码 + 采样

代码位置:inference.py:656-675

python 复制代码
with torch.autocast("cuda", dtype=torch.bfloat16):
    decoded = vae.decode(z=txt, txt_shape=txt_shape_cum,
                         txt_q_shape=txt_q_shape, update_kv=True)

decoded_logits = decoded.view(batch_size, block_size * patch_size, -1)
one_block_ids = sample_with_strategies(
    decoded_logits, generated_ids=context_ids,
    temperature=temperature, top_k=top_k, top_p=top_p,
    repetition_penalty=repetition_penalty,
)

5.1 采样策略

代码位置:inference.py:211-266

python 复制代码
def sample_with_strategies(logits, generated_ids=None, temperature=0.8,
                           top_k=50, top_p=0.95, repetition_penalty=1.1):
    # 1. 重复惩罚
    if repetition_penalty != 1.0:
        score = torch.gather(logits, 1, target_gen_ids)
        score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
        logits.scatter_(1, target_gen_ids, score)

    # 2. 温度缩放
    if temperature != 1.0:
        logits = logits / temperature

    # 3. Top-k 过滤
    if top_k > 0:
        top_k_values, _ = torch.topk(logits, top_k)
        logits = torch.where(logits < min_values, float("-inf"), logits)

    # 4. Top-p (nucleus) 过滤
    if 0 < top_p < 1.0:
        # 按概率累积截断
        cumulative_probs = torch.softmax(sorted_logits).cumsum()
        ...

    # 5. 采样
    probs = F.softmax(logits, dim=-1)
    next_tokens = torch.multinomial(probs, num_samples=1)

六、Prompt 模板

代码位置:inference.py:80-203

每个 benchmark 任务有特定的 few-shot 模板:

python 复制代码
def apply_prompt_template(task, context, question, answer, choices):
    if task == "lambada":
        return question  # 直接用原文
    elif task == "squad":
        return f"Context: ...\nQuestion: {question}\nAnswer:"
    elif task == "mmlu":
        return f"Question: {question}\n{choices_text}\nAnswer:"
    # ... 每个任务有不同的模板

注意:模板对生成质量影响很大。论文中的准确率是在特定模板下测得的。


七、循环终止条件

代码位置:inference.py:682-706

python 复制代码
# EOS 检测
for b in range(one_block_ids.shape[0]):
    if eos_token_id is not None and eos_token_id in one_block_ids[b]:
        eos_status[b] = True
if eos_status.all():
    stop_flag = True

# max_new_tokens 检测
if step * block_size * patch_size >= max_new_tokens:
    stop_flag = True

八、完整流程图

复制代码
prompt: "The capital of France is"
    │
    ▼
分词: [464, 3139, 286, 4881, 318]  (5 tokens)
    │
    ▼
block 对齐: [464, 3139, 286, 4881, 318, 100277, 100277, 100277]  (→ 8 tokens, 2 blocks)
labels:     [1,    1,    1,   1,    1,   3,       3,       3    ]
    │
    ▼
VAE encode: z_pre shape = (8, 16)  → 裁掉 label-3 尾部 → (5, 16)
    │
    ▼
latent label: [1, 1, 1, 1, 1]  (5 个 prompt latent)
prefix:       z_pre[:4]  (前 4 个,block 对齐)
first block:  z_pre[4:] + random pad  (第 5 个 + 3 个噪声)
    │
    ▼
DiT prefix prefetch: 把 prefix 写入 KV cache
    │
    ▼
Block 1 生成:
  noise ~ N(0, I)  →  16 步 Euler ODE  →  z_0^(1)  →  VAE decode  →  logits  →  "Paris"
  提交 z_0^(1) 到 KV cache
    │
    ▼
Block 2 生成(如果需要更多 token):
  noise ~ N(0, I)  →  16 步 Euler ODE  →  z_0^(2)  →  VAE decode  →  logits  →  ...
    │
    ▼
输出: "Paris"

九、面试追问清单

基础(⭐)

  1. 推理的三步流程分别对应论文的哪些公式?
  2. latent label 的 1/2/3 分别代表什么?
  3. 为什么需要 block 对齐?

进阶(⭐⭐)

  1. CFG 的条件前向和无条件前向有什么区别?
  2. per-sample noise seed 的作用是什么?
  3. 为什么第一个 block 的 CFG scale 可能被设为 1.0?

专家(⭐⭐⭐)

  1. 每个 block 的 16 步 Euler ODE 求解能否减少步数?trade-off 是什么?
  2. KV cache 的 update_kv 和 use_kv_cache 的组合逻辑是什么?
  3. 如果 prompt 长度正好是 block_size 的整数倍,推理流程有什么简化?

十、下期预告

下一章我们将从工程角度评析 Cola DLM 的实现------哪些设计值得学习,哪些需要改进,以及服务化部署的短板。


系列导航

第 01 章 · 第 02 章 · 第 03 章 · 第 04 章 · 第 05 章 · 第 06 章

第 07 章:推理流水线逐行拆解 ← 你在这里

第 08 章 · 第 09 章 · 第 10 章


作者Yunzenn

相关推荐
火山引擎开发者社区13 分钟前
龙虾突然“罢工”?别慌,我们派出了“AI 医生”
人工智能
醒醒该学习了!17 分钟前
Prompt提示词——文生文提示词的设计策略(理论篇)
prompt
NQBJT17 分钟前
青鸾云步:基于 Cordova 的 AI 导盲机器人 APP 全栈开发实战
人工智能·app·导盲·轮足机器人·青鸾云步
深兰科技1 小时前
韩国KAIST AI半导体高管项目代表团到访深兰科技,聚焦AI算力与智能产业合作机会
人工智能·机器人·symfony·ai算力·深兰科技·韩国科学技术院·kaist
快乐on9仔1 小时前
NLP学习(一)transformers之pipeline体验
人工智能·深度学习
冬奇Lab1 小时前
Agent系列(六):记忆管理——让 Agent 记住重要的事
人工智能·agent
冬奇Lab1 小时前
一天一个开源项目(第113篇):notebooklm-py - 把 Google NotebookLM 变成可编程 API,还能接入 Claude Code
人工智能·google·开源
字节跳动开源2 小时前
Viking AI 搜索 CLI 正式发布:会说话,就能做搜索推荐
数据库·人工智能·开源
阿杰技术2 小时前
AI 编程助手落地实战:从提效到重构的全场景指南
人工智能·重构