Pytorch图像去噪实战(八):Noise2Void盲点网络图像去噪实战,只有单张带噪图也能训练

Pytorch图像去噪实战(八):Noise2Void盲点网络图像去噪实战,只有单张带噪图也能训练


一、问题场景:我只有一批带噪图,没有clean,也没有成对noisy

上一篇我们讲了 Noise2Noise,它不需要干净图,但需要同一场景下两张独立带噪图。

但真实项目里还有更极端的情况:

我只有一批带噪图,每张图只有一份,没有clean,也没有 paired noisy。

比如:

  • 老照片扫描
  • 历史档案
  • 用户上传图片
  • 网络爬取低质量图片
  • 单次医学采集图像

这时 Noise2Noise 也不适用了。

那么还能不能训练去噪模型?

答案是可以,思路就是:Noise2Void / Blind-Spot Network


二、Noise2Void核心思想

Noise2Void的核心非常巧妙:

不让模型看到当前像素本身,而是让它根据周围像素预测当前像素。

因为自然图像的结构具有空间相关性,而随机噪声通常不具备这种稳定相关性。

举例:

一个像素周围都是文字边缘,那么当前像素大概率也是边缘附近。

但当前像素上的随机噪点无法从周围稳定预测出来。

因此模型会倾向学习图像结构,而不是噪声。


三、训练方式:随机遮挡像素

训练时,我们从 noisy image 中随机选择一些像素点,把它们替换掉。

模型输入被遮挡后的图像,目标是原图对应位置的像素值。

流程:

text 复制代码
noisy image
-> 随机mask部分像素
-> 模型预测
-> 只在mask位置计算loss

这就是盲点训练的基本思想。


四、工程目录结构

复制代码
noise2void_denoise/
├── data/
│   └── train/
├── models/
│   └── blind_unet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、数据集实现

这里数据集只返回 noisy image,不需要 clean。

dataset.py

python 复制代码
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class NoisyOnlyDataset(Dataset):
    def __init__(self, root_dir, image_size=256):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")
        img = self.transform(img)
        return img

六、构造盲点Mask

我们需要随机选择一些像素作为训练目标。

python 复制代码
import torch


def create_mask(x, mask_ratio=0.05):
    mask = torch.rand_like(x) < mask_ratio
    return mask.float()

然后对输入图像进行扰动:

python 复制代码
def apply_mask(x, mask):
    noise = torch.rand_like(x)
    masked_x = x * (1 - mask) + noise * mask
    return masked_x

注意:

loss只在mask位置计算,而不是整张图。


七、模型实现:轻量UNet

models/blind_unet.py

python 复制代码
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class BlindUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.pool = nn.MaxPool2d(2)

        self.enc1 = ConvBlock(1, 64)
        self.enc2 = ConvBlock(64, 128)
        self.bottleneck = ConvBlock(128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = ConvBlock(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = ConvBlock(128, 64)

        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))

        b = self.bottleneck(self.pool(e2))

        d2 = self.up2(b)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.out(d1)

八、训练代码

train.py

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import NoisyOnlyDataset
from models.blind_unet import BlindUNet


def create_mask(x, mask_ratio=0.05):
    return (torch.rand_like(x) < mask_ratio).float()


def apply_mask(x, mask):
    random_pixels = torch.rand_like(x)
    return x * (1 - mask) + random_pixels * mask


