深度分析字节最新研究cola-DLM第 06 章:分块因果 DiT 先验 —— 在隐空间里做 Flow Matching

第 06 章:分块因果 DiT 先验 ------ 在隐空间里做 Flow Matching

论文Continuous Latent Diffusion Language Model
项目地址ByteDance-Seed/Cola-DLM
源码modeling_cola_dit.py

核心困惑 :DiT 是怎么学习隐空间先验 p ψ ( z 0 ) p_\psi(z_0) pψ(z0) 的?分块因果注意力怎么实现?CFG 怎么工作?


一、DiT 在图像领域的成功

DiT(Diffusion Transformer)由 Scalable Diffusion Models with Transformers(Peebles & Xie, 2023)提出,用 Transformer 替代 UNet 作为扩散模型的骨干网络。核心思想:

  • 把图像切成 patch → 线性投影 → Transformer blocks → 线性投影 → unpatch
  • 用 AdaLN(Adaptive Layer Norm)注入时间步信息

Cola DLM 把这个思路从 2D 图像迁移到 1D 文本隐序列。


二、模型架构

2.1 整体结构

代码位置:modeling_cola_dit.py:536-690

复制代码
输入: txt (L_q_total, in_channels)
        │
        ▼
┌──────────────┐
│  PatchIn1D   │  patchify + 线性投影
└──────┬───────┘
       │ (L_q_total/patch_size, txt_dim)
       ▼
┌──────────────┐
│ TimestepEmb  │  sinusoidal → MLP → emb_dim
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ DiTBlock ×24 │  AdaLN + Attention + FFN
│ (分块因果)    │  per-sample KV cache
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ PatchOut1D   │  线性投影 + unpatch
└──────┬───────┘
       │
       ▼
输出: txt_sample (L_q_total, out_channels)

2.2 配置参数

参数 默认值 含义
txt_in_channels 16 输入隐空间维度(= VAE 的 latent_dim
txt_out_channels 16 输出隐空间维度
txt_dim 2048 Transformer 隐藏维度
emb_dim 2048 AdaLN 条件维度
heads 16 注意力头数
head_dim 128 每头维度
expand_ratio 4 FFN 扩展比
num_layers 24 Transformer 层数
patch_size 1 patchify 因子
rope_dim 96 RoPE 作用的通道数(< head_dim=128)
block_size 4 分块大小

总参数量:约 1.8B(24 层 × 16 头 × 128 head_dim = 2048 hidden dim)。


三、关键组件

3.1 PatchIn1D / PatchOut1D

代码位置:modeling_cola_dit.py:166-205

python 复制代码
class PatchIn1D(nn.Module):
    def __init__(self, in_channels, patch_size, dim):
        self.proj = nn.Linear(in_channels * patch_size, dim)

    def forward(self, txt, txt_shape):
        txt_shape_before_patchify = txt_shape
        if self.patch_size != 1:  # patch_size=1 时跳过,直接投影
            batch_list = _unflatten(txt, txt_shape)
            for i in range(len(batch_list)):
                batch_list[i] = rearrange(batch_list[i], "(T t) c -> T (t c)", t=self.patch_size)
            txt, txt_shape = _flatten(batch_list)
        txt = self.proj(txt)
        return txt, txt_shape, txt_shape_before_patchify

默认 patch_size=1,所以 rearrange 被完全跳过,只有线性投影生效。

3.2 TimestepEmbedding(AdaLN 条件化)

代码位置:modeling_cola_dit.py:135-158

python 复制代码
class TimestepEmbedding(nn.Module):
    def __init__(self, sinusoidal_dim, hidden_dim, output_dim):
        self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
        self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
        self.proj_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, timestep, device, dtype):
        emb = _get_sinusoidal_embedding(timestep, self.sinusoidal_dim)
        emb = self.act(self.proj_in(emb))
        emb = self.act(self.proj_hid(emb))
        emb = self.proj_out(emb)
        return emb

时间步 t t t 通过 sinusoidal embedding + MLP 注入到每个 Transformer block 的 AdaLN 中。

3.3 AdaLN(Adaptive Layer Norm)

代码位置:modeling_cola_dit.py:299-336

python 复制代码
class AdaLN(nn.Module):
    def forward(self, hid, emb, layer, mode, norm_layer=None, residual=None, **kwargs):
        emb = getattr(self, f"{layer}_{mode}")(emb)  # 线性投影
        if mode == "in":
            shift, scale = emb.chunk(2, dim=-1)
            return norm_layer(hid) * (1 + scale) + shift  # scale + shift
        if mode == "out":
            return hid * emb + residual  # gate + residual

每个 block 有两处 AdaLN:

  • mode="in":在 attention/FFN 之前,做 scale + shift
  • mode="out":在 attention/FFN 之后,做 gate + residual

3.4 MLP

