Pytorch图像去噪实战(十三):DDIM加速扩散模型采样,让去噪从1000步降到50步
一、问题场景:DDPM效果能看,但采样实在太慢
上一篇我们把 DDPM 图像去噪工程搭起来了。
训练流程跑通后,很快会遇到一个非常现实的问题:
采样太慢。
DDPM一般需要从 T=1000 一步步反向去噪:
text
x1000 -> x999 -> ... -> x0
如果只是做实验还可以接受。
但在真实项目中,比如:
- 用户上传图片实时去噪
- 批量修复图片
- OCR预处理
- 在线图片增强
1000步采样基本不可接受。
这时就需要 DDIM。
二、DDIM解决什么问题?
DDIM的核心价值是:
用更少的采样步数完成近似去噪。
比如把:
text
1000步
减少到:
text
50步
甚至:
text
20步
虽然可能牺牲一点质量,但速度提升非常明显。
三、DDPM和DDIM的工程区别
DDPM采样每一步都加入随机噪声:
text
随机反向过程
DDIM可以使用确定性采样:
text
确定性反向过程
这意味着:
- 采样更快
- 结果更稳定
- 可以跳步采样
- 更适合工程部署
四、项目结构
text
ddim_denoise/
├── diffusion/
│ ├── ddpm.py
│ └── ddim.py
├── models/
│ └── unet.py
├── dataset.py
├── train.py
├── sample_ddpm.py
└── sample_ddim.py
DDIM不需要重新训练模型,可以复用DDPM训练好的噪声预测网络。
五、DDIM采样器实现
diffusion/ddim.py
python
import torch
class DDIMSampler:
def __init__(self, ddpm, ddim_steps=50):
self.ddpm = ddpm
self.ddim_steps = ddim_steps
self.time_steps = torch.linspace(
ddpm.timesteps - 1,
0,
ddim_steps
).long().to(ddpm.device)
@torch.no_grad()
def sample(self, model, shape):
device = self.ddpm.device
x = torch.randn(shape).to(device)
for i in range(len(self.time_steps) - 1):
t = self.time_steps[i]
t_next = self.time_steps[i + 1]
batch_t = torch.full((shape[0],), t, device=device, dtype=torch.long)
pred_noise = model(x, batch_t)
alpha_bar_t = self.ddpm.alpha_bars[t]
alpha_bar_next = self.ddpm.alpha_bars[t_next]
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, 0.0, 1.0)
x = torch.sqrt(alpha_bar_next) * pred_x0 + torch.sqrt(1 - alpha_bar_next) * pred_noise
return x
六、DDIM采样脚本
sample_ddim.py
python
import torch
import torchvision.utils as vutils
from configs.train_config import TrainConfig
from diffusion.ddpm import DDPM
from diffusion.ddim import DDIMSampler
from models.unet import DDPMUNet
@torch.no_grad()
def sample_ddim():
cfg = TrainConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DDPMUNet(channels=cfg.channels).to(device)
model.load_state_dict(torch.load("checkpoints/ddpm_epoch_100.pth", map_location=device))
model.eval()
ddpm = DDPM(
timesteps=cfg.timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
device=device
)
sampler = DDIMSampler(ddpm, ddim_steps=50)
samples = sampler.sample(
model,
shape=(16, cfg.channels, cfg.image_size, cfg.image_size)
)
samples = torch.clamp(samples, 0.0, 1.0)
vutils.save_image(samples.cpu(), "ddim_samples.png", nrow=4)
if __name__ == "__main__":
sample_ddim()
七、为什么DDIM可以跳步?
DDPM严格按照马尔可夫链逐步反推。
DDIM则使用一种非马尔可夫形式的采样路径。
工程上可以这样理解:
DDIM不是每一步都重新随机采样,而是根据当前预测的x0和噪声方向,直接跳到更早的时间步。
所以它可以从:
text
1000 -> 999 -> 998
变成:
text
1000 -> 980 -> 960
这就是速度提升的核心。
八、采样步数怎么选?
实际建议:
快速预览
python
ddim_steps = 20
适合训练中间快速看效果。
平衡质量和速度
python
ddim_steps = 50
这是比较常用的设置。
更高质量
python
ddim_steps = 100
速度慢一些,但质量更稳。
九、加入eta控制随机性
DDIM可以设置 eta 控制是否加入随机性。
简化理解:
- eta = 0:确定性采样
- eta > 0:加入随机性
入门建议先用:
python
eta = 0
因为结果更稳定,方便对比实验。
十、推理速度对比
实际工程中,采样速度差距非常明显。
| 方法 | 采样步数 | 速度 | 质量 |
|---|---|---|---|
| DDPM | 1000 | 慢 | 稳 |
| DDIM | 100 | 快很多 | 较稳 |
| DDIM | 50 | 推荐 | 平衡 |
| DDIM | 20 | 很快 | 略差 |
十一、踩坑记录
坑1:time_steps顺序写反
DDIM采样必须从大时间步到小时间步:
text
T -> 0
如果写成 0 到 T,结果会完全错。
坑2:pred_x0不做clamp
预测出的 x0 可能超出 0~1。
建议:
python
pred_x0 = torch.clamp(pred_x0, 0.0, 1.0)
否则容易出现过曝或发黑。
坑3:步数太少导致结构崩
20步速度快,但质量不一定稳定。
建议先用50步作为默认值。
十二、适合收藏总结
DDIM加速流程
- 训练DDPM噪声预测模型
- 构建DDIMSampler
- 从1000步中均匀选择少量时间步
- 根据预测noise估计x0
- 跳步完成采样
避坑清单
- 时间步顺序必须反向
- pred_x0建议clamp
- 20步适合预览,50步更稳
- DDIM不需要重新训练模型
- 采样器要和DDPM参数一致
十三、优化建议
可以继续优化:
- 加eta参数
- 使用非均匀时间步
- 加EMA权重
- 改进UNet结构
- 用条件输入做真实图像去噪
结尾总结
DDIM解决的是扩散模型工程落地中最实际的问题:
DDPM质量可以,但太慢。
通过DDIM,我们可以在不重新训练模型的情况下,把采样速度提升一个数量级。
如果你准备把Diffusion用于图像去噪项目,DDIM几乎是必学内容。
下一篇预告
Pytorch图像去噪实战(十四):条件扩散模型图像去噪,让Diffusion根据带噪图恢复干净图