使用的数据集BSDS300 : https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/ BSDS300数据集下载地
train图片200张,test图片100张 图像大小321x481
以下UNet使用了亚像素卷积(Sub-pixel Convolution)来进行放大。
cpp
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
# =================================
# 简易 UNet(**最小改动:加入 tail 上采样模块**)
# =================================
class SimpleUNet(nn.Module):
def __init__(self):
super().__init__()
self.enc1 = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
)
self.pool1 = nn.MaxPool2d(2)
self.enc2 = nn.Sequential(
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
)
self.pool2 = nn.MaxPool2d(2)
self.mid = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
)
self.up2 = nn.ConvTranspose2d(128, 64, 2, 2)
self.dec2 = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()
)
self.up1 = nn.ConvTranspose2d(64, 32, 2, 2)
self.dec1 = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1), nn.ReLU()
)
self.tail = nn.Sequential(
nn.Conv2d(32, 12, 3, padding=1), # 12 = 3 * (2^2)
nn.PixelShuffle(2), # H,W -> 2x, channels -> 3 [B, C×(r²), H, W] -> [B, C, H×r, W×r] 模型现在低分辨率空间预测,应该长什么样子的子像素块,再把这些子像素块重新排序成更大的图像,这就是Sub-pixel Convolution(亚像素卷积)
nn.Conv2d(3, 12, 3, padding=1),
nn.PixelShuffle(2) # 再次 2x -> 总共 4x
)
# keep self.out name removed to avoid confusion (was: self.out = nn.Conv2d(32,3,1))
def forward(self, x):
c1 = self.enc1(x)
p1 = self.pool1(c1)
c2 = self.enc2(p1)
p2 = self.pool2(c2)
m = self.mid(p2)
u2 = self.up2(m)
u2 = torch.cat([u2, c2], 1)
d2 = self.dec2(u2)
u1 = self.up1(d2)
u1 = torch.cat([u1, c1], 1)
d1 = self.dec1(u1)
# ---------- 返回放大后结果(256x256) ----------
return self.tail(d1)
# =================================
# PSNR / SSIM
# =================================
def calc_psnr(sr, hr):
mse = F.mse_loss(sr, hr)
return 10 * torch.log10(1.0 / mse)
def calc_ssim(sr, hr):
C1, C2 = 0.01**2, 0.03**2
mu1, mu2 = sr.mean(), hr.mean()
s1, s2 = sr.var(), hr.var()
s12 = ((sr - mu1) * (hr - mu2)).mean()
return ((2*mu1*mu2+C1)*(2*s12+C2))/((mu1**2+mu2**2+C1)*(s1+s2+C2))
# =================================
# 数据集(**最小改动:LR 不再被再上采回 256**)
# =================================
import imgaug.augmenters as iaa
import numpy as np
class BSDDataset(Dataset):
def __init__(self, folder, augment=True):
self.files = sorted(glob.glob(os.path.join(folder, "*.jpg")))
self.augment = augment
# ========= SR 真实退化增强(用来模拟现实 LR 退化)=========
self.degrade = iaa.Sequential([
iaa.GaussianBlur((0, 1.2)), # 光学模糊 ->镜头模糊
iaa.AdditiveGaussianNoise(scale=(0, 0.03 * 255)), # 传感器噪声 -> 相机噪声
iaa.JpegCompression(compression=(40, 100)), # 压缩伪影 -> JPEG压缩模糊
])
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
# ---------- 读取原图 ----------
img = Image.open(self.files[idx]).convert("RGB")
img = np.array(img)
# ---------- 构造高分图 HR(固定 256x256) ----------
hr = Image.fromarray(img).resize((256, 256), Image.BICUBIC)
hr_np = np.array(hr)
# ---------- 是否进行退化增强(对 HR 做退化) ----------
if self.augment:
lr_np = self.degrade(image=hr_np) # 对 HR 进行真实退化
else:
lr_np = hr_np
# ---------- 生成 LR(**仅下采到 64x64**,**不要再上采回 256**) ----------
# 这是关键改动:保持 LR 的真实低分辨率尺寸,让模型去放大
lr = Image.fromarray(lr_np).resize((64, 64), Image.BICUBIC)
# ---------- 转 Tensor ----------
hr = self.to_tensor(hr) # [3, 256, 256]
lr = self.to_tensor(lr) # [3, 64, 64]
return lr, hr
# =================================
# 下载数据说明
# =================================
def show_download_tip():
print("\n请先下载 BSD300 数据集:")
print("1) 访问:https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/ ")
print("2) 解压后目录类似:")
print(" BSDS300/images/train/*.jpg")
print(" BSDS300/images/test/*.jpg\n")
# =================================
# Train / Val
# =================================
def train_one_epoch(model, loader, opt, loss_fn, device):
model.train()
total = 0
for lr, hr in loader:
lr, hr = lr.to(device), hr.to(device)
sr = model(lr)
loss = loss_fn(sr, hr)
opt.zero_grad()
loss.backward()
opt.step()
total += loss.item()
return total / len(loader)
def validate(model, loader, loss_fn, device):
model.eval()
total, psnr_t, ssim_t = 0, 0, 0
with torch.no_grad():
for lr, hr in loader:
lr, hr = lr.to(device), hr.to(device)
sr = model(lr)
total += loss_fn(sr, hr).item()
psnr_t += calc_psnr(sr, hr).item()
ssim_t += calc_ssim(sr, hr).item()
n = len(loader)
return total/n, psnr_t/n, ssim_t/n
# =================================
# Main
# =================================
def main():
if not os.path.exists("BSDS300"):
show_download_tip()
return
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dir = "BSDS300/images/train"
test_dir = "BSDS300/images/test"
train_set = BSDDataset(train_dir)
val_set = BSDDataset(test_dir)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False)
model = SimpleUNet().to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.L1Loss()
epochs = 300
for epoch in range(epochs):
train_loss = train_one_epoch(model, train_loader, opt, loss_fn, device)
val_loss, psnr, ssim = validate(model, val_loader, loss_fn, device)
print(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"PSNR: {psnr:.2f} | "
f"SSIM: {ssim:.4f}")
# ---------- 保存模型 ----------
os.makedirs("checkpoints", exist_ok=True)
model_path = "checkpoints/unet_sr.pth"
torch.save(model.state_dict(), model_path)
print(f"\nModel saved to {model_path}")
# Save examples (可视化时把 LR 放大回 256,仅用于查看)
model.eval()
with torch.no_grad():
lr, hr = next(iter(val_loader))
lr, hr = lr.to(device), hr.to(device)
sr = model(lr)
print("shape, lr:{}, sr:{}, hr:{}".format(lr.shape, sr.shape, hr.shape))
# ---------- minimal change: 为了可视化把 LR 放大回 256(不影响训练) ----------
lr_vis = F.interpolate(lr, scale_factor=4.0, mode='bilinear', align_corners=False)
save_image(torch.cat([lr_vis.cpu()[:3], sr.cpu()[:3], hr.cpu()[:3]], 0),
"sr_results.png", nrow=3, normalize=True)
print("\nSaved example as sr_results.png")
if __name__ == "__main__":
main()
epochs = 300训练结果,第一组:输入图像,第二组超分后的输出,第三组原图像。输入图像是在原图上做的尺寸缩小和光学模糊,传感器噪声,压缩伪影处理。

