Pytorch图像去噪实战(十一):Diffusion扩散模型去噪入门,从噪声预测理解生成式图像恢复
一、问题场景:传统去噪模型能用,但上限开始明显
前面我们已经做了 DnCNN、UNet、ResUNet、Attention UNet、FFDNet、CBDNet、Noise2Noise、Noise2Void、SwinIR、Restormer。
这些模型有一个共同特点:
大多数都是直接学习 noisy -> clean,或者 noisy -> noise。
在普通图像去噪任务里,这样已经够用。
但当我处理一些复杂图像时,问题开始变明显:
- 高噪声图像细节恢复差
- 老照片去噪后纹理不自然
- 真实噪声图像容易残留伪影
- 强去噪后图像发糊
- 模型对未知噪声泛化不足
这时就会接触到一个更强的方向:Diffusion Model 扩散模型。
扩散模型不是简单做一次映射,而是学习一个逐步去噪过程。
二、Diffusion去噪和普通去噪有什么区别?
普通去噪模型:
text
noisy_image -> clean_image
Diffusion模型:
text
clean_image -> 不断加噪 -> pure noise
pure noise -> 逐步去噪 -> clean_image
在训练阶段,它学习的是:
给定某一步的带噪图像,预测其中的噪声。
也就是:
text
x_t -> noise
这和 DnCNN 的"预测噪声"思想有相似之处,但 Diffusion 更进一步,把噪声过程拆成了很多步。
三、核心思想:前向加噪与反向去噪
1. 前向过程
从干净图像 x0 开始,逐步加入噪声:
text
x0 -> x1 -> x2 -> ... -> xT
最后 xT 接近纯噪声。
2. 反向过程
模型学习从 xT 一步步恢复:
text
xT -> xT-1 -> ... -> x0
训练目标通常是预测噪声 epsilon。
四、工程目录结构
text
diffusion_denoise/
├── data/
│ └── train/
├── models/
│ └── simple_unet.py
├── diffusion.py
├── dataset.py
├── train.py
├── sample.py
└── utils.py
五、数据集准备
这里先做灰度图像去噪,方便理解扩散模型流程。
python
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, root_dir, image_size=64):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("L")
return self.transform(img)
六、扩散过程实现
diffusion.py
python
import torch
class GaussianDiffusion:
def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
self.timesteps = timesteps
self.device = device
self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
self.alphas = 1.0 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def add_noise(self, x0, t):
noise = torch.randn_like(x0)
alpha_bar = self.alpha_bars[t].view(-1, 1, 1, 1)
noisy = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1.0 - alpha_bar) * noise
return noisy, noise
七、时间步编码
Diffusion模型必须知道当前是第几步噪声。
python
import torch
import torch.nn as nn
import math
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.SiLU(),
nn.Linear(dim * 4, dim)
)
def forward(self, t):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
return self.mlp(emb)
八、简化版UNet噪声预测网络
models/simple_unet.py
python
import torch
import torch.nn as nn
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
def forward(self, t):
t = t.float().view(-1, 1) / 1000.0
return self.net(t)
class SimpleDenoiseUNet(nn.Module):
def __init__(self, channels=1, base=64, time_dim=128):
super().__init__()
self.time_mlp = TimeEmbedding(time_dim)
self.conv1 = nn.Conv2d(channels, base, 3, padding=1)
self.conv2 = nn.Conv2d(base, base, 3, padding=1)
self.conv3 = nn.Conv2d(base, channels, 3, padding=1)
self.time_proj = nn.Linear(time_dim, base)
self.act = nn.SiLU()
def forward(self, x, t):
time_emb = self.time_mlp(t)
time_emb = self.time_proj(time_emb).view(x.size(0), -1, 1, 1)
h = self.act(self.conv1(x))
h = h + time_emb
h = self.act(self.conv2(h))
return self.conv3(h)
九、训练代码
train.py
python
import torch
from torch.utils.data import DataLoader
from dataset import ImageDataset
from diffusion import GaussianDiffusion
from models.simple_unet import SimpleDenoiseUNet
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ImageDataset("data/train", image_size=64)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
model = SimpleDenoiseUNet().to(device)
diffusion = GaussianDiffusion(timesteps=1000, device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
criterion = torch.nn.MSELoss()
for epoch in range(1, 101):
model.train()
total_loss = 0
for x0 in loader:
x0 = x0.to(device)
t = torch.randint(0, diffusion.timesteps, (x0.size(0),), device=device)
xt, noise = diffusion.add_noise(x0, t)
pred_noise = model(xt, t)
loss = criterion(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"diffusion_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
十、为什么Diffusion训练预测noise,而不是预测clean?
这是很多人第一次学扩散模型时最容易疑惑的地方。
如果直接预测 clean:
text
model(x_t, t) -> x0
模型在高噪声阶段很难恢复完整图像。
而预测 noise:
text
model(x_t, t) -> epsilon
训练目标更稳定,也更符合扩散模型的数学推导。
工程上看,预测噪声还有一个优点:
loss更稳定,模型更容易收敛。
十一、采样过程简化实现
下面写一个简化版采样逻辑,帮助理解反向去噪。
python
import torch
import torchvision.utils as vutils
from diffusion import GaussianDiffusion
from models.simple_unet import SimpleDenoiseUNet
@torch.no_grad()
def sample():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleDenoiseUNet().to(device)
model.load_state_dict(torch.load("diffusion_epoch_100.pth", map_location=device))
model.eval()
diffusion = GaussianDiffusion(timesteps=1000, device=device)
x = torch.randn(16, 1, 64, 64).to(device)
for i in reversed(range(diffusion.timesteps)):
t = torch.full((x.size(0),), i, device=device, dtype=torch.long)
pred_noise = model(x, t)
beta = diffusion.betas[i]
alpha = diffusion.alphas[i]
alpha_bar = diffusion.alpha_bars[i]
x = (1 / torch.sqrt(alpha)) * (
x - (beta / torch.sqrt(1 - alpha_bar)) * pred_noise
)
if i > 0:
noise = torch.randn_like(x)
x = x + torch.sqrt(beta) * noise
x = torch.clamp(x, 0.0, 1.0)
vutils.save_image(x.cpu(), "diffusion_samples.png", nrow=4)
if __name__ == "__main__":
sample()
十二、踩坑记录
坑1:时间步没有输入模型
Diffusion模型必须知道 t。
如果只输入 x_t,不输入 t,模型不知道当前噪声强度,训练会非常差。
坑2:学习率过大导致loss震荡
扩散模型训练比普通UNet更敏感。
建议:
python
lr = 2e-4
如果不稳定,降到:
python
lr = 1e-4
坑3:图像尺寸一开始不要太大
Diffusion训练成本高。
建议从:
text
64x64
开始,流程跑通后再放大。
十三、适合收藏总结
Diffusion去噪训练流程
- 读取干净图像
- 随机采样时间步 t
- 根据 t 给图像加噪
- 模型预测噪声
- 用真实噪声监督训练
- 推理时逐步反向去噪
避坑清单
- 必须输入时间步
- 训练目标建议预测noise
- 初期图像尺寸别太大
- 学习率不要过高
- 采样速度较慢是正常现象
十四、优化建议
可以继续升级:
- 更完整UNet结构
- 加Attention模块
- 使用DDIM加速采样
- 支持条件去噪
- 使用真实噪声数据微调
结尾总结
Diffusion模型的核心不是"一个更大的UNet",而是一套新的去噪建模方式:
把图像恢复拆成多个连续的小步骤,让模型逐步从噪声中恢复结构。
如果你已经理解 DnCNN 的残差噪声预测,那么学习 Diffusion 会更容易,因为它本质上也是在学噪声,只是把这个过程做得更细。
下一篇预告
Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程