UNet超分 效果测试

使用的数据集BSDS300 : https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/ BSDS300数据集下载地

train图片200张,test图片100张 图像大小321x481

以下UNet使用了亚像素卷积(Sub-pixel Convolution)来进行放大。

cpp 复制代码
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image

# =================================
# 简易 UNet(**最小改动:加入 tail 上采样模块**)
# =================================
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.mid = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()
        )

        self.up1 = nn.ConvTranspose2d(64, 32, 2, 2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU()
        )

        self.tail = nn.Sequential(
            nn.Conv2d(32, 12, 3, padding=1),   # 12 = 3 * (2^2)
            nn.PixelShuffle(2),                 # H,W -> 2x, channels -> 3   [B, C×(r²), H, W] -> [B, C, H×r, W×r] 模型现在低分辨率空间预测,应该长什么样子的子像素块,再把这些子像素块重新排序成更大的图像,这就是Sub-pixel Convolution(亚像素卷积)

            nn.Conv2d(3, 12, 3, padding=1),
            nn.PixelShuffle(2)                  # 再次 2x -> 总共 4x
        )

        # keep self.out name removed to avoid confusion (was: self.out = nn.Conv2d(32,3,1))

    def forward(self, x):
        c1 = self.enc1(x)
        p1 = self.pool1(c1)
        c2 = self.enc2(p1)
        p2 = self.pool2(c2)
        m = self.mid(p2)
        u2 = self.up2(m)
        u2 = torch.cat([u2, c2], 1)
        d2 = self.dec2(u2)
        u1 = self.up1(d2)
        u1 = torch.cat([u1, c1], 1)
        d1 = self.dec1(u1)
        # ---------- 返回放大后结果(256x256) ----------
        return self.tail(d1)

# =================================
# PSNR / SSIM
# =================================
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 10 * torch.log10(1.0 / mse)

def calc_ssim(sr, hr):
    C1, C2 = 0.01**2, 0.03**2
    mu1, mu2 = sr.mean(), hr.mean()
    s1, s2 = sr.var(), hr.var()
    s12 = ((sr - mu1) * (hr - mu2)).mean()
    return ((2*mu1*mu2+C1)*(2*s12+C2))/((mu1**2+mu2**2+C1)*(s1+s2+C2))

# =================================
# 数据集(**最小改动:LR 不再被再上采回 256**)
# =================================
import imgaug.augmenters as iaa
import numpy as np

class BSDDataset(Dataset):
    def __init__(self, folder, augment=True):
        self.files = sorted(glob.glob(os.path.join(folder, "*.jpg")))
        self.augment = augment

        # ========= SR 真实退化增强(用来模拟现实 LR 退化)=========
        self.degrade = iaa.Sequential([
            iaa.GaussianBlur((0, 1.2)),                       # 光学模糊 ->镜头模糊
            iaa.AdditiveGaussianNoise(scale=(0, 0.03 * 255)),  # 传感器噪声 -> 相机噪声
            iaa.JpegCompression(compression=(40, 100)),       # 压缩伪影 -> JPEG压缩模糊
        ])

        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # ---------- 读取原图 ----------
        img = Image.open(self.files[idx]).convert("RGB")
        img = np.array(img)

        # ---------- 构造高分图 HR(固定 256x256) ----------
        hr = Image.fromarray(img).resize((256, 256), Image.BICUBIC)
        hr_np = np.array(hr)

        # ---------- 是否进行退化增强(对 HR 做退化) ----------
        if self.augment:
            lr_np = self.degrade(image=hr_np)   # 对 HR 进行真实退化
        else:
            lr_np = hr_np

        # ---------- 生成 LR(**仅下采到 64x64**,**不要再上采回 256**) ----------
        # 这是关键改动:保持 LR 的真实低分辨率尺寸,让模型去放大
        lr = Image.fromarray(lr_np).resize((64, 64), Image.BICUBIC)

        # ---------- 转 Tensor ----------
        hr = self.to_tensor(hr)   # [3, 256, 256]
        lr = self.to_tensor(lr)   # [3, 64, 64]

        return lr, hr

# =================================
# 下载数据说明
# =================================
def show_download_tip():
    print("\n请先下载 BSD300 数据集:")
    print("1) 访问:https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/ ")
    print("2) 解压后目录类似:")
    print("   BSDS300/images/train/*.jpg")
    print("   BSDS300/images/test/*.jpg\n")

# =================================
# Train / Val
# =================================
def train_one_epoch(model, loader, opt, loss_fn, device):
    model.train()
    total = 0
    for lr, hr in loader:
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)
        loss = loss_fn(sr, hr)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total += loss.item()
    return total / len(loader)

def validate(model, loader, loss_fn, device):
    model.eval()
    total, psnr_t, ssim_t = 0, 0, 0
    with torch.no_grad():
        for lr, hr in loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)
            total += loss_fn(sr, hr).item()
            psnr_t += calc_psnr(sr, hr).item()
            ssim_t += calc_ssim(sr, hr).item()
    n = len(loader)
    return total/n, psnr_t/n, ssim_t/n

# =================================
# Main
# =================================
def main():
    if not os.path.exists("BSDS300"):
        show_download_tip()
        return

    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_dir = "BSDS300/images/train"
    test_dir  = "BSDS300/images/test"

    train_set = BSDDataset(train_dir)
    val_set   = BSDDataset(test_dir)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)

    model = SimpleUNet().to(device)
    opt = optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.L1Loss()

    epochs = 300
    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, opt, loss_fn, device)
        val_loss, psnr, ssim = validate(model, val_loader, loss_fn, device)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"PSNR: {psnr:.2f} | "
              f"SSIM: {ssim:.4f}")
        
    # ---------- 保存模型 ----------
    os.makedirs("checkpoints", exist_ok=True)
    model_path = "checkpoints/unet_sr.pth"
    torch.save(model.state_dict(), model_path)
    print(f"\nModel saved to {model_path}")

    # Save examples (可视化时把 LR 放大回 256,仅用于查看)
    model.eval()
    with torch.no_grad():
        lr, hr = next(iter(val_loader))
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)
        print("shape, lr:{}, sr:{}, hr:{}".format(lr.shape, sr.shape, hr.shape))

        # ---------- minimal change: 为了可视化把 LR 放大回 256(不影响训练) ----------
        lr_vis = F.interpolate(lr, scale_factor=4.0, mode='bilinear', align_corners=False)

        save_image(torch.cat([lr_vis.cpu()[:3], sr.cpu()[:3], hr.cpu()[:3]], 0),
                   "sr_results.png", nrow=3, normalize=True)

    print("\nSaved example as sr_results.png")

if __name__ == "__main__":
    main()

epochs = 300训练结果,第一组:输入图像,第二组超分后的输出,第三组原图像。输入图像是在原图上做的尺寸缩小和光学模糊,传感器噪声,压缩伪影处理。

这个是用CIFAR数据集做的一个demo

cpp 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image

# -----------------
# Simple UNet
# -----------------
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()
        )

        self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU()
        )

        self.out = nn.Conv2d(32, 3, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)

        c2 = self.conv2(p1)
        p2 = self.pool2(c2)

        m = self.middle(p2)

        u2 = self.up2(m)
        u2 = torch.cat([u2, c2], dim=1)
        d2 = self.dec2(u2)

        u1 = self.up1(d2)
        u1 = torch.cat([u1, c1], dim=1)
        d1 = self.dec1(u1)

        return self.out(d1)

# -----------------
# Metrics
# -----------------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 10 * torch.log10(1.0 / mse)

def calc_ssim(sr, hr):
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    mu_x = sr.mean()
    mu_y = hr.mean()
    sigma_x = sr.var()
    sigma_y = hr.var()
    sigma_xy = ((sr - mu_x) * (hr - mu_y)).mean()
    return ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
        (mu_x**2 + mu_y**2 + C1) * (sigma_x + sigma_y + C2)
    )

# -----------------
# Dataset
# -----------------
transform_hr = transforms.ToTensor()
transform_lr = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(16),
    transforms.Resize(32),
    transforms.ToTensor()
])

class SRDataset(datasets.CIFAR10):
    def __getitem__(self, idx):
        img, _ = super().__getitem__(idx)
        hr = transform_hr(img)
        lr = transform_lr(hr)
        return lr, hr

# -----------------
# Train / Validation Functions
# -----------------
def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for lr, hr in loader:
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        loss = loss_fn(sr, hr)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    psnr_total = 0
    ssim_total = 0

    with torch.no_grad():
        for lr, hr in loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)

            total_loss += loss_fn(sr, hr).item()
            psnr_total += calc_psnr(sr, hr).item()
            ssim_total += calc_ssim(sr, hr).item()

    n = len(loader)
    return total_loss / n, psnr_total / n, ssim_total / n

# -----------------
# Main
# -----------------
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)

    full_ds = SRDataset(root="./data", train=True, download=True)
    train_size = int(0.9 * len(full_ds))
    val_size = len(full_ds) - train_size

    train_ds, val_ds = random_split(full_ds, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

    model = SimpleUNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()

    epochs = 3
    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
        val_loss, psnr, ssim = validate(model, val_loader, loss_fn, device)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"PSNR: {psnr:.2f} | "
              f"SSIM: {ssim:.4f}")

    # -----------------
    # Save examples
    # -----------------
    model.eval()
    with torch.no_grad():
        lr, hr = next(iter(train_loader))
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        lr_show = lr.cpu()[:3]
        sr_show = sr.cpu()[:3]
        hr_show = hr.cpu()[:3]

        save_image(torch.cat([lr_show, sr_show, hr_show], dim=0),
                   "sr_results.png", nrow=3, normalize=True)

    print("Saved example as sr_results.png")

if __name__ == "__main__":
    main()
相关推荐
晞微13 小时前
PyTorch 实现 BP 神经网络:从函数拟合到分类任务
pytorch·python·神经网络·分类
乐园游梦记13 小时前
使用OpenCvSharp的DNN模块加载YOLOv11的ONNX模型,涉及将模型文件路径传递给DNN模块的相关函数。
人工智能·深度学习·opencv·yolo·c#·dnn
_oP_i13 小时前
部署DeepSeek开源模型
人工智能·语音识别
GitCode官方13 小时前
CANN Meetup 深圳站成功举办,开源开放赋能 AI 产业落地
人工智能·开源·cann·atomgit
Keep__Fighting13 小时前
【机器学习:K-Means】
人工智能·python·算法·机器学习·kmeans·聚类·sklearn
梦子yumeko13 小时前
Spring Ai Alibaba-1.1.0.0-M5-SequentialAgent
java·人工智能·spring
知识进脑的肖老千啊13 小时前
深度学习下载包时可能会遇到的问题及解决方案
人工智能·python·深度学习
徐1113 小时前
deppseek优化怎么做? GEO技术具体解决方案
人工智能·python