代码位置:modeling_cola_dit.py:344-352

python 复制代码
class MLP(nn.Module):
    def __init__(self, dim, expand_ratio):
        self.proj_in = nn.Linear(dim, dim * expand_ratio)
        self.act = nn.GELU("tanh")  # 注意:不是 SwiGLU!
        self.proj_out = nn.Linear(dim * expand_ratio, dim)

注意 :DiT 用 GELU tanh ,而 VAE 用 SwiGLU。这是一个设计选择差异。


四、分块因果注意力

4.1 ColaDiTAttention

代码位置:modeling_cola_dit.py:360-463

python 复制代码
class ColaDiTAttention(nn.Module):
    def __init__(self, txt_dim, heads, head_dim, qk_bias, qk_norm_eps, rope_dim):
        self.proj_qkv = nn.Linear(txt_dim, inner_dim * 3)
        self.norm_q = nn.LayerNorm(head_dim)  # QK-norm
        self.norm_k = nn.LayerNorm(head_dim)
        self.rope = TextRotaryEmbedding(dim=rope_dim)  # rope_dim=96 < head_dim=128

RoPE 只作用于部分通道rope_dim=96,而 head_dim=128。这意味着 128 个通道中有 96 个有位置编码,32 个没有。

4.2 KV Cache 管理

代码位置:modeling_cola_dit.py:420-442

python 复制代码
# per-sample KV cache
self._k_cache: Optional[list[torch.Tensor]] = None
self._v_cache: Optional[list[torch.Tensor]] = None

# forward 中的逻辑:
if update_kv:  # 提交新 block 到 cache
    self._k_cache = [torch.cat([c, n], dim=0) for c, n in zip(self._k_cache, new_ks)]
    full_k = torch.cat(self._k_cache, dim=0)
elif use_kv_cache and self._k_cache is not None:  # 读 cache
    full_k = torch.cat([torch.cat([c, n], dim=0) for c, n in zip(self._k_cache, new_ks)], dim=0)
else:  # 无 cache
    full_k = txt_k

4.3 注意力计算

代码位置:modeling_cola_dit.py:381-397

python 复制代码
def slow_attn(self, query, key, value, attn_mask=None):
    d_head = query.shape[-1]
    device_type = "cuda" if query.is_cuda else query.device.type
    with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        scale = 1.0 / (d_head ** 0.5)
        attn = query.mul(scale) @ key.transpose(-2, -1)  # 显式构造完整矩阵
        if attn_mask is not None:
            attn = attn + attn_mask.to(attn.dtype)
        attn_weight = attn.softmax(dim=-1)
        attn_out = attn_weight @ value
    return attn_out

注意 :没有使用 Flash Attention,而是显式构造完整的 ( L q , L k ) (L_q, L_k) (Lq,Lk) 注意力矩阵。这是当前实现的一个重要限制。


五、CFG(Classifier-Free Guidance)

5.1 原理

CFG 是扩散模型的标准技巧:同时做条件生成和无条件生成,然后按比例混合。

v ^ = v uncond + s ⋅ ( v cond − v uncond ) \hat{v} = v_{\text{uncond}} + s \cdot (v_{\text{cond}} - v_{\text{uncond}}) v^=vuncond+s⋅(vcond−vuncond)

其中 s s s 是 guidance scale(默认 7.0)。

5.2 代码实现

代码位置:inference.py:621-648

python 复制代码
# 条件前向:用 KV cache(能看到 prompt 和历史 block)
drift_cond = dit(txt=txt_bf16, txt_shape=txt_shape_cum,
                 txt_q_shape=txt_q_shape, timestep=ts_bf16,
                 update_kv=False, use_kv_cache=True).txt_sample

# 无条件前向:不用 cache(只能看到当前 block)
drift_uncond = dit(txt=txt_bf16, txt_shape=txt_q_shape,
                   txt_q_shape=txt_q_shape, timestep=ts_bf16,
                   update_kv=False, use_kv_cache=False).txt_sample

# CFG 融合
s = cfg_scale_first_block if step == 0 else guidance_scale
drift = s * (drift_cond - drift_uncond) + drift_uncond

5.3 短 prompt 的 CFG 退化

代码位置:inference.py:531-554

当 prompt 短于 block_size 时,第一个生成 block 的前缀 KV cache 为空,条件和无条件前向数学上相同。此时 CFG 会放大 bf16 噪声:

python 复制代码
cfg_scale_first_block = torch.tensor(
    [guidance_scale if pl > 0 else 1.0 for pl in prefix_lens],
    device=device, dtype=torch.bfloat16,
).repeat_interleave(block_size).unsqueeze(-1)

空 prefix 的样本自动将 guidance_scale 降为 1.0。


六、ColaDiTBlock

代码位置:modeling_cola_dit.py:471-519

每个 block 的前向流程:

python 复制代码
def forward(self, txt, *, txt_shape, txt_q_shape, emb, ...):
    # 1. AdaLN + Attention
    txt_msa = self.ada(txt, emb=emb, layer="msa", mode="in", norm_layer=self.msa_norm)
    txt_msa = self.msa(txt_msa, txt_shape=txt_shape, txt_q_shape=txt_q_shape, ...)
    txt = self.ada(txt_msa, emb=emb, layer="msa", mode="out", residual=txt)

    # 2. AdaLN + FFN
    txt_mlp = self.ada(txt, emb=emb, layer="mlp", mode="in", norm_layer=self.mlp_norm)
    txt_mlp = self.mlp(txt_mlp)
    txt = self.ada(txt_mlp, emb=emb, layer="mlp", mode="out", residual=txt)
    return txt

七、Stage 2 联合训练目标

论文式 2.2.3 给出 Stage 2 的损失:

L stage2 = λ VAE ⋅ L VAAE + λ FM ⋅ L FM + λ ref ⋅ E [ KL ( q ϕ ( z 0 ∣ x ) ∥ q ϕ ref ( z 0 ∣ x ) ) ] \mathcal{L}{\text{stage2}} = \lambda{\text{VAE}} \cdot \mathcal{L}{\text{VAAE}} + \lambda{\text{FM}} \cdot \mathcal{L}{\text{FM}} + \lambda{\text{ref}} \cdot \mathbb{E}[\text{KL}(q_\phi(z_0|x) \| q_{\phi_{\text{ref}}}(z_0|x))] Lstage2=λVAE⋅LVAAE+λFM⋅LFM+λref⋅E[KL(qϕ(z0∣x)∥qϕref(z0∣x))]

作用
L VAE \mathcal{L}_{\text{VAE}} LVAE 保持 VAE 的重构能力
L FM \mathcal{L}_{\text{FM}} LFM 训练 DiT 先验(Flow Matching)
reference KL 防止 VAE 的隐空间漂移(对齐冻结的参考编码器)

八、面试追问清单

基础(⭐)

  1. DiT 和 UNet 作为扩散模型骨干的区别是什么?
  2. AdaLN 是如何注入时间步信息的?
  3. CFG 的 guidance scale 对生成质量有什么影响?

进阶(⭐⭐)

  1. 为什么 DiT 的 RoPE 只作用于 96/128 个通道?
  2. per-sample KV cache 和标准 KV cache 有什么区别?
  3. Stage 2 的 reference-encoder KL 正则为什么能防止隐空间漂移?

专家(⭐⭐⭐)

  1. DiT 用 GELU tanh 而 VAE 用 SwiGLU,这个差异会影响什么?
  2. rope_theta=10000(DiT)vs rope_theta=500000(VAE)的位置编码频率差异意味着什么?
  3. 如果把 block_size 从 4 改为 16,DiT 的注意力模式会怎么变化?

九、下期预告

下一章我们将逐行拆解推理流水线------从 prompt 输入到文本输出的完整过程,包括 tokenization、前缀编码、分块先验传输、条件解码和采样策略。


系列导航

第 01 章 · 第 02 章 · 第 03 章 · 第 04 章 · 第 05 章

第 06 章:分块因果 DiT 先验 ← 你在这里

第 07 章 · 第 08 章 · 第 09 章 · 第 10 章


作者Yunzenn

相关推荐
comcoo15 小时前
OpenClaw 本地部署避坑指南|环境配置 + 故障排查全流程
运维·人工智能·openclaw安装包·open claw部署
云飞云共享云桌面15 小时前
企业降本增效新思路:SolidWorks共享部署实战经验分享
运维·服务器·网络·人工智能·3d·自动化
AI周红伟15 小时前
Windows 支持 Hermes Agent 吗:原生 Windows 安装 + WSL2 路径完整指南
数据库·人工智能·windows·阿里云·职场和发展·计算机外设
Rocky Ding*15 小时前
深入浅出讲解ERNIE-Image图像创作大模型
论文阅读·人工智能·深度学习·机器学习·ai作画·aigc·ai-native
boonya15 小时前
AI Coding落地生产的真实困境与可执行操作指南
人工智能·落地生产·困境
xier_ran15 小时前
【infra之路】Transformer 核心计算流
人工智能·深度学习·transformer
huangdong_15 小时前
电商图片智能分类算法:主图/属性图/详情图自动识别技术
人工智能·分类·数据挖掘
电商API_1800790524715 小时前
价格波动预警|用API实时监控淘宝京东商品价格,实现自动化竞品调价与捡漏
大数据·运维·数据库·人工智能·数据挖掘·自动化
美狐美颜sdk15 小时前
直播APP开发如何实现美颜功能?低成本美颜SDK方案推荐
android·人工智能·ios·第三方美颜sdk·视频美颜sdk