Pytorch图像去噪实战(十四):条件扩散模型图像去噪,让Diffusion根据带噪图恢复干净图

Pytorch图像去噪实战(十四):条件扩散模型图像去噪,让Diffusion根据带噪图恢复干净图


一、问题场景:普通Diffusion能生成图,但不能直接修复指定图片

前面我们实现了 DDPM 和 DDIM。

但如果你仔细看,会发现之前的采样方式是:

text 复制代码
从纯噪声开始生成图像

这更像是生成任务。

而真实图像去噪任务通常是:

text 复制代码
给定一张带噪图,输出它对应的干净图

也就是说,我们不是要随机生成图片,而是要修复指定图片。

这时普通无条件Diffusion就不够用了,需要引入:

条件扩散模型 Conditional Diffusion


二、条件扩散去噪的核心思想

普通Diffusion输入:

text 复制代码
x_t, t

条件Diffusion输入:

text 复制代码
x_t, noisy_condition, t

其中:

  • x_t:扩散过程中的 noisy clean image
  • noisy_condition:真实带噪图
  • t:时间步

模型学习:

text 复制代码
predict noise from x_t with condition

也就是让模型在反向去噪时参考原始带噪图。


三、为什么需要condition?

如果没有condition,模型生成的是随机干净图,不一定和输入图片内容一致。

加入condition后,模型知道:

  • 图像结构是什么
  • 边缘在哪里
  • 文字位置在哪里
  • 物体轮廓在哪里

因此它可以围绕输入图像做恢复,而不是凭空生成。


四、工程结构

text 复制代码
conditional_diffusion_denoise/
├── data/
│   └── train/
├── models/
│   └── conditional_unet.py
├── diffusion/
│   └── ddpm.py
├── dataset.py
├── train.py
├── infer.py
└── utils.py

五、数据集构造

训练时我们有 clean 图,然后人工加噪得到 condition。

python 复制代码
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class ConditionalDenoiseDataset(Dataset):
    def __init__(self, root_dir, image_size=64):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        clean = Image.open(self.paths[index]).convert("L")
        clean = self.transform(clean)

        sigma = random.choice([15, 25, 35, 50])
        noise = torch.randn_like(clean) * sigma / 255.0
        noisy_condition = torch.clamp(clean + noise, 0.0, 1.0)

        return noisy_condition, clean

六、条件UNet模型

核心改动非常简单:

把 x_t 和 noisy_condition 在通道维度拼接。

如果是灰度图:

text 复制代码
x_t: 1通道
condition: 1通道
concat后: 2通道

models/conditional_unet.py

python 复制代码
import torch
import torch.nn as nn


class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, t):
        t = t.float().view(-1, 1) / 1000.0
        return self.net(t)


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_proj = nn.Linear(time_dim, out_channels)

        self.shortcut = nn.Identity()
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)

        self.act = nn.SiLU()

    def forward(self, x, t_emb):
        h = self.act(self.conv1(x))

        time = self.time_proj(t_emb).view(x.size(0), -1, 1, 1)
        h = h + time

        h = self.conv2(self.act(h))

        return h + self.shortcut(x)


class ConditionalUNet(nn.Module):
    def __init__(self, image_channels=1, base=64, time_dim=128):
        super().__init__()

        self.time_mlp = TimeEmbedding(time_dim)

        in_channels = image_channels * 2

        self.down1 = ResBlock(in_channels, base, time_dim)
        self.down2 = ResBlock(base, base * 2, time_dim)

        self.pool = nn.MaxPool2d(2)

        self.mid = ResBlock(base * 2, base * 2, time_dim)

        self.up = nn.ConvTranspose2d(base * 2, base, 2, 2)
        self.up_block = ResBlock(base * 2, base, time_dim)

        self.out = nn.Conv2d(base, image_channels, 3, padding=1)

    def forward(self, xt, condition, t):
        t_emb = self.time_mlp(t)

        x = torch.cat([xt, condition], dim=1)

        d1 = self.down1(x, t_emb)
        d2 = self.down2(self.pool(d1), t_emb)

        mid = self.mid(d2, t_emb)

        u = self.up(mid)
        u = torch.cat([u, d1], dim=1)
        u = self.up_block(u, t_emb)

        return self.out(u)

七、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader

