Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程
一、问题场景:扩散模型能跑,但工程代码很容易写乱
上一篇我们从最小实现理解了 Diffusion 的核心逻辑。
但如果真正放到项目里,会很快遇到问题:
- beta schedule 写在训练脚本里,后续不好改
- 采样逻辑和训练逻辑混在一起
- 模型保存与恢复不规范
- 训练参数不可复现
- 后续无法扩展 DDIM、条件去噪、彩色图像
很多人学扩散模型时,能写出一个 demo,但很难整理成工程。
这一篇我们重点做一件事:
把 DDPM 图像去噪流程整理成一个可复现、可扩展的工程结构。
二、DDPM核心训练目标
DDPM训练目标仍然是预测噪声:
text
epsilon_theta(x_t, t) ≈ epsilon
训练时:
- 从数据集中取 clean image x0
- 随机采样时间步 t
- 根据 t 给 x0 加噪得到 xt
- 模型输入 xt 和 t
- 模型预测 noise
- 使用 MSELoss 训练
三、推荐工程结构
text
ddpm_denoise/
├── configs/
│ └── train_config.py
├── data/
│ └── train/
├── models/
│ └── unet.py
├── diffusion/
│ └── ddpm.py
├── dataset.py
├── train.py
├── sample.py
└── utils.py
这个结构相比简单 demo 有几个好处:
- 模型独立
- 扩散过程独立
- 配置独立
- 训练和采样分离
- 后续扩展方便
四、配置文件
configs/train_config.py
python
class TrainConfig:
image_size = 64
channels = 1
batch_size = 32
num_workers = 4
epochs = 100
lr = 2e-4
timesteps = 1000
beta_start = 1e-4
beta_end = 0.02
save_interval = 10
data_dir = "data/train"
save_dir = "checkpoints"
配置单独抽出来,最大的好处是:
实验参数不会散落在代码里。
后面复现实验时非常重要。
五、数据集代码
dataset.py
python
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, image_size=64, channels=1):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".jpeg", ".png"))
]
if channels == 1:
self.mode = "L"
else:
self.mode = "RGB"
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
img = Image.open(self.paths[index]).convert(self.mode)
return self.transform(img)
六、DDPM扩散类封装
diffusion/ddpm.py
python
import torch
class DDPM:
def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
self.timesteps = timesteps
self.device = device
self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
self.alphas = 1.0 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
self.sqrt_one_minus_alpha_bars = torch.sqrt(1.0 - self.alpha_bars)
def q_sample(self, x0, t, noise=None):
if noise is None:
noise = torch.randn_like(x0)
sqrt_alpha_bar = self.sqrt_alpha_bars[t].view(-1, 1, 1, 1)
sqrt_one_minus = self.sqrt_one_minus_alpha_bars[t].view(-1, 1, 1, 1)
xt = sqrt_alpha_bar * x0 + sqrt_one_minus * noise
return xt, noise
@torch.no_grad()
def p_sample(self, model, x, t):
beta = self.betas[t]
alpha = self.alphas[t]
alpha_bar = self.alpha_bars[t]
batch_t = torch.full((x.size(0),), t, device=x.device, dtype=torch.long)
pred_noise = model(x, batch_t)
mean = (1 / torch.sqrt(alpha)) * (
x - (beta / torch.sqrt(1.0 - alpha_bar)) * pred_noise
)
if t > 0:
noise = torch.randn_like(x)
return mean + torch.sqrt(beta) * noise
return mean
七、UNet噪声预测模型
models/unet.py
python
import torch
import torch.nn as nn
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
def forward(self, t):
t = t.float().view(-1, 1) / 1000.0
return self.net(t)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.time_proj = nn.Linear(time_dim, out_channels)
self.shortcut = nn.Identity()
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
self.act = nn.SiLU()
def forward(self, x, t_emb):
h = self.act(self.conv1(x))
time = self.time_proj(t_emb).view(x.size(0), -1, 1, 1)
h = h + time
h = self.conv2(self.act(h))
return h + self.shortcut(x)
class DDPMUNet(nn.Module):
def __init__(self, channels=1, base=64, time_dim=128):
super().__init__()
self.time_mlp = TimeEmbedding(time_dim)
self.down1 = ResidualBlock(channels, base, time_dim)
self.down2 = ResidualBlock(base, base * 2, time_dim)
self.pool = nn.MaxPool2d(2)
self.mid = ResidualBlock(base * 2, base * 2, time_dim)
self.up = nn.ConvTranspose2d(base * 2, base, 2, 2)
self.up_block = ResidualBlock(base * 2, base, time_dim)
self.out = nn.Conv2d(base, channels, 3, padding=1)
def forward(self, x, t):
t_emb = self.time_mlp(t)
d1 = self.down1(x, t_emb)
d2 = self.down2(self.pool(d1), t_emb)
mid = self.mid(d2, t_emb)
u = self.up(mid)
u = torch.cat([u, d1], dim=1)
u = self.up_block(u, t_emb)
return self.out(u)
八、训练脚本
train.py
python
import os
import torch
from torch.utils.data import DataLoader
from configs.train_config import TrainConfig
from dataset import ImageFolderDataset
from models.unet import DDPMUNet
from diffusion.ddpm import DDPM
def train():
cfg = TrainConfig()
os.makedirs(cfg.save_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ImageFolderDataset(
root_dir=cfg.data_dir,
image_size=cfg.image_size,
channels=cfg.channels
)
loader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers
)
model = DDPMUNet(channels=cfg.channels).to(device)
diffusion = DDPM(
timesteps=cfg.timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
device=device
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
criterion = torch.nn.MSELoss()
for epoch in range(1, cfg.epochs + 1):
model.train()
total_loss = 0
for x0 in loader:
x0 = x0.to(device)
t = torch.randint(0, cfg.timesteps, (x0.size(0),), device=device)
xt, noise = diffusion.q_sample(x0, t)
pred_noise = model(xt, t)
loss = criterion(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch [{epoch}/{cfg.epochs}], Loss: {avg_loss:.6f}")
if epoch % cfg.save_interval == 0:
path = os.path.join(cfg.save_dir, f"ddpm_epoch_{epoch}.pth")
torch.save(model.state_dict(), path)
if __name__ == "__main__":
train()
九、采样脚本
sample.py
python
import torch
import torchvision.utils as vutils
from configs.train_config import TrainConfig
from models.unet import DDPMUNet
from diffusion.ddpm import DDPM
@torch.no_grad()
def sample():
cfg = TrainConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DDPMUNet(channels=cfg.channels).to(device)
model.load_state_dict(torch.load("checkpoints/ddpm_epoch_100.pth", map_location=device))
model.eval()
diffusion = DDPM(
timesteps=cfg.timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
device=device
)
x = torch.randn(16, cfg.channels, cfg.image_size, cfg.image_size).to(device)
for t in reversed(range(cfg.timesteps)):
x = diffusion.p_sample(model, x, t)
x = torch.clamp(x, 0.0, 1.0)
vutils.save_image(x.cpu(), "ddpm_sample.png", nrow=4)
if __name__ == "__main__":
sample()
十、为什么要做工程拆分?
很多扩散模型代码一开始写在一个文件里,能跑,但很难维护。
工程拆分带来的好处:
- diffusion类可复用
- UNet可替换
- config方便调参
- train和sample互不干扰
- 后续DDIM可以直接扩展
这也是从"能跑demo"到"能做项目"的关键一步。
十一、踩坑记录
坑1:采样结果全是噪声
常见原因:
- 模型训练不够
- 时间步输入错误
- beta schedule太激进
- 采样公式写错
建议先用小数据集验证过拟合能力。
坑2:loss下降但采样效果差
DDPM的loss下降不代表马上能生成好图。
采样质量通常需要更多训练轮数。
坑3:训练太慢
DDPM采样慢是正常现象,因为要从 T 逐步采样。
后续可以使用 DDIM 或减少 timesteps。
十二、适合收藏总结
DDPM工程化流程
- 配置文件管理参数
- Dataset加载图像
- DDPM类负责加噪和采样
- UNet预测噪声
- train.py训练模型
- sample.py生成结果
避坑清单
- 不要把所有代码写一个文件
- 时间步必须正确传入
- beta schedule要稳定
- 采样结果差不一定是loss问题
- 先用小尺寸图跑通
十三、优化建议
后续可以继续做:
- DDIM加速采样
- 条件Diffusion去噪
- 彩色图像支持
- EMA模型权重
- 混合精度训练
结尾总结
DDPM不是一个单独模型,而是一套完整的扩散训练和采样框架。
如果你只是写一个demo,很容易跑通;但如果要长期做系列实验,就必须从一开始整理好工程结构。
这一篇的重点不是追求最强效果,而是把DDPM搭成一个稳定可复现的项目骨架。
下一篇预告
Pytorch图像去噪实战(十三):DDIM加速采样,让扩散模型去噪从1000步降到50步