def masked_l1_loss(pred, target, mask):
    loss = torch.abs(pred - target) * mask
    return loss.sum() / (mask.sum() + 1e-8)


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = NoisyOnlyDataset("data/train")
    loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

    model = BlindUNet().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(1, 81):
        model.train()
        total_loss = 0

        for noisy in loader:
            noisy = noisy.to(device)

            mask = create_mask(noisy, mask_ratio=0.05)
            masked_input = apply_mask(noisy, mask)

            pred = model(masked_input)
            loss = masked_l1_loss(pred, noisy, mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"noise2void_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

九、推理代码

推理时不需要mask,直接输入完整noisy image。

python 复制代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.blind_unet import BlindUNet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BlindUNet().to(device)
model.load_state_dict(torch.load("noise2void_epoch_80.pth", map_location=device))
model.eval()

img = Image.open("test_noisy.png").convert("L")

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

noisy = transform(img).unsqueeze(0).to(device)

with torch.no_grad():
    pred = model(noisy)
    pred = torch.clamp(pred, 0.0, 1.0)

vutils.save_image(pred.cpu(), "noise2void_result.png")

十、为什么loss只在mask区域计算?

这是Noise2Void最关键的地方。

如果在整张图上计算loss:

python 复制代码
loss = L1(pred, noisy)

模型很容易学成 identity mapping,也就是直接复制输入。

这样不会去噪。

正确做法是:

python 复制代码
loss = L1(pred[mask], noisy[mask])

也就是只让模型预测被遮挡的位置。


十一、mask_ratio怎么选?

我实际测试后,建议:

text 复制代码
0.03 ~ 0.10

如果 mask 太少:

  • 训练信号弱
  • 收敛慢

如果 mask 太多:

  • 输入破坏严重
  • 图像结构变差

一般可以从:

python 复制代码
mask_ratio = 0.05

开始。


十二、踩坑记录

坑1:模型直接学会复制输入

原因:

  • 没有遮挡当前像素
  • loss在整张图上计算

解决:

  • 随机mask
  • 只在mask位置计算loss

坑2:输出过度平滑

原因:

  • mask比例太大
  • 模型太浅
  • 数据纹理复杂

解决:

  • 降低mask_ratio
  • 使用更强UNet
  • 加边缘损失

坑3:固定噪声去不掉

Noise2Void适合随机噪声。

如果噪声是固定条纹、周期纹理,模型可能会把它当成结构。

解决:

  • 对固定噪声单独建模
  • 加频域滤波
  • 增加多样化训练数据

十三、效果验证

Noise2Void的优势是数据要求低,但效果通常弱于有监督模型。

方法 是否需要clean 是否需要成对noisy 效果
UNet监督训练 需要 不需要
Noise2Noise 不需要 需要 中上
Noise2Void 不需要 不需要 中等

十四、适合收藏总结

Noise2Void完整流程

  1. 准备noisy-only数据
  2. 随机mask部分像素
  3. 用周围信息预测mask像素
  4. 只在mask区域计算loss
  5. 推理时直接输入完整图像

避坑清单

  • loss不能算整张图
  • mask比例不要太大
  • 适合随机噪声
  • 固定噪声效果有限
  • 推理时不需要mask

十五、优化建议

可以继续优化:

  • 用盲点卷积替代随机mask
  • 加多尺度UNet
  • 加频域约束
  • 加边缘损失
  • 结合Noise2Noise增强训练

结尾总结

Noise2Void真正解决的是一个很现实的问题:

没有干净图,也没有成对噪声图,只有一批脏图,能不能训练去噪模型?

答案是可以,但要理解它的限制。

它不是万能模型,但在真实数据难以标注的场景中,是非常值得收藏和尝试的一类自监督去噪方法。


下一篇预告

Pytorch图像去噪实战(九):SwinIR图像去噪实战,用Transformer提升纹理恢复能力

相关推荐
才兄说1 小时前
机器人二次开发机器狗巡检?路径覆盖率100%
python
梦想很大很大1 小时前
让 AI 成为“报表配置员”:BI 低代码平台的 Schema 实践路径
前端·人工智能·低代码
隔壁大炮1 小时前
Day07-RNN层(循环网络层)
人工智能·pytorch·python·rnn·深度学习·神经网络·计算机视觉
itzixiao1 小时前
L1-066 猫是液体(5分)[java][python]
java·开发语言·python·算法
zhoutongsheng1 小时前
如何解决ORA-01078参数文件错误_pfile与spfile互相创建恢复
jvm·数据库·python
小饕1 小时前
从 Word2Vec 到多模态:词嵌入技术的演进全景
人工智能·算法·机器学习
上海云盾第一敬业销售1 小时前
生成式AI催生深度伪造攻击,WAF如何识别“假流量“?
人工智能
ykjhr_3d1 小时前
数字工具AI智能学伴,助力教育数字化转型
大数据·人工智能·ai·ai人工智能·华锐视点·华锐云空间
Lightning-py1 小时前
Python 配置日志(Logging)
开发语言·python