from dataset import ConditionalDenoiseDataset
from diffusion.ddpm import DDPM
from models.conditional_unet import ConditionalUNet


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = ConditionalDenoiseDataset("data/train", image_size=64)
    loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

    model = ConditionalUNet().to(device)

    diffusion = DDPM(
        timesteps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        device=device
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    criterion = torch.nn.MSELoss()

    for epoch in range(1, 101):
        model.train()
        total_loss = 0

        for condition, clean in loader:
            condition = condition.to(device)
            clean = clean.to(device)

            t = torch.randint(0, diffusion.timesteps, (clean.size(0),), device=device)

            xt, noise = diffusion.q_sample(clean, t)

            pred_noise = model(xt, condition, t)

            loss = criterion(pred_noise, noise)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"conditional_diffusion_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

八、推理代码

推理时输入一张真实 noisy image 作为 condition。

python 复制代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils

from diffusion.ddpm import DDPM
from models.conditional_unet import ConditionalUNet


@torch.no_grad()
def infer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ConditionalUNet().to(device)
    model.load_state_dict(torch.load("conditional_diffusion_epoch_100.pth", map_location=device))
    model.eval()

    diffusion = DDPM(
        timesteps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        device=device
    )

    img = Image.open("test_noisy.png").convert("L")

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])

    condition = transform(img).unsqueeze(0).to(device)

    x = torch.randn_like(condition)

    for t in reversed(range(diffusion.timesteps)):
        batch_t = torch.full((1,), t, device=device, dtype=torch.long)

        pred_noise = model(x, condition, batch_t)

        beta = diffusion.betas[t]
        alpha = diffusion.alphas[t]
        alpha_bar = diffusion.alpha_bars[t]

        x = (1 / torch.sqrt(alpha)) * (
            x - (beta / torch.sqrt(1 - alpha_bar)) * pred_noise
        )

        if t > 0:
            x = x + torch.sqrt(beta) * torch.randn_like(x)

    x = torch.clamp(x, 0.0, 1.0)
    vutils.save_image(x.cpu(), "conditional_denoised.png")


if __name__ == "__main__":
    infer()

九、为什么条件图不能直接作为初始x?

很多人第一次写条件扩散时,会想:

text 复制代码
直接从 noisy image 开始反向去噪不就行了?

但标准条件扩散里,反向过程的变量 x 是目标 clean 的扩散状态,而 noisy image 是条件信息。

两者角色不同:

  • x:当前正在生成的 clean image 状态
  • condition:引导恢复的输入图

如果混在一起,模型训练和推理分布会不一致。


十、和普通UNet去噪相比有什么优势?

普通UNet:

text 复制代码
noisy -> clean

条件Diffusion:

text 复制代码
noise state + noisy condition -> clean distribution

优势在于:

  • 更适合复杂噪声
  • 可以生成更自然细节
  • 对强噪声恢复潜力更高

缺点也明显:

  • 训练更慢
  • 推理更慢
  • 工程复杂度更高

十一、踩坑记录

坑1:condition没有拼接进模型

如果模型只输入 xt 和 t,那就是无条件生成,不是图像去噪。


坑2:condition和clean尺寸不一致

训练时 condition 和 clean 必须尺寸一致。

建议在 dataset 中统一 resize。


坑3:采样太慢

条件Diffusion同样有1000步采样问题。

建议后续结合DDIM。


十二、适合收藏总结

条件Diffusion去噪流程

  1. 从clean构造noisy condition
  2. 对clean执行扩散加噪
  3. 模型输入 xt + condition + t
  4. 模型预测noise
  5. 推理时用condition引导反向去噪

避坑清单

  • condition必须输入模型
  • clean和condition尺寸一致
  • x和condition角色不要混
  • 推理成本较高
  • 建议结合DDIM加速

十三、优化建议

可以继续做:

  • 条件DDIM采样
  • 加强UNet结构
  • 使用Restormer作为条件网络
  • 支持RGB图像
  • 用真实噪声数据微调

结尾总结

条件扩散模型把Diffusion从"随机生成图像"推进到"指定图像恢复"。

它的核心价值是:

既保留扩散模型强大的生成能力,又让模型受输入带噪图约束。

如果你要把Diffusion用于真正的图像去噪任务,条件扩散是必须掌握的一步。


下一篇预告

Pytorch图像去噪实战(十五):彩色RGB图像去噪实战,从灰度模型升级到真实图片处理

相关推荐
li星野5 小时前
FastAPI 项目加入 WebSocket 支持
python·websocket·fastapi
tangweiguo030519875 小时前
LangGraph 入门:多智能体工作流实战(阿里云百炼)
人工智能·python·langchain
饭后一颗花生米5 小时前
AI算力选型全景指南:从入门到旗舰的硬核实操
人工智能
Yue栎廷5 小时前
邪修:Markdown加粗语法**本土化改造
前端·javascript·人工智能
2301_815279525 小时前
实战分享LangChain WebUI 部署智能客服:从零搭建到生产环境优化
人工智能·langchain
三维频道5 小时前
柔性材料3D数字化:蓝光扫描在内衣胸垫设计与质检中的应用
人工智能·3d·逆向工程·蓝光3d扫描仪·服装数字化·内衣设计·柔性材料检测
科研前沿5 小时前
镜像视界浙江科技有限公司的核心引擎关键技术有哪些?
人工智能·数码相机·计算机视觉
帅次5 小时前
Android AI 面试速刷版
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·数据分析
生物信息与育种5 小时前
全基因组重测序及群体遗传与进化分析技术服务指南
人工智能·深度学习·算法·数据分析·r语言