UNet超分 效果测试

使用的数据集BSDS300 : https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/ BSDS300数据集下载地

train图片200张,test图片100张 图像大小321x481

以下UNet使用了亚像素卷积(Sub-pixel Convolution)来进行放大。

cpp 复制代码
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image

# =================================
# 简易 UNet(**最小改动:加入 tail 上采样模块**)
# =================================
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.mid = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()
        )

        self.up1 = nn.ConvTranspose2d(64, 32, 2, 2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU()
        )

        self.tail = nn.Sequential(
            nn.Conv2d(32, 12, 3, padding=1),   # 12 = 3 * (2^2)
            nn.PixelShuffle(2),                 # H,W -> 2x, channels -> 3   [B, C×(r²), H, W] -> [B, C, H×r, W×r] 模型现在低分辨率空间预测,应该长什么样子的子像素块,再把这些子像素块重新排序成更大的图像,这就是Sub-pixel Convolution(亚像素卷积)

            nn.Conv2d(3, 12, 3, padding=1),
            nn.PixelShuffle(2)                  # 再次 2x -> 总共 4x
        )

        # keep self.out name removed to avoid confusion (was: self.out = nn.Conv2d(32,3,1))

    def forward(self, x):
        c1 = self.enc1(x)
        p1 = self.pool1(c1)
        c2 = self.enc2(p1)
        p2 = self.pool2(c2)
        m = self.mid(p2)
        u2 = self.up2(m)
        u2 = torch.cat([u2, c2], 1)
        d2 = self.dec2(u2)
        u1 = self.up1(d2)
        u1 = torch.cat([u1, c1], 1)
        d1 = self.dec1(u1)
        # ---------- 返回放大后结果(256x256) ----------
        return self.tail(d1)

# =================================
# PSNR / SSIM
# =================================
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 10 * torch.log10(1.0 / mse)

def calc_ssim(sr, hr):
    C1, C2 = 0.01**2, 0.03**2
    mu1, mu2 = sr.mean(), hr.mean()
    s1, s2 = sr.var(), hr.var()
    s12 = ((sr - mu1) * (hr - mu2)).mean()
    return ((2*mu1*mu2+C1)*(2*s12+C2))/((mu1**2+mu2**2+C1)*(s1+s2+C2))

# =================================
# 数据集(**最小改动:LR 不再被再上采回 256**)
# =================================
import imgaug.augmenters as iaa
import numpy as np

class BSDDataset(Dataset):
    def __init__(self, folder, augment=True):
        self.files = sorted(glob.glob(os.path.join(folder, "*.jpg")))
        self.augment = augment

        # ========= SR 真实退化增强(用来模拟现实 LR 退化)=========
        self.degrade = iaa.Sequential([
            iaa.GaussianBlur((0, 1.2)),                       # 光学模糊 ->镜头模糊
            iaa.AdditiveGaussianNoise(scale=(0, 0.03 * 255)),  # 传感器噪声 -> 相机噪声
            iaa.JpegCompression(compression=(40, 100)),       # 压缩伪影 -> JPEG压缩模糊
        ])

        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # ---------- 读取原图 ----------
        img = Image.open(self.files[idx]).convert("RGB")
        img = np.array(img)

        # ---------- 构造高分图 HR(固定 256x256) ----------
        hr = Image.fromarray(img).resize((256, 256), Image.BICUBIC)
        hr_np = np.array(hr)

        # ---------- 是否进行退化增强(对 HR 做退化) ----------
        if self.augment:
            lr_np = self.degrade(image=hr_np)   # 对 HR 进行真实退化
        else:
            lr_np = hr_np

        # ---------- 生成 LR(**仅下采到 64x64**,**不要再上采回 256**) ----------
        # 这是关键改动:保持 LR 的真实低分辨率尺寸,让模型去放大
        lr = Image.fromarray(lr_np).resize((64, 64), Image.BICUBIC)

        # ---------- 转 Tensor ----------
        hr = self.to_tensor(hr)   # [3, 256, 256]
        lr = self.to_tensor(lr)   # [3, 64, 64]

        return lr, hr

# =================================
# 下载数据说明
# =================================
def show_download_tip():
    print("\n请先下载 BSD300 数据集:")
    print("1) 访问:https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/ ")
    print("2) 解压后目录类似:")
    print("   BSDS300/images/train/*.jpg")
    print("   BSDS300/images/test/*.jpg\n")

# =================================
# Train / Val
# =================================
def train_one_epoch(model, loader, opt, loss_fn, device):
    model.train()
    total = 0
    for lr, hr in loader:
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)
        loss = loss_fn(sr, hr)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total += loss.item()
    return total / len(loader)

def validate(model, loader, loss_fn, device):
    model.eval()
    total, psnr_t, ssim_t = 0, 0, 0
    with torch.no_grad():
        for lr, hr in loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)
            total += loss_fn(sr, hr).item()
            psnr_t += calc_psnr(sr, hr).item()
            ssim_t += calc_ssim(sr, hr).item()
    n = len(loader)
    return total/n, psnr_t/n, ssim_t/n