这个是用CIFAR数据集做的一个demo
cpp
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
# -----------------
# Simple UNet
# -----------------
class SimpleUNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
)
self.pool2 = nn.MaxPool2d(2)
self.middle = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
)
self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec2 = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()
)
self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.dec1 = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1), nn.ReLU()
)
self.out = nn.Conv2d(32, 3, 1)
def forward(self, x):
c1 = self.conv1(x)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
m = self.middle(p2)
u2 = self.up2(m)
u2 = torch.cat([u2, c2], dim=1)
d2 = self.dec2(u2)
u1 = self.up1(d2)
u1 = torch.cat([u1, c1], dim=1)
d1 = self.dec1(u1)
return self.out(d1)
# -----------------
# Metrics
# -----------------
def calc_psnr(sr, hr):
mse = F.mse_loss(sr, hr)
return 10 * torch.log10(1.0 / mse)
def calc_ssim(sr, hr):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu_x = sr.mean()
mu_y = hr.mean()
sigma_x = sr.var()
sigma_y = hr.var()
sigma_xy = ((sr - mu_x) * (hr - mu_y)).mean()
return ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
(mu_x**2 + mu_y**2 + C1) * (sigma_x + sigma_y + C2)
)
# -----------------
# Dataset
# -----------------
transform_hr = transforms.ToTensor()
transform_lr = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(16),
transforms.Resize(32),
transforms.ToTensor()
])
class SRDataset(datasets.CIFAR10):
def __getitem__(self, idx):
img, _ = super().__getitem__(idx)
hr = transform_hr(img)
lr = transform_lr(hr)
return lr, hr
# -----------------
# Train / Validation Functions
# -----------------
def train_one_epoch(model, loader, optimizer, loss_fn, device):
model.train()
total_loss = 0
for lr, hr in loader:
lr, hr = lr.to(device), hr.to(device)
sr = model(lr)
loss = loss_fn(sr, hr)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def validate(model, loader, loss_fn, device):
model.eval()
total_loss = 0
psnr_total = 0
ssim_total = 0
with torch.no_grad():
for lr, hr in loader:
lr, hr = lr.to(device), hr.to(device)
sr = model(lr)
total_loss += loss_fn(sr, hr).item()
psnr_total += calc_psnr(sr, hr).item()
ssim_total += calc_ssim(sr, hr).item()
n = len(loader)
return total_loss / n, psnr_total / n, ssim_total / n
# -----------------
# Main
# -----------------
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
full_ds = SRDataset(root="./data", train=True, download=True)
train_size = int(0.9 * len(full_ds))
val_size = len(full_ds) - train_size
train_ds, val_ds = random_split(full_ds, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
model = SimpleUNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.L1Loss()
epochs = 3
for epoch in range(epochs):
train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
val_loss, psnr, ssim = validate(model, val_loader, loss_fn, device)
print(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"PSNR: {psnr:.2f} | "
f"SSIM: {ssim:.4f}")
# -----------------
# Save examples
# -----------------
model.eval()
with torch.no_grad():
lr, hr = next(iter(train_loader))
lr, hr = lr.to(device), hr.to(device)
sr = model(lr)
lr_show = lr.cpu()[:3]
sr_show = sr.cpu()[:3]
hr_show = hr.cpu()[:3]
save_image(torch.cat([lr_show, sr_show, hr_show], dim=0),
"sr_results.png", nrow=3, normalize=True)
print("Saved example as sr_results.png")
if __name__ == "__main__":
main()