Pytorch图像去噪实战(十四):条件扩散模型图像去噪,让Diffusion根据带噪图恢复干净图
一、问题场景:普通Diffusion能生成图,但不能直接修复指定图片
前面我们实现了 DDPM 和 DDIM。
但如果你仔细看,会发现之前的采样方式是:
text
从纯噪声开始生成图像
这更像是生成任务。
而真实图像去噪任务通常是:
text
给定一张带噪图,输出它对应的干净图
也就是说,我们不是要随机生成图片,而是要修复指定图片。
这时普通无条件Diffusion就不够用了,需要引入:
条件扩散模型 Conditional Diffusion
二、条件扩散去噪的核心思想
普通Diffusion输入:
text
x_t, t
条件Diffusion输入:
text
x_t, noisy_condition, t
其中:
- x_t:扩散过程中的 noisy clean image
- noisy_condition:真实带噪图
- t:时间步
模型学习:
text
predict noise from x_t with condition
也就是让模型在反向去噪时参考原始带噪图。
三、为什么需要condition?
如果没有condition,模型生成的是随机干净图,不一定和输入图片内容一致。
加入condition后,模型知道:
- 图像结构是什么
- 边缘在哪里
- 文字位置在哪里
- 物体轮廓在哪里
因此它可以围绕输入图像做恢复,而不是凭空生成。
四、工程结构
text
conditional_diffusion_denoise/
├── data/
│ └── train/
├── models/
│ └── conditional_unet.py
├── diffusion/
│ └── ddpm.py
├── dataset.py
├── train.py
├── infer.py
└── utils.py
五、数据集构造
训练时我们有 clean 图,然后人工加噪得到 condition。
python
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class ConditionalDenoiseDataset(Dataset):
def __init__(self, root_dir, image_size=64):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
clean = Image.open(self.paths[index]).convert("L")
clean = self.transform(clean)
sigma = random.choice([15, 25, 35, 50])
noise = torch.randn_like(clean) * sigma / 255.0
noisy_condition = torch.clamp(clean + noise, 0.0, 1.0)
return noisy_condition, clean
六、条件UNet模型
核心改动非常简单:
把 x_t 和 noisy_condition 在通道维度拼接。
如果是灰度图:
text
x_t: 1通道
condition: 1通道
concat后: 2通道
models/conditional_unet.py
python
import torch
import torch.nn as nn
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
def forward(self, t):
t = t.float().view(-1, 1) / 1000.0
return self.net(t)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.time_proj = nn.Linear(time_dim, out_channels)
self.shortcut = nn.Identity()
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
self.act = nn.SiLU()
def forward(self, x, t_emb):
h = self.act(self.conv1(x))
time = self.time_proj(t_emb).view(x.size(0), -1, 1, 1)
h = h + time
h = self.conv2(self.act(h))
return h + self.shortcut(x)
class ConditionalUNet(nn.Module):
def __init__(self, image_channels=1, base=64, time_dim=128):
super().__init__()
self.time_mlp = TimeEmbedding(time_dim)
in_channels = image_channels * 2
self.down1 = ResBlock(in_channels, base, time_dim)
self.down2 = ResBlock(base, base * 2, time_dim)
self.pool = nn.MaxPool2d(2)
self.mid = ResBlock(base * 2, base * 2, time_dim)
self.up = nn.ConvTranspose2d(base * 2, base, 2, 2)
self.up_block = ResBlock(base * 2, base, time_dim)
self.out = nn.Conv2d(base, image_channels, 3, padding=1)
def forward(self, xt, condition, t):
t_emb = self.time_mlp(t)
x = torch.cat([xt, condition], dim=1)
d1 = self.down1(x, t_emb)
d2 = self.down2(self.pool(d1), t_emb)
mid = self.mid(d2, t_emb)
u = self.up(mid)
u = torch.cat([u, d1], dim=1)
u = self.up_block(u, t_emb)
return self.out(u)
七、训练代码
python
import torch
from torch.utils.data import DataLoader
from dataset import ConditionalDenoiseDataset
from diffusion.ddpm import DDPM
from models.conditional_unet import ConditionalUNet
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ConditionalDenoiseDataset("data/train", image_size=64)
loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)
model = ConditionalUNet().to(device)
diffusion = DDPM(
timesteps=1000,
beta_start=1e-4,
beta_end=0.02,
device=device
)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
criterion = torch.nn.MSELoss()
for epoch in range(1, 101):
model.train()
total_loss = 0
for condition, clean in loader:
condition = condition.to(device)
clean = clean.to(device)
t = torch.randint(0, diffusion.timesteps, (clean.size(0),), device=device)
xt, noise = diffusion.q_sample(clean, t)
pred_noise = model(xt, condition, t)
loss = criterion(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"conditional_diffusion_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
八、推理代码
推理时输入一张真实 noisy image 作为 condition。
python
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from diffusion.ddpm import DDPM
from models.conditional_unet import ConditionalUNet
@torch.no_grad()
def infer():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConditionalUNet().to(device)
model.load_state_dict(torch.load("conditional_diffusion_epoch_100.pth", map_location=device))
model.eval()
diffusion = DDPM(
timesteps=1000,
beta_start=1e-4,
beta_end=0.02,
device=device
)
img = Image.open("test_noisy.png").convert("L")
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
condition = transform(img).unsqueeze(0).to(device)
x = torch.randn_like(condition)
for t in reversed(range(diffusion.timesteps)):
batch_t = torch.full((1,), t, device=device, dtype=torch.long)
pred_noise = model(x, condition, batch_t)
beta = diffusion.betas[t]
alpha = diffusion.alphas[t]
alpha_bar = diffusion.alpha_bars[t]
x = (1 / torch.sqrt(alpha)) * (
x - (beta / torch.sqrt(1 - alpha_bar)) * pred_noise
)
if t > 0:
x = x + torch.sqrt(beta) * torch.randn_like(x)
x = torch.clamp(x, 0.0, 1.0)
vutils.save_image(x.cpu(), "conditional_denoised.png")
if __name__ == "__main__":
infer()
九、为什么条件图不能直接作为初始x?
很多人第一次写条件扩散时,会想:
text
直接从 noisy image 开始反向去噪不就行了?
但标准条件扩散里,反向过程的变量 x 是目标 clean 的扩散状态,而 noisy image 是条件信息。
两者角色不同:
- x:当前正在生成的 clean image 状态
- condition:引导恢复的输入图
如果混在一起,模型训练和推理分布会不一致。
十、和普通UNet去噪相比有什么优势?
普通UNet:
text
noisy -> clean
条件Diffusion:
text
noise state + noisy condition -> clean distribution
优势在于:
- 更适合复杂噪声
- 可以生成更自然细节
- 对强噪声恢复潜力更高
缺点也明显:
- 训练更慢
- 推理更慢
- 工程复杂度更高
十一、踩坑记录
坑1:condition没有拼接进模型
如果模型只输入 xt 和 t,那就是无条件生成,不是图像去噪。
坑2:condition和clean尺寸不一致
训练时 condition 和 clean 必须尺寸一致。
建议在 dataset 中统一 resize。
坑3:采样太慢
条件Diffusion同样有1000步采样问题。
建议后续结合DDIM。
十二、适合收藏总结
条件Diffusion去噪流程
- 从clean构造noisy condition
- 对clean执行扩散加噪
- 模型输入 xt + condition + t
- 模型预测noise
- 推理时用condition引导反向去噪
避坑清单
- condition必须输入模型
- clean和condition尺寸一致
- x和condition角色不要混
- 推理成本较高
- 建议结合DDIM加速
十三、优化建议
可以继续做:
- 条件DDIM采样
- 加强UNet结构
- 使用Restormer作为条件网络
- 支持RGB图像
- 用真实噪声数据微调
结尾总结
条件扩散模型把Diffusion从"随机生成图像"推进到"指定图像恢复"。
它的核心价值是:
既保留扩散模型强大的生成能力,又让模型受输入带噪图约束。
如果你要把Diffusion用于真正的图像去噪任务,条件扩散是必须掌握的一步。
下一篇预告
Pytorch图像去噪实战(十五):彩色RGB图像去噪实战,从灰度模型升级到真实图片处理