# =================================
# Main
# =================================
def main():
    if not os.path.exists("BSDS300"):
        show_download_tip()
        return

    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_dir = "BSDS300/images/train"
    test_dir  = "BSDS300/images/test"

    train_set = BSDDataset(train_dir)
    val_set   = BSDDataset(test_dir)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)

    model = SimpleUNet().to(device)
    opt = optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.L1Loss()

    epochs = 300
    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, opt, loss_fn, device)
        val_loss, psnr, ssim = validate(model, val_loader, loss_fn, device)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"PSNR: {psnr:.2f} | "
              f"SSIM: {ssim:.4f}")
        
    # ---------- 保存模型 ----------
    os.makedirs("checkpoints", exist_ok=True)
    model_path = "checkpoints/unet_sr.pth"
    torch.save(model.state_dict(), model_path)
    print(f"\nModel saved to {model_path}")

    # Save examples (可视化时把 LR 放大回 256,仅用于查看)
    model.eval()
    with torch.no_grad():
        lr, hr = next(iter(val_loader))
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)
        print("shape, lr:{}, sr:{}, hr:{}".format(lr.shape, sr.shape, hr.shape))

        # ---------- minimal change: 为了可视化把 LR 放大回 256(不影响训练) ----------
        lr_vis = F.interpolate(lr, scale_factor=4.0, mode='bilinear', align_corners=False)

        save_image(torch.cat([lr_vis.cpu()[:3], sr.cpu()[:3], hr.cpu()[:3]], 0),
                   "sr_results.png", nrow=3, normalize=True)

    print("\nSaved example as sr_results.png")

if __name__ == "__main__":
    main()

epochs = 300训练结果,第一组:输入图像,第二组超分后的输出,第三组原图像。输入图像是在原图上做的尺寸缩小和光学模糊,传感器噪声,压缩伪影处理。

这个是用CIFAR数据集做的一个demo

cpp 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image

# -----------------
# Simple UNet
# -----------------
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()
        )

        self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU()
        )

        self.out = nn.Conv2d(32, 3, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)

        c2 = self.conv2(p1)
        p2 = self.pool2(c2)

        m = self.middle(p2)

        u2 = self.up2(m)
        u2 = torch.cat([u2, c2], dim=1)
        d2 = self.dec2(u2)

        u1 = self.up1(d2)
        u1 = torch.cat([u1, c1], dim=1)
        d1 = self.dec1(u1)

        return self.out(d1)

# -----------------
# Metrics
# -----------------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 10 * torch.log10(1.0 / mse)

def calc_ssim(sr, hr):
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    mu_x = sr.mean()
    mu_y = hr.mean()
    sigma_x = sr.var()
    sigma_y = hr.var()
    sigma_xy = ((sr - mu_x) * (hr - mu_y)).mean()
    return ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
        (mu_x**2 + mu_y**2 + C1) * (sigma_x + sigma_y + C2)
    )

# -----------------
# Dataset
# -----------------
transform_hr = transforms.ToTensor()
transform_lr = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(16),
    transforms.Resize(32),
    transforms.ToTensor()
])

class SRDataset(datasets.CIFAR10):
    def __getitem__(self, idx):
        img, _ = super().__getitem__(idx)
        hr = transform_hr(img)
        lr = transform_lr(hr)
        return lr, hr

# -----------------
# Train / Validation Functions
# -----------------
def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for lr, hr in loader:
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        loss = loss_fn(sr, hr)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    psnr_total = 0
    ssim_total = 0

    with torch.no_grad():
        for lr, hr in loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)

            total_loss += loss_fn(sr, hr).item()
            psnr_total += calc_psnr(sr, hr).item()
            ssim_total += calc_ssim(sr, hr).item()

    n = len(loader)
    return total_loss / n, psnr_total / n, ssim_total / n

# -----------------
# Main
# -----------------
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)

    full_ds = SRDataset(root="./data", train=True, download=True)
    train_size = int(0.9 * len(full_ds))
    val_size = len(full_ds) - train_size

    train_ds, val_ds = random_split(full_ds, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

    model = SimpleUNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()

    epochs = 3
    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
        val_loss, psnr, ssim = validate(model, val_loader, loss_fn, device)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"PSNR: {psnr:.2f} | "
              f"SSIM: {ssim:.4f}")

    # -----------------
    # Save examples
    # -----------------
    model.eval()
    with torch.no_grad():
        lr, hr = next(iter(train_loader))
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        lr_show = lr.cpu()[:3]
        sr_show = sr.cpu()[:3]
        hr_show = hr.cpu()[:3]

        save_image(torch.cat([lr_show, sr_show, hr_show], dim=0),
                   "sr_results.png", nrow=3, normalize=True)

    print("Saved example as sr_results.png")

if __name__ == "__main__":
    main()
相关推荐
NAGNIP15 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab16 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab16 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP20 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年20 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼20 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS20 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区21 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈1 天前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang1 天前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx