深度学习图像超分辨率技术全面解析:从入门到精通
本文系统梳理基于深度学习的单图像超分辨率(SISR)技术,涵盖问题定义、网络架构设计、损失函数、评估指标、前沿方法及实际应用,配合完整的PyTorch代码实现,帮你全面掌握这一图像处理的核心技术。
一、什么是图像超分辨率
1.1 问题定义
图像超分辨率(Super-Resolution, SR) 是指从低分辨率(LR)图像重建高分辨率(HR)图像的技术。
超分辨率任务示意:
┌─────────────┐ ┌─────────────────────┐
│ │ │ │
│ 64×64 │ SR网络 │ 256×256 │
│ LR图像 │ ──────────→ │ HR图像 │
│ │ ×4放大 │ │
└─────────────┘ └─────────────────────┘
目标:恢复清晰的纹理、边缘和细节
1.2 数学建模
图像退化过程通常建模为:
退化模型:
I_LR = D(I_HR; θ_D)
其中:
D(I_HR; θ_D) = (I_HR ⊗ κ) ↓_s + n
参数说明:
- I_HR: 原始高分辨率图像
- κ: 模糊核(如高斯模糊)
- ⊗: 卷积操作
- ↓_s: 下采样操作,缩放因子为s
- n: 加性噪声(通常是高斯白噪声)
超分辨率任务:
给定 I_LR,重建 I_SR ≈ I_HR
I_SR = F(I_LR; θ_F)
其中 F 是超分辨率模型,θ_F 是模型参数
1.3 为什么超分辨率很难?
超分辨率是一个病态问题(ill-posed problem):
┌─────────────────────────────────────────────────────────────┐
│ │
│ 一个低分辨率图像可能对应无数个高分辨率图像 │
│ │
│ HR₁ ─┐ │
│ HR₂ ─┼──→ 下采样 ──→ LR │
│ HR₃ ─┘ │
│ ... │
│ │
│ 信息在下采样过程中丢失了,无法完美恢复 │
│ │
└─────────────────────────────────────────────────────────────┘
挑战:
1. 高频细节(纹理、边缘)在下采样时丢失
2. 需要从有限信息中"猜测"缺失的细节
3. 不同场景需要不同的先验知识
1.4 应用场景
超分辨率的实际应用:
┌─────────────────────────────────────────────────────────────┐
│ │
│ 安防监控 │ 增强模糊的监控画面,辅助人脸识别 │
│ 医学影像 │ 提升CT/MRI图像清晰度,辅助诊断 │
│ 卫星遥感 │ 增强卫星图像分辨率,提取地物信息 │
│ 视频增强 │ 将老旧视频/低分辨率视频转为高清 │
│ 手机摄影 │ 数码变焦、夜景增强 │
│ 游戏/影视 │ 老游戏/老电影高清重制 │
│ │
└─────────────────────────────────────────────────────────────┘
二、常用数据集与评估指标
2.1 基准数据集
┌─────────────────────────────────────────────────────────────┐
│ 常用超分辨率数据集 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 训练集: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ DIV2K │ 1000张高质量2K图像,最常用 │ │
│ │ Flickr2K │ 2650张2K图像,常与DIV2K合并使用 │ │
│ │ ImageNet │ 大规模图像数据集,用于预训练 │ │
│ │ T91 │ 91张图像,早期常用的小数据集 │ │
│ │ BSDS500 │ 500张自然图像 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 测试集: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Set5 │ 5张经典测试图像 │ │
│ │ Set14 │ 14张测试图像 │ │
│ │ BSD100 │ 100张自然图像 │ │
│ │ Urban100 │ 100张城市建筑图像,边缘丰富 │ │
│ │ Manga109 │ 109张日本漫画图像 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
2.2 退化模式
python
"""
常用的图像退化方式
"""
import cv2
import numpy as np
from scipy.ndimage import gaussian_filter
def bicubic_degradation(hr_image, scale=4):
"""
BI模式:双三次下采样(最常用)
直接使用bicubic插值进行下采样
"""
h, w = hr_image.shape[:2]
lr_image = cv2.resize(hr_image, (w//scale, h//scale),
interpolation=cv2.INTER_CUBIC)
return lr_image
def blur_downsample_degradation(hr_image, scale=3, kernel_size=7, sigma=1.6):
"""
BD模式:模糊 + 下采样
先用高斯核模糊,再下采样
"""
# 高斯模糊
blurred = cv2.GaussianBlur(hr_image, (kernel_size, kernel_size), sigma)
# 下采样
h, w = blurred.shape[:2]
lr_image = cv2.resize(blurred, (w//scale, h//scale),
interpolation=cv2.INTER_CUBIC)
return lr_image
def downsample_noise_degradation(hr_image, scale=3, noise_level=30):
"""
DN模式:下采样 + 噪声
先下采样,再加高斯噪声
"""
# 下采样
h, w = hr_image.shape[:2]
lr_image = cv2.resize(hr_image, (w//scale, h//scale),
interpolation=cv2.INTER_CUBIC)
# 添加高斯噪声
noise = np.random.normal(0, noise_level, lr_image.shape)
lr_image = np.clip(lr_image + noise, 0, 255).astype(np.uint8)
return lr_image
def complex_degradation(hr_image, scale=4, blur_sigma=1.5, noise_sigma=10,
jpeg_quality=70):
"""
复杂退化模式(更接近真实世界)
模糊 → 下采样 → 噪声 → JPEG压缩
"""
# 1. 模糊
blurred = cv2.GaussianBlur(hr_image, (21, 21), blur_sigma)
# 2. 下采样
h, w = blurred.shape[:2]
downsampled = cv2.resize(blurred, (w//scale, h//scale),
interpolation=cv2.INTER_CUBIC)
# 3. 添加噪声
noise = np.random.normal(0, noise_sigma, downsampled.shape)
noisy = np.clip(downsampled + noise, 0, 255).astype(np.uint8)
# 4. JPEG压缩
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
_, encoded = cv2.imencode('.jpg', noisy, encode_param)
lr_image = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
return lr_image
2.3 评估指标
PSNR(峰值信噪比)
python
import numpy as np
import torch
def calculate_psnr(img1, img2, max_val=255.0):
"""
计算PSNR(Peak Signal-to-Noise Ratio)
PSNR = 10 * log10(MAX² / MSE)
PSNR越高,图像质量越好
一般来说:
- PSNR < 30dB: 质量较差
- 30-40dB: 质量可接受
- PSNR > 40dB: 质量很好
"""
mse = np.mean((img1.astype(np.float64) - img2.astype(np.float64)) ** 2)
if mse == 0:
return float('inf')
psnr = 10 * np.log10((max_val ** 2) / mse)
return psnr
def calculate_psnr_torch(img1, img2, max_val=1.0):
"""PyTorch版本的PSNR计算"""
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
psnr = 10 * torch.log10((max_val ** 2) / mse)
return psnr.item()
SSIM(结构相似性)
python
def calculate_ssim(img1, img2, window_size=11, C1=0.01**2, C2=0.03**2):
"""
计算SSIM(Structural Similarity Index)
SSIM考虑三个方面:
1. 亮度对比 (luminance)
2. 对比度对比 (contrast)
3. 结构对比 (structure)
SSIM = (2*μx*μy + C1)(2*σxy + C2) / ((μx² + μy² + C1)(σx² + σy² + C2))
SSIM范围:[-1, 1],越接近1越好
"""
from scipy.ndimage import uniform_filter
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
# 计算均值
mu1 = uniform_filter(img1, window_size)
mu2 = uniform_filter(img2, window_size)
# 计算方差和协方差
sigma1_sq = uniform_filter(img1 ** 2, window_size) - mu1 ** 2
sigma2_sq = uniform_filter(img2 ** 2, window_size) - mu2 ** 2
sigma12 = uniform_filter(img1 * img2, window_size) - mu1 * mu2
# 计算SSIM
numerator = (2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)
denominator = (mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2)
ssim_map = numerator / denominator
return np.mean(ssim_map)
感知质量指标
python
"""
感知质量评估指标
PSNR/SSIM关注像素级差异,但不能完全反映人眼视觉感受
感知质量指标更关注图像的视觉效果
"""
# LPIPS (Learned Perceptual Image Patch Similarity)
# pip install lpips
import lpips
def calculate_lpips(img1, img2, net='alex'):
"""
计算LPIPS
使用预训练网络(如AlexNet、VGG)提取特征
计算特征空间的距离
LPIPS越低越好
"""
loss_fn = lpips.LPIPS(net=net)
# 转换为torch tensor,范围[-1, 1]
img1_tensor = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1
img2_tensor = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1
with torch.no_grad():
distance = loss_fn(img1_tensor, img2_tensor)
return distance.item()
# NIQE (Natural Image Quality Evaluator) - 无参考评估
def calculate_niqe(img):
"""
计算NIQE(无需参考图像)
基于自然图像统计特性
NIQE越低越好
需要安装:pip install pyiqa
"""
import pyiqa
niqe_metric = pyiqa.create_metric('niqe')
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() / 255.0
score = niqe_metric(img_tensor)
return score.item()
三、上采样方法
3.1 传统插值方法
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class InterpolationUpsampler:
"""
传统插值上采样方法
"""
@staticmethod
def nearest_upsample(x, scale_factor):
"""
最近邻插值
简单快速,但会产生块状伪影
"""
return F.interpolate(x, scale_factor=scale_factor, mode='nearest')
@staticmethod
def bilinear_upsample(x, scale_factor):
"""
双线性插值
结果平滑,但可能模糊
"""
return F.interpolate(x, scale_factor=scale_factor,
mode='bilinear', align_corners=False)
@staticmethod
def bicubic_upsample(x, scale_factor):
"""
双三次插值
比双线性更平滑,计算量稍大
"""
return F.interpolate(x, scale_factor=scale_factor,
mode='bicubic', align_corners=False)
3.2 转置卷积
python
class TransposedConvUpsampler(nn.Module):
"""
转置卷积上采样(也叫反卷积)
可学习的上采样方式
原理:
- 在输入特征图周围/之间添加padding
- 然后进行标准卷积
问题:容易产生棋盘格伪影(checkerboard artifacts)
"""
def __init__(self, in_channels, out_channels, scale_factor=2):
super().__init__()
# 转置卷积
# kernel_size = 2 * scale_factor
# stride = scale_factor
self.deconv = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=2*scale_factor,
stride=scale_factor,
padding=scale_factor // 2
)
def forward(self, x):
return self.deconv(x)
class TransposedConvUpsamplerV2(nn.Module):
"""
改进的转置卷积(减少棋盘格伪影)
使用小kernel + 后续卷积
"""
def __init__(self, in_channels, out_channels, scale_factor=2):
super().__init__()
self.deconv = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=scale_factor,
stride=scale_factor,
padding=0
)
# 后续卷积平滑
self.conv = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
def forward(self, x):
x = self.deconv(x)
x = self.conv(x)
return x
3.3 亚像素卷积(Sub-Pixel Convolution)
python
class SubPixelUpsampler(nn.Module):
"""
亚像素卷积上采样(PixelShuffle)
原理:
1. 先用卷积增加通道数(C → C * r²)
2. 然后重排像素(Pixel Shuffle)
优点:
- 高效:大部分计算在低分辨率空间进行
- 无棋盘格伪影
- 是目前最流行的上采样方式
H×W×(C*r²) → (H*r)×(W*r)×C
"""
def __init__(self, in_channels, out_channels, scale_factor=2):
super().__init__()
# 先卷积增加通道数
self.conv = nn.Conv2d(
in_channels,
out_channels * (scale_factor ** 2),
kernel_size=3,
padding=1
)
# 像素重排
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
return x
def pixel_shuffle_demo():
"""
PixelShuffle原理演示
"""
print("PixelShuffle原理:")
print("=" * 50)
# 假设输入是 1×1×4 (H=1, W=1, C=4)
# scale_factor = 2
# 输出是 2×2×1
# 输入特征图的4个通道
# [a, b, c, d] 重排为 2×2
# [[a, b],
# [c, d]]
x = torch.arange(1, 17).float().view(1, 4, 2, 2)
print(f"输入形状: {x.shape}") # [1, 4, 2, 2]
ps = nn.PixelShuffle(2)
y = ps(x)
print(f"输出形状: {y.shape}") # [1, 1, 4, 4]
print("\n通道数减少为原来的1/r²,空间尺寸增加r倍")
pixel_shuffle_demo()
3.4 上采样策略对比
┌─────────────────────────────────────────────────────────────────┐
│ 上采样策略对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Pre-upsampling(前置上采样): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ LR ──→ 插值放大 ──→ 深度网络 ──→ HR │ │
│ │ 优点:简单直接 │ │
│ │ 缺点:计算量大(在HR空间做卷积) │ │
│ │ 代表:SRCNN, VDSR │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ Post-upsampling(后置上采样): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ LR ──→ 深度网络 ──→ 上采样层 ──→ HR │ │
│ │ 优点:计算高效(在LR空间做卷积) │ │
│ │ 缺点:上采样层设计关键 │ │
│ │ 代表:ESPCN, EDSR, RCAN │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ Progressive upsampling(渐进上采样): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ LR ──→ 网络+×2 ──→ 网络+×2 ──→ HR (×4) │ │
│ │ 优点:逐步重建,更稳定 │ │
│ │ 缺点:网络较复杂 │ │
│ │ 代表:LapSRN, ProSR │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
四、损失函数
4.1 像素级损失
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class PixelLoss(nn.Module):
"""
像素级损失函数
直接衡量像素值的差异
"""
def __init__(self, loss_type='l1'):
super().__init__()
self.loss_type = loss_type
def forward(self, pred, target):
if self.loss_type == 'l1':
# L1损失(MAE)
# 对异常值更鲁棒
return F.l1_loss(pred, target)
elif self.loss_type == 'l2':
# L2损失(MSE)
# 与PSNR直接相关
# 容易导致过度平滑
return F.mse_loss(pred, target)
elif self.loss_type == 'charbonnier':
# Charbonnier损失
# L1的平滑近似,处处可微
eps = 1e-6
diff = pred - target
return torch.mean(torch.sqrt(diff ** 2 + eps ** 2))
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
class CharbonnierLoss(nn.Module):
"""
Charbonnier损失(L1的可微近似)
L_char = sqrt((pred - target)² + ε²)
当|x|远大于ε时,近似于L1
在0点处可微(L1在0点不可微)
"""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, pred, target):
diff = pred - target
loss = torch.sqrt(diff ** 2 + self.eps ** 2)
return torch.mean(loss)
4.2 感知损失(Perceptual Loss)
python
import torchvision.models as models
class VGGPerceptualLoss(nn.Module):
"""
VGG感知损失
使用预训练VGG网络提取高层特征
计算特征空间的距离
优点:
- 更关注语义和结构相似性
- 生成的图像视觉效果更好
缺点:
- 可能导致颜色偏移
- PSNR可能下降
"""
def __init__(self, layer_weights=None, use_input_norm=True):
super().__init__()
# 加载预训练VGG19
vgg = models.vgg19(pretrained=True).features
# 定义要提取的层
# conv1_2, conv2_2, conv3_4, conv4_4, conv5_4
self.layer_indices = [2, 7, 16, 25, 34]
# 默认权重
if layer_weights is None:
self.layer_weights = [0.1, 0.1, 1.0, 1.0, 1.0]
else:
self.layer_weights = layer_weights
# 分割VGG为多个阶段
self.stages = nn.ModuleList()
prev_idx = 0
for idx in self.layer_indices:
self.stages.append(nn.Sequential(*list(vgg.children())[prev_idx:idx+1]))
prev_idx = idx + 1
# 冻结参数
for param in self.parameters():
param.requires_grad = False
# 输入归一化(ImageNet均值和标准差)
self.use_input_norm = use_input_norm
self.register_buffer(
'mean',
torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
)
self.register_buffer(
'std',
torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
)
def forward(self, pred, target):
"""
计算感知损失
"""
if self.use_input_norm:
pred = (pred - self.mean) / self.std
target = (target - self.mean) / self.std
loss = 0.0
pred_feat = pred
target_feat = target
for stage, weight in zip(self.stages, self.layer_weights):
pred_feat = stage(pred_feat)
target_feat = stage(target_feat)
# L1距离
loss += weight * F.l1_loss(pred_feat, target_feat)
return loss
class ContentStyleLoss(nn.Module):
"""
内容损失 + 风格损失
内容损失:特征的L2距离
风格损失:Gram矩阵的距离(捕捉纹理信息)
"""
def __init__(self, content_weight=1.0, style_weight=0.1):
super().__init__()
self.content_weight = content_weight
self.style_weight = style_weight
# VGG特征提取器
vgg = models.vgg19(pretrained=True).features[:16]
self.vgg = vgg
for param in self.vgg.parameters():
param.requires_grad = False
def gram_matrix(self, x):
"""
计算Gram矩阵
G = F @ F^T
捕捉特征通道之间的相关性(纹理信息)
"""
b, c, h, w = x.size()
features = x.view(b, c, h * w)
gram = torch.bmm(features, features.transpose(1, 2))
return gram / (c * h * w)
def forward(self, pred, target):
pred_feat = self.vgg(pred)
target_feat = self.vgg(target)
# 内容损失
content_loss = F.mse_loss(pred_feat, target_feat)
# 风格损失
pred_gram = self.gram_matrix(pred_feat)
target_gram = self.gram_matrix(target_feat)
style_loss = F.mse_loss(pred_gram, target_gram)
return self.content_weight * content_loss + self.style_weight * style_loss
4.3 对抗损失(Adversarial Loss)
python
class GANLoss(nn.Module):
"""
GAN损失
让生成的SR图像在判别器看来像真实HR图像
"""
def __init__(self, gan_type='vanilla', real_label=1.0, fake_label=0.0):
super().__init__()
self.gan_type = gan_type
self.real_label = real_label
self.fake_label = fake_label
if gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif gan_type == 'wgan':
self.loss = None # Wasserstein距离
elif gan_type == 'hinge':
self.loss = None # Hinge损失
else:
raise ValueError(f"Unknown GAN type: {gan_type}")
def get_target_tensor(self, pred, is_real):
"""获取目标标签"""
if is_real:
return torch.full_like(pred, self.real_label)
else:
return torch.full_like(pred, self.fake_label)
def forward(self, pred, is_real):
"""
计算GAN损失
Args:
pred: 判别器输出
is_real: 是否为真实样本
"""
if self.gan_type in ['vanilla', 'lsgan']:
target = self.get_target_tensor(pred, is_real)
return self.loss(pred, target)
elif self.gan_type == 'wgan':
if is_real:
return -pred.mean()
else:
return pred.mean()
elif self.gan_type == 'hinge':
if is_real:
return F.relu(1.0 - pred).mean()
else:
return F.relu(1.0 + pred).mean()
class RelativisticGANLoss(nn.Module):
"""
相对GAN损失(ESRGAN中使用)
不只是判断"真假",而是判断"谁更真"
D_Ra(x_r, x_f) = σ(C(x_r) - E[C(x_f)])
D_Ra(x_f, x_r) = σ(C(x_f) - E[C(x_r)])
"""
def __init__(self):
super().__init__()
self.loss = nn.BCEWithLogitsLoss()
def forward(self, real_pred, fake_pred, is_discriminator):
"""
Args:
real_pred: 判别器对真实图像的预测
fake_pred: 判别器对生成图像的预测
is_discriminator: 是否是判别器的损失
"""
# 相对预测
real_relative = real_pred - fake_pred.mean()
fake_relative = fake_pred - real_pred.mean()
if is_discriminator:
# 判别器:真实图像应该比假图像更真
real_loss = self.loss(real_relative, torch.ones_like(real_relative))
fake_loss = self.loss(fake_relative, torch.zeros_like(fake_relative))
return (real_loss + fake_loss) / 2
else:
# 生成器:假图像应该比真实图像更真
real_loss = self.loss(real_relative, torch.zeros_like(real_relative))
fake_loss = self.loss(fake_relative, torch.ones_like(fake_relative))
return (real_loss + fake_loss) / 2
4.4 综合损失
python
class SRLoss(nn.Module):
"""
综合超分辨率损失
结合多种损失函数
"""
def __init__(self, pixel_weight=1.0, perceptual_weight=0.1,
adversarial_weight=0.01):
super().__init__()
self.pixel_weight = pixel_weight
self.perceptual_weight = perceptual_weight
self.adversarial_weight = adversarial_weight
# 像素损失
self.pixel_loss = nn.L1Loss()
# 感知损失
if perceptual_weight > 0:
self.perceptual_loss = VGGPerceptualLoss()
# 对抗损失
if adversarial_weight > 0:
self.gan_loss = GANLoss(gan_type='vanilla')
def forward(self, pred, target, discriminator=None):
"""
计算总损失
"""
losses = {}
# 像素损失
losses['pixel'] = self.pixel_loss(pred, target)
total_loss = self.pixel_weight * losses['pixel']
# 感知损失
if self.perceptual_weight > 0:
losses['perceptual'] = self.perceptual_loss(pred, target)
total_loss += self.perceptual_weight * losses['perceptual']
# 对抗损失
if self.adversarial_weight > 0 and discriminator is not None:
fake_pred = discriminator(pred)
losses['adversarial'] = self.gan_loss(fake_pred, is_real=True)
total_loss += self.adversarial_weight * losses['adversarial']
losses['total'] = total_loss
return total_loss, losses
五、经典网络架构
5.1 SRCNN:开山之作
python
class SRCNN(nn.Module):
"""
SRCNN - 第一个深度学习超分辨率模型(2014)
三层结构,对应传统方法的三个步骤:
1. 特征提取(Patch extraction)
2. 非线性映射(Non-linear mapping)
3. 重建(Reconstruction)
输入:双三次插值放大后的图像
输出:高分辨率图像
"""
def __init__(self, num_channels=3, feature_dim=64, mapping_dim=32):
super().__init__()
# 特征提取层
self.feature_extraction = nn.Sequential(
nn.Conv2d(num_channels, feature_dim, kernel_size=9, padding=4),
nn.ReLU(inplace=True)
)
# 非线性映射层
self.mapping = nn.Sequential(
nn.Conv2d(feature_dim, mapping_dim, kernel_size=1),
nn.ReLU(inplace=True)
)
# 重建层
self.reconstruction = nn.Conv2d(mapping_dim, num_channels, kernel_size=5, padding=2)
def forward(self, x):
"""
前向传播
输入x应该是双三次插值放大后的图像
"""
feat = self.feature_extraction(x)
mapped = self.mapping(feat)
out = self.reconstruction(mapped)
return out
def srcnn_example():
"""SRCNN使用示例"""
model = SRCNN()
# 先用bicubic放大,再输入网络
lr_image = torch.randn(1, 3, 64, 64)
lr_upscaled = F.interpolate(lr_image, scale_factor=4, mode='bicubic')
sr_image = model(lr_upscaled)
print(f"输入(放大后): {lr_upscaled.shape}")
print(f"输出: {sr_image.shape}")
5.2 VDSR:深度网络+残差学习
python
class VDSR(nn.Module):
"""
VDSR - Very Deep Super Resolution(2016)
关键创新:
1. 更深的网络(20层)
2. 全局残差学习(学习残差而非完整图像)
3. 梯度裁剪解决梯度爆炸
4. 可学习的多尺度(一个模型处理多种放大倍数)
网络学习的是残差:R = HR - LR_upscaled
输出:LR_upscaled + R
"""
def __init__(self, num_channels=1, num_features=64, num_layers=20):
super().__init__()
layers = []
# 第一层
layers.append(nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1))
layers.append(nn.ReLU(inplace=True))
# 中间层
for _ in range(num_layers - 2):
layers.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
layers.append(nn.ReLU(inplace=True))
# 最后一层
layers.append(nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1))
self.network = nn.Sequential(*layers)
def forward(self, x):
"""
前向传播
全局残差学习:输出 = 输入 + 网络输出(残差)
"""
residual = self.network(x)
return x + residual
5.3 残差块设计
python
class BasicBlock(nn.Module):
"""
基础残差块
Conv → ReLU → Conv + Skip Connection
"""
def __init__(self, num_features, kernel_size=3):
super().__init__()
self.conv1 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out + residual
class ResidualBlock(nn.Module):
"""
标准残差块(去除BN层)
EDSR发现:在超分辨率任务中,BN层会消耗大量显存且不提升性能
"""
def __init__(self, num_features, kernel_size=3, res_scale=1.0):
super().__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
# 残差缩放(稳定训练)
out = out * self.res_scale
return out + residual
class ResidualDenseBlock(nn.Module):
"""
残差密集块(RDB)- RDN中使用
结合残差学习和密集连接
每一层都接收前面所有层的特征
"""
def __init__(self, num_features, growth_rate=32, num_layers=5):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_channels = num_features + i * growth_rate
self.layers.append(nn.Sequential(
nn.Conv2d(in_channels, growth_rate, 3, padding=1),
nn.ReLU(inplace=True)
))
# 局部特征融合
self.lff = nn.Conv2d(
num_features + num_layers * growth_rate,
num_features,
kernel_size=1
)
def forward(self, x):
features = [x]
for layer in self.layers:
# 密集连接:拼接所有之前的特征
out = layer(torch.cat(features, dim=1))
features.append(out)
# 局部特征融合
out = self.lff(torch.cat(features, dim=1))
# 残差连接
return out + x
5.4 EDSR:去除BN的深度残差网络
python
class EDSR(nn.Module):
"""
EDSR - Enhanced Deep Residual Networks(2017)
关键改进:
1. 去除BN层(节省显存,提升性能)
2. 残差缩放(稳定深层网络训练)
3. 更宽的网络(256通道)
4. 后置上采样(计算高效)
"""
def __init__(self, num_channels=3, num_features=256, num_blocks=32,
scale_factor=4, res_scale=0.1):
super().__init__()
self.scale_factor = scale_factor
# 头部:特征提取
self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)
# 主体:残差块堆叠
body = []
for _ in range(num_blocks):
body.append(ResidualBlock(num_features, res_scale=res_scale))
body.append(nn.Conv2d(num_features, num_features, 3, padding=1))
self.body = nn.Sequential(*body)
# 尾部:上采样
self.upsample = self._make_upsample(num_features, scale_factor)
# 输出层
self.tail = nn.Conv2d(num_features, num_channels, 3, padding=1)
def _make_upsample(self, num_features, scale_factor):
"""构建上采样模块"""
layers = []
if scale_factor == 2 or scale_factor == 4:
for _ in range(scale_factor // 2):
layers.append(nn.Conv2d(num_features, num_features * 4, 3, padding=1))
layers.append(nn.PixelShuffle(2))
elif scale_factor == 3:
layers.append(nn.Conv2d(num_features, num_features * 9, 3, padding=1))
layers.append(nn.PixelShuffle(3))
return nn.Sequential(*layers)
def forward(self, x):
# 特征提取
head_feat = self.head(x)
# 残差学习
body_feat = self.body(head_feat)
body_feat = body_feat + head_feat # 全局残差
# 上采样
upsampled = self.upsample(body_feat)
# 输出
out = self.tail(upsampled)
return out
5.5 RCAN:通道注意力网络
python
class ChannelAttention(nn.Module):
"""
通道注意力模块(CA)
自适应地给不同通道分配不同的权重
让网络关注更重要的特征通道
结构:
全局平均池化 → FC → ReLU → FC → Sigmoid
"""
def __init__(self, num_features, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(num_features, num_features // reduction, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(num_features // reduction, num_features, 1, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# 全局信息聚合
y = self.avg_pool(x)
# 通道权重
y = self.fc(y)
# 通道加权
return x * y
class RCAB(nn.Module):
"""
残差通道注意力块(RCAB)
Conv → ReLU → Conv → CA → + Skip
"""
def __init__(self, num_features, reduction=16, res_scale=1.0):
super().__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_features, num_features, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_features, num_features, 3, padding=1)
self.ca = ChannelAttention(num_features, reduction)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
out = self.ca(out)
out = out * self.res_scale
return out + residual
class ResidualGroup(nn.Module):
"""
残差组(RG)
多个RCAB + 短跳连接
"""
def __init__(self, num_features, num_rcab=20, reduction=16):
super().__init__()
modules = []
for _ in range(num_rcab):
modules.append(RCAB(num_features, reduction))
modules.append(nn.Conv2d(num_features, num_features, 3, padding=1))
self.body = nn.Sequential(*modules)
def forward(self, x):
return self.body(x) + x
class RCAN(nn.Module):
"""
RCAN - Residual Channel Attention Networks(2018)
关键创新:
1. 通道注意力机制
2. 残差组(Residual Group)结构
3. 长短跳连接
"""
def __init__(self, num_channels=3, num_features=64, num_groups=10,
num_rcab=20, reduction=16, scale_factor=4):
super().__init__()
# 头部
self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)
# 主体:残差组
body = []
for _ in range(num_groups):
body.append(ResidualGroup(num_features, num_rcab, reduction))
body.append(nn.Conv2d(num_features, num_features, 3, padding=1))
self.body = nn.Sequential(*body)
# 上采样
self.upsample = self._make_upsample(num_features, scale_factor)
# 尾部
self.tail = nn.Conv2d(num_features, num_channels, 3, padding=1)
def _make_upsample(self, num_features, scale_factor):
layers = []
if scale_factor in [2, 4, 8]:
for _ in range(int(np.log2(scale_factor))):
layers.append(nn.Conv2d(num_features, num_features * 4, 3, padding=1))
layers.append(nn.PixelShuffle(2))
elif scale_factor == 3:
layers.append(nn.Conv2d(num_features, num_features * 9, 3, padding=1))
layers.append(nn.PixelShuffle(3))
return nn.Sequential(*layers)
def forward(self, x):
head_feat = self.head(x)
body_feat = self.body(head_feat) + head_feat
upsampled = self.upsample(body_feat)
out = self.tail(upsampled)
return out
六、GAN-based超分辨率
6.1 SRGAN
python
class SRResNet(nn.Module):
"""
SRResNet - SRGAN的生成器
基于残差块的深度网络
"""
def __init__(self, num_channels=3, num_features=64, num_blocks=16, scale_factor=4):
super().__init__()
# 第一个卷积
self.conv1 = nn.Sequential(
nn.Conv2d(num_channels, num_features, 9, padding=4),
nn.PReLU()
)
# 残差块
self.res_blocks = nn.Sequential(
*[ResidualBlockBN(num_features) for _ in range(num_blocks)]
)
# 第二个卷积
self.conv2 = nn.Sequential(
nn.Conv2d(num_features, num_features, 3, padding=1),
nn.BatchNorm2d(num_features)
)
# 上采样
self.upsample = self._make_upsample(num_features, scale_factor)
# 输出卷积
self.conv3 = nn.Conv2d(num_features, num_channels, 9, padding=4)
def _make_upsample(self, num_features, scale_factor):
layers = []
for _ in range(int(np.log2(scale_factor))):
layers.extend([
nn.Conv2d(num_features, num_features * 4, 3, padding=1),
nn.PixelShuffle(2),
nn.PReLU()
])
return nn.Sequential(*layers)
def forward(self, x):
feat1 = self.conv1(x)
feat2 = self.res_blocks(feat1)
feat2 = self.conv2(feat2) + feat1
upsampled = self.upsample(feat2)
out = self.conv3(upsampled)
return out
class ResidualBlockBN(nn.Module):
"""带BN的残差块(SRGAN使用)"""
def __init__(self, num_features):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(num_features, num_features, 3, padding=1),
nn.BatchNorm2d(num_features),
nn.PReLU(),
nn.Conv2d(num_features, num_features, 3, padding=1),
nn.BatchNorm2d(num_features)
)
def forward(self, x):
return x + self.conv_block(x)
class Discriminator(nn.Module):
"""
SRGAN判别器
VGG风格的分类网络
判断输入是真实HR图像还是生成的SR图像
"""
def __init__(self, input_shape=(3, 96, 96)):
super().__init__()
in_channels, in_height, in_width = input_shape
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, 3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = in_channels
for i, out_filters in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
self.features = nn.Sequential(*layers)
# 计算特征图大小
ds_size = in_height // 2 ** 4
self.classifier = nn.Sequential(
nn.Linear(512 * ds_size * ds_size, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 1)
)
def forward(self, x):
features = self.features(x)
features = features.view(features.size(0), -1)
validity = self.classifier(features)
return validity
6.2 ESRGAN
python
class RRDB(nn.Module):
"""
Residual in Residual Dense Block (RRDB)
ESRGAN的核心模块
比SRResNet的残差块更强大
结构:3个RDB + 残差连接
"""
def __init__(self, num_features, growth_rate=32, res_scale=0.2):
super().__init__()
self.res_scale = res_scale
self.rdb1 = ResidualDenseBlock(num_features, growth_rate)
self.rdb2 = ResidualDenseBlock(num_features, growth_rate)
self.rdb3 = ResidualDenseBlock(num_features, growth_rate)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return x + self.res_scale * out
class RRDBNet(nn.Module):
"""
RRDB Network - ESRGAN的生成器
关键改进:
1. 去除BN层
2. RRDB模块(更强的特征提取能力)
3. 相对判别器
"""
def __init__(self, num_channels=3, num_features=64, num_blocks=23,
growth_rate=32, scale_factor=4):
super().__init__()
# 第一个卷积
self.conv_first = nn.Conv2d(num_channels, num_features, 3, padding=1)
# RRDB模块
self.rrdb_blocks = nn.Sequential(
*[RRDB(num_features, growth_rate) for _ in range(num_blocks)]
)
# 第二个卷积
self.conv_body = nn.Conv2d(num_features, num_features, 3, padding=1)
# 上采样
self.upsample = nn.Sequential(
nn.Conv2d(num_features, num_features * 4, 3, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_features, num_features * 4, 3, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, inplace=True),
)
# 最后的卷积
self.conv_last = nn.Sequential(
nn.Conv2d(num_features, num_features, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_features, num_channels, 3, padding=1)
)
def forward(self, x):
feat_first = self.conv_first(x)
feat_body = self.rrdb_blocks(feat_first)
feat_body = self.conv_body(feat_body)
feat = feat_first + feat_body
upsampled = self.upsample(feat)
out = self.conv_last(upsampled)
return out
七、轻量级超分辨率网络
7.1 FSRCNN:快速超分辨率
python
class FSRCNN(nn.Module):
"""
FSRCNN - Fast Super-Resolution CNN(2016)
关键改进:
1. 后置上采样(在LR空间做卷积)
2. 沙漏结构(通道先扩张后收缩)
3. 转置卷积上采样
比SRCNN快40倍以上
"""
def __init__(self, scale_factor=4, num_channels=1, d=56, s=12, m=4):
"""
Args:
d: 特征提取层的通道数
s: 收缩层的通道数
m: 映射层的数量
"""
super().__init__()
# 特征提取
self.feature_extraction = nn.Sequential(
nn.Conv2d(num_channels, d, kernel_size=5, padding=2),
nn.PReLU()
)
# 收缩
self.shrinking = nn.Sequential(
nn.Conv2d(d, s, kernel_size=1),
nn.PReLU()
)
# 非线性映射
mapping = []
for _ in range(m):
mapping.extend([
nn.Conv2d(s, s, kernel_size=3, padding=1),
nn.PReLU()
])
self.mapping = nn.Sequential(*mapping)
# 扩展
self.expanding = nn.Sequential(
nn.Conv2d(s, d, kernel_size=1),
nn.PReLU()
)
# 转置卷积上采样
self.deconv = nn.ConvTranspose2d(
d, num_channels,
kernel_size=9,
stride=scale_factor,
padding=4,
output_padding=scale_factor - 1
)
def forward(self, x):
feat = self.feature_extraction(x)
shrunk = self.shrinking(feat)
mapped = self.mapping(shrunk)
expanded = self.expanding(mapped)
out = self.deconv(expanded)
return out
7.2 IMDN:信息蒸馏网络
python
class IMDModule(nn.Module):
"""
信息多蒸馏模块(IMDN的核心)
渐进式提取特征,每一步保留部分特征
"""
def __init__(self, in_channels, distillation_rate=0.25):
super().__init__()
self.distilled_channels = int(in_channels * distillation_rate)
self.remaining_channels = int(in_channels - self.distilled_channels)
self.c1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
self.c2 = nn.Conv2d(self.remaining_channels, in_channels, 3, padding=1)
self.c3 = nn.Conv2d(self.remaining_channels, in_channels, 3, padding=1)
self.c4 = nn.Conv2d(self.remaining_channels, self.distilled_channels, 3, padding=1)
self.act = nn.LeakyReLU(0.05, inplace=True)
# 融合层
self.fusion = nn.Conv2d(self.distilled_channels * 4, in_channels, 1)
def forward(self, x):
out1 = self.act(self.c1(x))
distilled1, remaining1 = torch.split(out1, [self.distilled_channels, self.remaining_channels], dim=1)
out2 = self.act(self.c2(remaining1))
distilled2, remaining2 = torch.split(out2, [self.distilled_channels, self.remaining_channels], dim=1)
out3 = self.act(self.c3(remaining2))
distilled3, remaining3 = torch.split(out3, [self.distilled_channels, self.remaining_channels], dim=1)
distilled4 = self.act(self.c4(remaining3))
# 拼接所有蒸馏出的特征
out = torch.cat([distilled1, distilled2, distilled3, distilled4], dim=1)
out = self.fusion(out)
return out + x
class IMDN(nn.Module):
"""
IMDN - Information Multi-Distillation Network(2019)
轻量级超分辨率网络
参数量约715K,性能优秀
"""
def __init__(self, num_channels=3, num_features=64, num_blocks=6, scale_factor=4):
super().__init__()
# 特征提取
self.conv_first = nn.Conv2d(num_channels, num_features, 3, padding=1)
# IMDB模块
self.imdbs = nn.ModuleList([
IMDModule(num_features) for _ in range(num_blocks)
])
# 特征融合
self.fusion = nn.Conv2d(num_features * num_blocks, num_features, 1)
# 上采样
self.upsample = SubPixelUpsampler(num_features, num_features, scale_factor)
# 输出
self.conv_last = nn.Conv2d(num_features, num_channels, 3, padding=1)
def forward(self, x):
feat_first = self.conv_first(x)
features = []
feat = feat_first
for imdb in self.imdbs:
feat = imdb(feat)
features.append(feat)
# 拼接所有块的输出
feat_cat = torch.cat(features, dim=1)
feat_fused = self.fusion(feat_cat) + feat_first
upsampled = self.upsample(feat_fused)
out = self.conv_last(upsampled)
return out
八、完整训练流程
8.1 数据集准备
python
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class SRDataset(Dataset):
"""
超分辨率数据集
"""
def __init__(self, hr_dir, scale_factor=4, patch_size=96, augment=True):
"""
Args:
hr_dir: 高分辨率图像目录
scale_factor: 放大倍数
patch_size: HR patch大小
augment: 是否数据增强
"""
self.hr_dir = hr_dir
self.scale_factor = scale_factor
self.patch_size = patch_size
self.augment = augment
self.lr_patch_size = patch_size // scale_factor
# 获取所有图像路径
self.image_paths = []
for f in os.listdir(hr_dir):
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
self.image_paths.append(os.path.join(hr_dir, f))
# 预处理
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.image_paths)
def random_crop(self, hr_image):
"""随机裁剪HR patch"""
w, h = hr_image.size
# 确保能裁剪出完整的patch
x = np.random.randint(0, w - self.patch_size + 1)
y = np.random.randint(0, h - self.patch_size + 1)
hr_patch = hr_image.crop((x, y, x + self.patch_size, y + self.patch_size))
return hr_patch
def augment_patch(self, hr_patch, lr_patch):
"""数据增强"""
# 随机水平翻转
if np.random.random() < 0.5:
hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
# 随机垂直翻转
if np.random.random() < 0.5:
hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
# 随机旋转90度
if np.random.random() < 0.5:
angle = np.random.choice([90, 180, 270])
hr_patch = hr_patch.rotate(angle)
lr_patch = lr_patch.rotate(angle)
return hr_patch, lr_patch
def __getitem__(self, idx):
# 加载HR图像
hr_image = Image.open(self.image_paths[idx]).convert('RGB')
# 随机裁剪
hr_patch = self.random_crop(hr_image)
# 下采样生成LR patch
lr_patch = hr_patch.resize(
(self.lr_patch_size, self.lr_patch_size),
Image.BICUBIC
)
# 数据增强
if self.augment:
hr_patch, lr_patch = self.augment_patch(hr_patch, lr_patch)
# 转换为tensor
hr_tensor = self.to_tensor(hr_patch)
lr_tensor = self.to_tensor(lr_patch)
return lr_tensor, hr_tensor
class SRTestDataset(Dataset):
"""测试数据集(不裁剪)"""
def __init__(self, lr_dir, hr_dir=None):
self.lr_dir = lr_dir
self.hr_dir = hr_dir
self.lr_paths = sorted([
os.path.join(lr_dir, f) for f in os.listdir(lr_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
if hr_dir:
self.hr_paths = sorted([
os.path.join(hr_dir, f) for f in os.listdir(hr_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
else:
self.hr_paths = None
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.lr_paths)
def __getitem__(self, idx):
lr_image = Image.open(self.lr_paths[idx]).convert('RGB')
lr_tensor = self.to_tensor(lr_image)
if self.hr_paths:
hr_image = Image.open(self.hr_paths[idx]).convert('RGB')
hr_tensor = self.to_tensor(hr_image)
return lr_tensor, hr_tensor, os.path.basename(self.lr_paths[idx])
return lr_tensor, os.path.basename(self.lr_paths[idx])
8.2 训练器
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class SRTrainer:
"""
超分辨率训练器
"""
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建模型
self.model = self._build_model().to(self.device)
# 优化器
self.optimizer = optim.Adam(
self.model.parameters(),
lr=config['lr'],
betas=(0.9, 0.999)
)
# 学习率调度
self.scheduler = optim.lr_scheduler.MultiStepLR(
self.optimizer,
milestones=config['lr_milestones'],
gamma=0.5
)
# 损失函数
self.criterion = self._build_loss()
# TensorBoard
self.writer = SummaryWriter(config['log_dir'])
# 最佳指标
self.best_psnr = 0
def _build_model(self):
"""构建模型"""
model_name = self.config.get('model', 'edsr')
if model_name == 'edsr':
return EDSR(
num_features=self.config.get('num_features', 64),
num_blocks=self.config.get('num_blocks', 16),
scale_factor=self.config['scale_factor']
)
elif model_name == 'rcan':
return RCAN(
num_features=self.config.get('num_features', 64),
num_groups=self.config.get('num_groups', 10),
scale_factor=self.config['scale_factor']
)
else:
raise ValueError(f"Unknown model: {model_name}")
def _build_loss(self):
"""构建损失函数"""
loss_type = self.config.get('loss', 'l1')
if loss_type == 'l1':
return nn.L1Loss()
elif loss_type == 'l2':
return nn.MSELoss()
elif loss_type == 'charbonnier':
return CharbonnierLoss()
else:
raise ValueError(f"Unknown loss: {loss_type}")
def train_epoch(self, train_loader, epoch):
"""训练一个epoch"""
self.model.train()
total_loss = 0
pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
for batch_idx, (lr, hr) in enumerate(pbar):
lr = lr.to(self.device)
hr = hr.to(self.device)
# 前向传播
sr = self.model(lr)
loss = self.criterion(sr, hr)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
# 更新进度条
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
# 记录到TensorBoard
global_step = epoch * len(train_loader) + batch_idx
self.writer.add_scalar('train/loss', loss.item(), global_step)
return total_loss / len(train_loader)
@torch.no_grad()
def validate(self, val_loader):
"""验证"""
self.model.eval()
total_psnr = 0
total_ssim = 0
count = 0
for lr, hr, _ in val_loader:
lr = lr.to(self.device)
hr = hr.to(self.device)
sr = self.model(lr)
# 计算指标
for i in range(sr.size(0)):
sr_np = sr[i].cpu().numpy().transpose(1, 2, 0) * 255
hr_np = hr[i].cpu().numpy().transpose(1, 2, 0) * 255
total_psnr += calculate_psnr(sr_np, hr_np)
total_ssim += calculate_ssim(sr_np[..., 0], hr_np[..., 0])
count += 1
avg_psnr = total_psnr / count
avg_ssim = total_ssim / count
return avg_psnr, avg_ssim
def train(self, train_loader, val_loader, num_epochs):
"""完整训练流程"""
for epoch in range(1, num_epochs + 1):
# 训练
train_loss = self.train_epoch(train_loader, epoch)
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}")
# 更新学习率
self.scheduler.step()
# 验证
if epoch % self.config['val_interval'] == 0:
psnr, ssim = self.validate(val_loader)
print(f"Validation: PSNR = {psnr:.2f}, SSIM = {ssim:.4f}")
self.writer.add_scalar('val/psnr', psnr, epoch)
self.writer.add_scalar('val/ssim', ssim, epoch)
# 保存最佳模型
if psnr > self.best_psnr:
self.best_psnr = psnr
self.save_checkpoint('best.pth', epoch)
print(f"New best model! PSNR = {psnr:.2f}")
# 定期保存
if epoch % self.config['save_interval'] == 0:
self.save_checkpoint(f'epoch_{epoch}.pth', epoch)
def save_checkpoint(self, filename, epoch):
"""保存检查点"""
os.makedirs(self.config['checkpoint_dir'], exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_psnr': self.best_psnr
}, os.path.join(self.config['checkpoint_dir'], filename))
def load_checkpoint(self, path):
"""加载检查点"""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_psnr = checkpoint.get('best_psnr', 0)
return checkpoint.get('epoch', 0)
8.3 训练脚本
python
def main():
"""主函数"""
config = {
# 数据
'train_hr_dir': './data/DIV2K/train_HR',
'val_lr_dir': './data/Set5/LR_bicubic/X4',
'val_hr_dir': './data/Set5/HR',
# 模型
'model': 'edsr',
'num_features': 64,
'num_blocks': 16,
'scale_factor': 4,
# 训练
'batch_size': 16,
'patch_size': 96,
'num_epochs': 300,
'lr': 1e-4,
'lr_milestones': [100, 200],
'loss': 'l1',
# 其他
'val_interval': 10,
'save_interval': 50,
'log_dir': './logs',
'checkpoint_dir': './checkpoints',
'num_workers': 4
}
# 创建数据集
train_dataset = SRDataset(
config['train_hr_dir'],
scale_factor=config['scale_factor'],
patch_size=config['patch_size']
)
val_dataset = SRTestDataset(
config['val_lr_dir'],
config['val_hr_dir']
)
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=config['num_workers'],
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False
)
# 创建训练器
trainer = SRTrainer(config)
# 开始训练
trainer.train(train_loader, val_loader, config['num_epochs'])
print("Training completed!")
if __name__ == '__main__':
main()
九、前沿研究方向
9.1 真实世界超分辨率
┌─────────────────────────────────────────────────────────────────┐
│ 真实世界超分辨率挑战 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 问题: │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 传统方法假设:LR = Bicubic(HR) │ │
│ │ 真实退化包含:模糊、噪声、压缩、传感器噪声... │ │
│ │ 在Bicubic数据上训练的模型在真实图像上效果差 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 解决方案: │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 1. 更复杂的退化模型 │ │
│ │ 模糊 + 下采样 + 噪声 + JPEG压缩 │ │
│ │ │ │
│ │ 2. 盲超分辨率 │ │
│ │ 不假设已知退化核,自动估计退化 │ │
│ │ │ │
│ │ 3. 真实数据集 │ │
│ │ RealSR:用不同焦距拍摄的真实图像对 │ │
│ │ │ │
│ │ 4. 无监督学习 │ │
│ │ 不需要成对数据的训练方法 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
9.2 Transformer超分辨率
近年来Transformer在超分辨率领域的应用:
IPT (2021):
- 预训练图像处理Transformer
- 大规模ImageNet预训练
- 参数量巨大(115M)
SwinIR (2021):
- 基于Swin Transformer
- 局部窗口注意力 + 移位窗口
- 参数量适中(11.8M)
- 在多个benchmark上达到SOTA
ESRT (2021):
- 轻量级Transformer
- 仅751K参数
- 效率和性能的良好平衡
9.3 领域特定应用
人脸超分辨率:
- 利用人脸先验(关键点、解析图)
- 身份保持损失
- 面部属性保持
医学图像超分辨率:
- CT/MRI图像增强
- 3D体数据超分辨率
- 保持医学诊断信息
遥感图像超分辨率:
- 处理大尺寸图像
- 多光谱/高光谱数据
- 时序数据融合
视频超分辨率:
- 利用时间冗余
- 光流对齐
- 实时处理需求
十、总结与展望
10.1 核心要点回顾
┌─────────────────────────────────────────────────────────────────┐
│ 图像超分辨率核心知识 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 问题本质: │
│ • 从低分辨率图像恢复高分辨率图像 │
│ • 病态问题,需要学习图像先验 │
│ │
│ 关键技术: │
│ • 上采样:PixelShuffle最常用 │
│ • 网络设计:残差学习、注意力机制 │
│ • 损失函数:像素损失+感知损失+对抗损失 │
│ │
│ 评估指标: │
│ • PSNR/SSIM:重建精度 │
│ • LPIPS/NIQE:感知质量 │
│ │
│ 发展趋势: │
│ • 真实世界退化 │
│ • 轻量级网络 │
│ • Transformer架构 │
│ • 领域特定应用 │
│ │
└─────────────────────────────────────────────────────────────────┘
10.2 未来方向
1. 轻量高效
- 边缘设备部署
- 实时处理需求
- 模型压缩量化
2. 真实场景
- 复杂退化建模
- 盲超分辨率
- 少样本/无监督学习
3. 新架构探索
- Transformer + CNN混合
- 扩散模型
- 神经隐式表示
4. 联合任务
- 超分辨率 + 去噪
- 超分辨率 + 去模糊
- 超分辨率 + 目标检测
希望这篇文章帮助你全面理解了图像超分辨率技术!如有问题,欢迎评论区交流。