第一章:研究问题与核心思想
dmd2思想解读移步 从 DMD 到 DMD2:搞懂扩散模型的 "提速革命"-CSDN博客
1.1 论文解决的核心问题
扩散模型虽然能生成高质量图像,但需要数十到上百步的迭代采样,推理成本极高。Distribution Matching Distillation (DMD) 是一种将扩散模型蒸馏为单步生成器的方法,其核心思想是:让学生生成器的输出分布匹配教师扩散模型的分布,而不是强求一一对应的关系 。
然而,原始 DMD 存在三个关键局限:
- 需要昂贵的 ODE 对数据集 :必须用教师模型以多步确定性采样生成大量噪声-图像对,用于回归损失计算
- 学生质量受限于教师采样路径 :回归损失使学生过度模仿教师的采样轨迹
- 训练不稳定 :移除回归损失后,假判别器无法准确估计生成样本分布
1.2 DMD2 的解决方案
DMD2 通过以下技术创新解决上述问题:
|------------------|-----------------------|---------------------------------------------------|
| 论文贡献 | 解决的问题 | 代码实现位置 |
| 消除回归损失 + 双时间尺度更新 | 避免 ODE 对数据集成本,解决训练不稳定 | train_sd.py:330 (dfake_gen_update_ratio) |
| 集成 GAN 损失 | 让学生直接在真实数据上学习,突破教师上限 | sd_guidance.py:369-395 (GAN classifier) |
| 多步采样训练模拟 | 解决训练-推理输入分布不匹配 | sd_unified_model.py:128-162 (backward simulation) |
第二章:从数据视角理解训练范式转变
2.1 传统蒸馏方法的困境
传统扩散模型蒸馏(如 Consistency Distillation)需要预先构建噪声-图像对数据集 。在 DMD2 的代码中,我们可以看到这种传统方法的遗留实现:
ODE 预训练阶段 (可选)的实现位于 /DMD2/main/train_sd_ode.py:
# 第 233-246 行:使用预生成的 ODE 对进行训练`
`ode_dict =` `next(self.ode_dataloader)`
`ode_noises = ode_dict['latents']` `# 教师生成的噪声`
`ode_images = ode_dict['images']` `# 教师以多步采样生成的图像`
`# 使用 LPIPS 感知损失进行回归`
`loss = torch.mean(`
` self.lpips_loss_func(`
` ode_pred_image *` `0.5` `+` `0.5,`
` ode_gt_image *` `0.5` `+` `0.5`
`).float()`
`)`
`
这种方法的问题在于:对于 SDXL 这样的大模型,生成数万张高质量图像需要巨大的计算开销。更重要的是,学生模型被限制在教师已经探索过的采样路径上,无法超越教师。
2.2 DMD2 的数据策略革新
DMD2 彻底改变了数据策略:不再依赖 ODE 对,而是直接使用真实图像数据 。这在代码中体现为两个数据加载器的并行使用:
文件: /DMD2/main/train_sd.py
# 第 128-138 行:双数据加载器设计`
`dataloader = torch.utils.data.DataLoader(dataset,` `...)` `# 文本 prompts`
`self.dataloader = cycle(dataloader)`
`real_dataset = SDImageDatasetLMDB(args.real_image_path,` `...)` `# 真实图像`
`real_dataloader = torch.utils.data.DataLoader(real_dataset,` `...)`
`self.real_dataloader = cycle(real_dataloader)`
`
数据来源与处理链路 :
LAION 数据集`
` │`
` ├──→ 文本 Prompts → captions_laion_score6.25.pkl`
` │ │`
` │ └──→ SDTextDataset (utils.py:222-277)`
` │ 用于条件生成`
` │`
` └──→ 图像 → VAE 编码 → sdxl_vae_latents_laion_500k_lmdb/`
` │`
` └──→ SDImageDatasetLMDB (sd_image_dataset.py)`
` 用于 GAN 判别器训练`
`
VAE Latents 预处理 位于 DMD2/main/sdxl/generate_vae_latents.py:
# 第 85-88 行:将图像压缩到潜在空间`
`with torch.no_grad():`
` latents = vae.encode(batch_images).latent_dist.sample()` `* vae.config.scaling_factor`
`latent_list.append(latents.half().cpu().numpy())`
`
这种设计使得 DMD2 可以在真实数据 而非教师生成的数据上训练判别器,这是论文"集成 GAN 损失"贡献的数据基础。
第三章:模型架构------双网络博弈设计
3.1 架构设计理念
DMD2 的架构设计体现了生成对抗 与分布匹配 的结合。整个系统包含两个核心网络:
┌─────────────────────────────────────────────────────────────┐`
`│ DMD2 系统架构 │`
`├─────────────────────────────────────────────────────────────┤`
`│ │`
`│ Generator (学生) Guidance (判别器系统) │`
`│ ┌─────────────┐ ┌─────────────────────────┐ │`
`│ │ │ 生成图像 │ Fake UNet (可训练) │ │`
`│ │ UNet 学生 │─────────────→│ + GAN Classifier │ │`
`│ │ │ │ │ │`
`│ └─────────────┘ │ Real UNet (冻结教师) │ │`
`│ ↑ │ │ │`
`│ │ 梯度更新 └─────────────────────────┘ │`
`│ │ │ │`
`│ DMD Loss + GAN Loss │ │`
`│ │ │ │`
`│ └──────────────────────────────┘ │`
`│ 分布匹配梯度 │`
`└─────────────────────────────────────────────────────────────┘`
`
3.2 统一模型实现
文件: DMD2/main/sd_unified_model.py
# 第 12-18 行:统一模型封装两个子网络`
`class` `SDUniModel(nn.Module):`
`def` `__init__(self, args, accelerator):`
`super().__init__()`
` self.guidance_model = SDGuidance(args, accelerator)` `# 判别器系统`
`# Generator 在下方初始化...`
`
Generator 初始化 (第 38-75 行)展示了 DMD2 对不同微调策略的支持:
# 从预训练 SDXL UNet 初始化学生模型`
`self.feedforward_model = UNet2DConditionModel.from_pretrained(`
` args.model_id, subfolder="unet"`
`).float()`
`# 可选:LoRA 微调(减少显存占用)`
`if args.generator_lora:`
` lora_config = LoraConfig(`
` r=args.lora_rank,` `# 默认 64`
` target_modules=["to_q",` `"to_k",` `"to_v",` `"to_out.0",` `...],`
` lora_alpha=args.lora_alpha # 默认 8`
`)`
` self.feedforward_model.add_adapter(lora_config)`
`
3.3 Guidance 系统------判别器的双重角色
文件: /DMD2/main/sd_guidance.py
SDGuidance 类实现了论文中"假判别器"的概念,但它的设计比传统 GAN 判别器更加复杂:
# 第 45-58 行:双 UNet 设计`
`self.real_unet = UNet2DConditionModel.from_pretrained(...) # 冻结的教师模型`
`self.real_unet.requires_grad_(False)`
`self.fake_unet = UNet2DConditionModel.from_pretrained(...) # 可训练的判别器`
`self.fake_unet.requires_grad_(True)`
`
为什么需要两个 UNet? 这涉及到论文的核心数学推导:
- real_unet:代表真实数据分布的得分函数(score function),即 \\nabla \\log p_{\\text{real}}
- fake_unet:估计生成数据分布的得分函数,即 \\nabla \\log p_{\\text{fake}}
分布匹配的目标是让 p_{\\text{fake}} 逼近 p_{\\text{real}},这需要两个得分函数的协作。
3.4 GAN 判别器------突破教师上限的关键
论文的创新点之一是集成 GAN 损失,让生成器直接在真实数据上获得反馈。这在代码中体现为一个额外的分类器分支:
文件: /DMD2/main/sd_guidance.py
# 第 108-134 行:GAN 分类器网络`
`self.cls_pred_branch = nn.Sequential(`
` nn.Conv2d(kernel_size=4, in_channels=1280, out_channels=1280, stride=2, padding=1),`
` nn.GroupNorm(num_groups=32, num_channels=1280),`
` nn.SiLU(),`
` # ... 多层卷积下采样`
` nn.Conv2d(kernel_size=1, in_channels=1280, out_channels=1, stride=1, padding=0),`
`)`
`
关键设计 :分类器直接使用 fake_unet 的瓶颈层特征(bottleneck features),而不是单独设计一个判别器网络。这种设计既节省参数,又能复用 UNet 的语义理解能力。
# 第 156-165 行:分类器前向传播`
`def` `compute_cls_logits(self, image, text_embedding, unet_added_conditions):`
`# 获取 UNet 中间层特征`
` rep = self.fake_unet.forward(image, timesteps, text_embedding, classify_mode=True)`
` rep = rep[-1].float()` `# 取瓶颈层特征`
` logits = self.cls_pred_branch(rep).squeeze(dim=[2,` `3])`
`return logits`
`
第四章:训练过程------双时间尺度更新规则的实现
4.1 论文核心洞察:为什么需要不同的更新频率?
论文发现,移除回归损失后训练会变得不稳定。根本原因是假判别器无法准确估计生成样本的分布 。解决方案是采用双时间尺度更新规则 :
- 判别器需要更频繁地更新,以准确跟踪生成分布的变化
- 生成器更新频率较低,避免在判别器不稳定时进行优化
4.2 代码实现
文件: DMD2/main/train_sd.py
# 第 330 行:控制更新频率`
`COMPUTE_GENERATOR_GRADIENT = self.step % self.dfake_gen_update_ratio ==` `0`
`
dfake_gen_update_ratio 参数(默认为 5,见实验脚本)表示:每更新 5 次判别器,才更新 1 次生成器 。
4.3 单步训练的完整流程
文件: /DMD2/main/train_sd.py 第 322-419 行
让我用论文的语言解释这段代码:
def` `train_one_step(self):`
`# ===== 第一步:从噪声生成图像 =====`
`# 这是 Generator 的前向过程:z → G(z) → x_fake`
` noise = torch.randn(batch_size, latent_channel, latent_resolution,` `...)`
`# ===== 第二步:Generator Turn(当满足更新频率时)=====`
`if COMPUTE_GENERATOR_GRADIENT:`
` generator_loss_dict, generator_log_dict = self.model(`
` noise, text_embedding,` `...,`
` compute_generator_gradient=True,`
` generator_turn=True`
`)`
`# 论文中的总损失:L_G = L_DM + λ_GAN * L_GAN`
` generator_loss = generator_loss_dict["loss_dm"]` `* args.dm_loss_weight`
`if self.gen_cls_loss:`
` generator_loss += generator_loss_dict["gen_cls_loss"]` `* gen_cls_loss_weight`
`# 更新生成器`
` self.accelerator.backward(generator_loss)`
` self.optimizer_generator.step()`
`# ===== 第三步:Guidance Turn(每步都执行)=====`
`# 更新判别器以准确估计 p_fake`
` guidance_loss_dict, guidance_log_dict = self.model(`
` noise, text_embedding,` `...,`
` generator_turn=False,`
` guidance_turn=True`
`)`
` guidance_loss = guidance_loss_dict["loss_fake_mean"]` `# Dfake Loss`
`if args.cls_on_clean_image:`
` guidance_loss += guidance_loss_dict["guidance_cls_loss"]` `* guidance_cls_loss_weight`
` self.accelerator.backward(guidance_loss)`
` self.optimizer_guidance.step()`
`
4.4 分布匹配损失的数学实现
文件: DMD2/main/sd_guidance.py 第 168-255 行
这是论文公式 (3)-(5) 的代码实现:
def` `compute_distribution_matching_loss(self, latents, text_embedding,` `...):`
`# 论文公式:从生成图像出发,加噪到随机时间步`
` timesteps = torch.randint(self.min_step, self.max_step+1,` `[batch_size],` `...)`
` noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)`
`# 论文公式:教师 UNet 预测真实得分方向`
` pred_real_noise = predict_noise(self.real_unet, noisy_latents,` `...)`
` pred_real_image = get_x0_from_noise(noisy_latents, pred_real_noise,` `...)`
`# 论文公式:学生 UNet 预测当前得分方向`
` pred_fake_noise = predict_noise(self.fake_unet, noisy_latents,` `...)`
` pred_fake_image = get_x0_from_noise(noisy_latents, pred_fake_noise,` `...)`
`# 论文公式 (5):分布匹配梯度`
`# grad = ∇log p_real - ∇log p_fake ≈ (x - x̂_real) - (x - x̂_fake)`
` p_real =` `(latents - pred_real_image)`
` p_fake =` `(latents - pred_fake_image)`
` grad =` `(p_real - p_fake)` `/ torch.abs(p_real).mean(...)`
`# 通过 MSE 损失将梯度反传给生成器`
`# 论文使用"隐式梯度"技巧:L = 0.5 * ||x - (x - grad)||²`
` loss =` `0.5` `* F.mse_loss(original_latents,` `(original_latents-grad).detach(),` `...)`
`
关键理解 :这里 latents 是生成器输出的图像,通过加噪、预测、去噪的过程,我们得到了"如何调整这张图像使其更接近真实分布"的梯度方向。这个梯度通过一个巧妙的 MSE 损失反向传播给生成器。
4.5 Dfake Loss------让判别器学会估计生成分布
文件: DMD2/main/sd_guidance.py 第 257-310 行
为了让 fake_unet 准确估计生成分布,DMD2 使用标准的去噪得分匹配目标:
def` `compute_loss_fake(self, latents, text_embedding,` `...):`
` latents = latents.detach()` `# 重要:不传梯度给生成器`
`# 随机采样时间步(覆盖整个扩散过程)`
` timesteps = torch.randint(0, self.num_train_timesteps,` `[batch_size],` `...)`
` noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)`
`# 训练 fake_unet 预测噪声`
` fake_noise_pred = predict_noise(self.fake_unet, noisy_latents,` `...)`
`# 简单的 MSE 损失:让判别器学会预测噪声`
` loss_fake = torch.mean((fake_noise_pred - noise)**2)`
`
4.6 GAN Loss------在真实数据上学习
文件: /DMD2/main/sd_guidance.py 第 369-395 行
def` `compute_guidance_clean_cls_loss(self, real_image, fake_image,` `...):`
`# 真实图像 → 判别为真`
` pred_realism_on_real = self.compute_cls_logits(real_image.detach(),` `...)`
`# 生成图像 → 判别为假`
` pred_realism_on_fake = self.compute_cls_logits(fake_image.detach(),` `...)`
`# 二分类损失(使用 softplus 而非 sigmoid + BCE,更稳定)`
` classification_loss = F.softplus(pred_realism_on_fake).mean()` `+ \`
` F.softplus(-pred_realism_on_real).mean()`
`
论文意义 :这个损失让生成器能够直接从真实图像中获得学习信号,而不是仅依赖教师模型的间接指导。这是 DMD2 能够在某些指标上超越教师模型 的关键原因。
第五章:多步采样------解决训练-推理不匹配问题
5.1 问题的提出
单步生成器虽然快速,但质量有限。如果想让生成器支持多步采样(如 4 步),会面临一个关键问题:
- 训练时 :生成器从纯噪声 z_T 一步生成 x_0
- 推理时 :生成器需要逐步去噪 z_T → z_{T-1} → ... → x_0
输入分布的不一致会导致性能下降。
5.2 DMD2 的解决方案:反向模拟训练
文件: DMD2/main/sd_unified_model.py
# 第 128-162 行:反向采样模拟`
`@torch.no_grad()`
`def` `sample_backward(self, noisy_image, real_text_embedding, real_pooled_text_embedding):`
`# 随机选择一个中间时间步`
` selected_step = torch.randint(low=0, high=self.num_denoising_step, size=(1,),` `...)`
`# 模拟推理时的多步采样过程`
` generated_image = noisy_image`
`for constant in self.denoising_step_list[:selected_step]:`
` current_timesteps = torch.ones(batch_size,` `...)` `* constant`
`# 用生成器进行一步预测`
` generated_noise = self.feedforward_model(`
` generated_image, current_timesteps, real_text_embedding,` `...`
`).sample`
`# 预测干净图像`
` generated_image = get_x0_from_noise(generated_image, generated_noise,` `...)`
`# 加噪到下一个时间步(模拟推理过程)`
` next_timestep = current_timesteps - self.timestep_interval`
` generated_image = self.noise_scheduler.add_noise(`
` generated_image, torch.randn_like(generated_image), next_timestep`
`)`
`return generated_image, return_timesteps`
`
核心思想 :训练时模拟推理的多步过程,让生成器看到它在推理时会遇到的输入分布。这相当于数据增强的一种形式。
第六章:并行训练与工程实现
6.1 为什么选择 FSDP?
SDXL 的 UNet 有约 2.6B 参数,单卡显存难以容纳。DMD2 需要同时训练两个 UNet (Generator 和 Fake UNet),显存压力更大。
FSDP 配置 位于 /DMD2/fsdp_configs/fsdp_1node_debug.yaml:
distributed_type: FSDP`
`fsdp_config:`
`fsdp_auto_wrap_policy: SIZE_BASED_WRAP`
`fsdp_min_num_params:` `50000000` `# 50M 参数以上的模块单独分片`
`fsdp_sharding_strategy:` `1` `# FULL_SHARD:参数、梯度、优化器状态都分片`
`fsdp_state_dict_type: SHARDED_STATE_DICT`
`fsdp_sync_module_states:` `true` `# 确保各节点初始化一致`
`
6.2 FSDP 初始化同步问题
文件: DMD2/main/train_sd.py 第 164-184 行
if self.fsdp:`
`# 问题:FSDP hybrid_shard 模式下,不同节点的参数初始化可能不一致`
`# 解决方案:主进程保存初始参数,然后所有进程加载相同的参数`
` generator_path = os.path.join(args.output_path,` `f"checkpoint_model_{self.step:06d}",` `"pytorch_model.bin")`
`if accelerator.is_main_process:`
` torch.save(self.model.feedforward_model.state_dict(), generator_path)`
` accelerator.wait_for_everyone()`
` self.model.feedforward_model.load_state_dict(torch.load(generator_path,` `...))`
`
6.3 混合精度训练策略
文件: DMD2/main/sd_unified_model.py
# 第 107-111 行:混合精度上下文`
`self.network_context_manager = torch.autocast(`
` device_type="cuda", dtype=torch.bfloat16`
`)` `if self.use_fp16 else NoOpContext()`
`
重要注意事项 (代码注释第 107-108 行):
# "SDXL's original VAE doesn't work with half precision"`
`# SDXL 的 VAE 必须保持 float32,否则会出现数值问题`
`
6.4 优化器与学习率策略
文件: /DMD2/main/train_sd.py 第 191-216 行
# 两个独立的 AdamW 优化器,分别用于生成器和判别器`
`self.optimizer_generator = torch.optim.AdamW(`
`[p for p in self.model.feedforward_model.parameters()` `if p.requires_grad],`
` lr=args.generator_lr,` `# 论文推荐 5e-7`
` betas=(0.9,` `0.999),`
` weight_decay=0.01`
`)`
`self.optimizer_guidance = torch.optim.AdamW(`
`[p for p in self.model.guidance_model.parameters()` `if p.requires_grad],`
` lr=args.guidance_lr,` `# 论文推荐 5e-7`
`...`
`)`
`# Constant with warmup 调度器`
`self.scheduler_generator = get_scheduler(`
`"constant_with_warmup",`
` num_warmup_steps=args.warmup_step,` `# 500 步预热`
`...`
`)`
`
第七章:推理流程------从训练到部署
7.1 推理架构的简化
训练时的双网络架构在推理时大幅简化:只需要 Generator ,Guidance 网络完全丢弃。
文件: DMD2/demo/text_to_image_sdxl.py
# 第 94-103 行:加载生成器`
`def create_generator(self, args):`
` generator = UNet2DConditionModel.from_pretrained(`
` args.model_id, subfolder="unet"`
` ).to(self.DTYPE)`
` # 加载训练好的权重`
` state_dict = torch.load(args.checkpoint_path, map_location="cpu")`
` generator.load_state_dict(state_dict, strict=True)`
` generator.requires_grad_(False) # 推理时冻结`
` return generator`
`
7.2 单步 vs 多步采样
文件: /DMD2/demo/text_to_image_sdxl.py 第 142-178 行
def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):`
` # 论文推荐的时间步配置`
` if self.num_step == 1:`
` all_timesteps = [399] # 单步:从 t=399 开始`
` step_interval = 0`
` elif self.num_step == 4:`
` all_timesteps = [999, 749, 499, 249] # 四步:均匀间隔`
` step_interval = 250`
` for constant in all_timesteps:`
` # UNet 前向传播(预测噪声)`
` eval_images = self.model(noise, current_timesteps, prompt_embed, ...).sample`
` # 从噪声预测 x0(DDIM 公式)`
` eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps)`
` # 加噪到下一个时间步(继续迭代)`
` if constant != all_timesteps[-1]:`
` noise = self.scheduler.add_noise(eval_images, torch.randn_like(eval_images), next_timestep)`
` # VAE 解码到像素空间`
` eval_images = self.vae.decode(eval_images / scaling_factor, ...)`
`
关键理解 :
- conditioning_timestep=399 是单步采样的最优起点(论文通过实验确定)
- 4 步采样使用均匀间隔 999, 749, 499, 249
- 每一步都是"预测噪声 → 转换为 x0 → 加噪"的迭代过程
7.3 与 Diffusers 生态集成
论文提供的预训练权重可以直接加载到标准的 Diffusers Pipeline 中:
# README 中的示例代码`
`from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler`
`# 加载 DMD2 微调的 UNet`
`unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet")`
`unet.load_state_dict(torch.load("dmd2_sdxl_4step_unet_fp16.bin"))`
`# 创建标准 Pipeline`
`pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet)`
`pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)` `# 使用 LCM 调度器`
`# 生成图像`
`image = pipe(prompt="a photo of a cat", num_inference_steps=4,`
` guidance_scale=0, timesteps=[999,` `749,` `499,` `249]).images[0]`
`
第八章:项目目录结构------代码组织的逻辑
/DMD2/`
`│`
`├── main/ # 核心训练代码`
`│ ├── train_sd.py # ★ SDXL 主训练入口`
`│ │ └── Trainer 类实现双时间尺度更新、数据加载、模型保存`
`│ │`
`│ ├── sd_unified_model.py # ★ 统一模型:Generator + Guidance 封装`
`│ │ └── 实现前向传播、多步采样模拟、条件输入构建`
`│ │`
`│ ├── sd_guidance.py # ★ Guidance 系统:分布匹配的核心`
`│ │ ├── compute_distribution_matching_loss() → DMD Loss`
`│ │ ├── compute_loss_fake() → Dfake Loss`
`│ │ └── compute_guidance_clean_cls_loss() → GAN Loss`
`│ │`
`│ ├── train_sd_ode.py # ODE 预训练(可选阶段)`
`│ │ └── 使用预生成数据对进行 LPIPS 回归`
`│ │`
`│ ├── sdxl/ # SDXL 特定组件`
`│ │ ├── sdxl_text_encoder.py # 双 CLIP 编码器`
`│ │ ├── sdxl_ode_dataset.py # ODE 数据加载`
`│ │ ├── generate_vae_latents.py # 预处理真实图像`
`│ │ └── generate_noise_image_pairs_laion_sdxl.py # 生成 ODE 对`
`│ │`
`│ ├── edm/ # ImageNet-64x64 实现`
`│ │ ├── train_edm.py # ★ EDM 训练入口`
`│ │ ├── edm_guidance.py # EDM 版本的 Guidance`
`│ │ └── edm_network.py # EDMPrecond 网络配置`
`│ │`
`│ └── data/ # 数据处理工具`
`│ ├── lmdb_dataset.py # 高效数据加载`
`│ └── create_lmdb_iterative.py # 数据格式转换`
`│`
`├── demo/ # 推理演示`
`│ ├── text_to_image_sdxl.py # ★ SDXL Gradio Demo`
`│ └── imagenet_example.py # ImageNet 生成示例`
`│`
`├── experiments/ # 实验配置脚本`
`│ ├── sdxl/ # SDXL 超参数配置`
`│ │ └── sdxl_cond399_*.sh # 单步/多步训练脚本`
`│ └── imagenet/ # ImageNet 配置`
`│`
`├── third_party/edm/ # NVIDIA EDM 参考实现`
`│ └── training/networks.py # EDMPrecond 网络定义`
`│`
`└── fsdp_configs/ # FSDP 并行配置`
` └── fsdp_1node_debug.yaml # 分片策略配置`
`
第九章:从论文到代码的映射总结
|--------------|-------------------------------------------------|----------------------------------------------|
| 论文方法 | 代码位置 | 关键参数 |
| 消除回归损失 | train_sd.py 不使用 ODE 数据 | --gan_alone 标志 |
| 双时间尺度更新 | train_sd.py:330 | dfake_gen_update_ratio=5 |
| 分布匹配损失 | sd_guidance.py:168-255 | dm_loss_weight=1.0 |
| GAN 损失集成 | sd_guidance.py:369-395 | gen_cls_loss_weight=5e-3 |
| 多步采样模拟 | sd_unified_model.py:128-162 | denoising=True, num_denoising_step=4 |
| 条件时间步选择 | train_sd.py:661, demo/text_to_image_sdxl.py:148 | conditioning_timestep=399 |
| CFG 引导 | sd_guidance.py:9-38 | real_guidance_scale=8, fake_guidance_scale=1 |
这份分析将 DMD2 的代码实现与论文思想紧密结合,展示了每个代码模块背后的研究动机和数学原理。从数据策略的革新(放弃 ODE 对),到架构设计的权衡(双 UNet + GAN Classifier),再到训练过程的关键细节(双时间尺度更新),DMD2 的代码完整地体现了论文的技术创新。