目录
1. 问题定义与数学建模
1.1 问题描述
超分辨率(Super-Resolution, SR)是从低分辨率(Low-Resolution, LR)图像重建高分辨率(High-Resolution, HR)图像的逆问题。
超分辨率重建
LR 输入 (64×64) HR 输出 (256×256)
┌──────────────┐ ┌────────────────────────────┐
│ │ │ │
│ 模糊的 │ SR │ 清晰的 │
│ 低分辨率 │ ──────► │ 高分辨率 │
│ 图像 │ 算法 │ 图像 │
│ │ │ │
└──────────────┘ │ │
│ │
└────────────────────────────┘
放大倍数: ×4 (面积放大 16 倍)
1.2 数学模型
图像降质模型
LR 图像的生成过程:
y = D(H(x)) + n
其中:
x ∈ ℝᴴˣᵂ --- 高分辨率图像
H --- 模糊算子(光学模糊、运动模糊等)
D --- 下采样算子(降采样因子 s)
n --- 噪声(高斯噪声、泊松噪声等)
y ∈ ℝʰˣʷ --- 低分辨率图像 (h = H/s, w = W/s)
逆问题求解
超分辨率的目标:
x̂ = argmin_x ‖y - D(H(x))‖² + λ·R(x)
───────────────── ──────
数据保真项 正则化项
正则化项 R(x) 的作用:
- 引入先验知识
- 约束解空间
- 抑制噪声放大
1.3 问题分类
┌─────────────────────────────────────────────────────────────┐
│ 超分辨率问题分类 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 按输入数量: │
│ ├── 单帧超分辨率 (Single Image SR, SISR) │
│ │ └─ 从单张 LR 图像重建 HR │
│ └─ 多帧超分辨率 (Multi-Frame SR) │
│ └─ 利用多帧 LR 图像的互补信息 │
│ │
│ 按放大倍数: │
│ ├── 固定倍数 (×2, ×4, ×8) │
│ └─ 任意倍数 (Arbitrary Scale) │
│ │
│ 按退化类型: │
│ ├── 理想退化 (Bicubic 下采样) │
│ ├── 盲超分辨率 (Blind SR, 未知退化) │
│ └─ 真实世界超分辨率 (Real-World SR) │
│ │
│ 按应用领域: │
│ ├── 自然图像超分辨率 │
│ ├── 人脸超分辨率 (Face SR) │
│ ├── 遥感图像超分辨率 │
│ └─ 医学图像超分辨率 │
│ │
└─────────────────────────────────────────────────────────────┘
2. 传统方法
2.1 基于插值的方法
最近邻插值(Nearest Neighbor)
python
def nearest_neighbor_interpolation(lr_image, scale_factor):
"""
原理: 将每个 LR 像素直接复制到对应 HR 区域
优点: 计算简单,速度快
缺点: 产生明显的锯齿和块状伪影
"""
h, w = lr_image.shape[:2]
H, W = h * scale_factor, w * scale_factor
hr_image = np.zeros((H, W), dtype=lr_image.dtype)
for i in range(H):
for j in range(W):
# 映射回 LR 坐标
src_i = min(int(i / scale_factor), h - 1)
src_j = min(int(j / scale_factor), w - 1)
hr_image[i, j] = lr_image[src_i, src_j]
return hr_image
双线性插值(Bilinear Interpolation)
python
def bilinear_interpolation(lr_image, scale_factor):
"""
原理: 在两个方向上分别进行线性插值
公式: f(x,y) = f(0,0)(1-x)(1-y) + f(1,0)x(1-y)
+ f(0,1)(1-x)y + f(1,1)xy
优点: 比最近邻平滑,计算量适中
缺点: 过度平滑,丢失高频细节
"""
h, w = lr_image.shape[:2]
H, W = h * scale_factor, w * scale_factor
hr_image = np.zeros((H, W), dtype=np.float64)
for i in range(H):
for j in range(W):
# 映射到 LR 坐标系
x = j / scale_factor
y = i / scale_factor
# 四个最近邻像素坐标
x0, y0 = int(x), int(y)
x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1)
# 插值权重
dx, dy = x - x0, y - y0
# 双线性插值计算
hr_image[i, j] = (
lr_image[y0, x0] * (1 - dx) * (1 - dy) +
lr_image[y0, x1] * dx * (1 - dy) +
lr_image[y1, x0] * (1 - dx) * dy +
lr_image[y1, x1] * dx * dy
)
return hr_image
双三次插值(Bicubic Interpolation)
python
def bicubic_kernel(x, a=-0.5):
"""
双三次插值核函数 (Keys 核)
a = -0.5 时为标准 Catmull-Rom 样条
"""
x = abs(x)
if x <= 1:
return (a + 2) * x**3 - (a + 3) * x**2 + 1
elif x < 2:
return a * x**3 - 5 * a * x**2 + 8 * a * x - 4 * a
else:
return 0
def bicubic_interpolation(lr_image, scale_factor):
"""
原理: 使用 4×4 邻域的加权求和
优点: 边缘保持较好,视觉质量高
缺点: 计算量较大,仍会丢失高频信息
这是大多数 SR 论文的默认退化方式
"""
h, w = lr_image.shape[:2]
H, W = h * scale_factor, w * scale_factor
hr_image = np.zeros((H, W), dtype=np.float64)
for i in range(H):
for j in range(W):
x = j / scale_factor
y = i / scale_factor
x0, y0 = int(x), int(y)
pixel = 0.0
for m in range(-1, 3):
for n in range(-1, 3):
# 边界处理
px = min(max(x0 + m, 0), w - 1)
py = min(max(y0 + n, 0), h - 1)
# 权重 = W(x) * W(y)
wx = bicubic_kernel(x - (x0 + m))
wy = bicubic_kernel(y - (y0 + n))
pixel += lr_image[py, px] * wx * wy
hr_image[i, j] = np.clip(pixel, 0, 255)
return hr_image
2.2 基于重建的方法
稀疏编码(Sparse Coding)
python
class SparseCodingSR:
"""
原理: LR 和 HR 图像块共享相同的稀疏表示
流程:
1. 训练 LR-HR 字典对 (D_l, D_h)
2. 对 LR 图像块求稀疏系数 α
3. 用 α 和 D_h 重建 HR 图像块
"""
def __init__(self, patch_size=5, dict_size=1024, sparsity=3):
self.patch_size = patch_size
self.dict_size = dict_size
self.sparsity = sparsity
def train_dictionary(self, lr_patches, hr_patches):
"""
联合字典学习
目标: min ‖Y_l - D_l·α‖² + ‖Y_h - D_h·α‖² + λ‖α‖₁
"""
# 使用 K-SVD 或在线字典学习
from sklearn.decomposition import MiniBatchDictionaryLearning
# LR 字典
self.dict_lr = MiniBatchDictionaryLearning(
n_components=self.dict_size, alpha=1.0
).fit(lr_patches)
# HR 字典 (使用相同的稀疏系数)
self.dict_hr = MiniBatchDictionaryLearning(
n_components=self.dict_size, alpha=1.0
).fit(hr_patches)
def reconstruct(self, lr_image, scale_factor):
"""
超分辨率重建
"""
# 提取 LR 图像块
lr_patches = extract_patches(lr_image, self.patch_size)
# 求稀疏系数
sparse_codes = self.dict_lr.transform(lr_patches)
# 用 HR 字典重建
hr_patches = np.dot(sparse_codes, self.dict_hr.components_)
# 聚合重叠的图像块
hr_image = aggregate_patches(hr_patches, scale_factor)
return hr_image
自相似性方法(Self-Example)
核心思想: 图像内部存在跨尺度的相似结构
┌─────────────────────────────────────────────────────────────┐
│ │
│ 原图中的小结构 ────────相似───────► 放大后的结构 │
│ │
│ ┌───┐ ┌─────────┐ │
│ │ ▪ │ 在原图的不同尺度搜索相似块 │ ▪ ▪ │ │
│ └───┘ ─────────────────────────► │ │ │
│ 利用高分辨率细节填充 │ ▪ ▪ │ │
│ └─────────┘ │
│ │
│ 代表方法: Glasner et al. (2009) │
│ │
└─────────────────────────────────────────────────────────────┘
2.3 基于边缘的方法
python
class EdgeDirectedSR:
"""
核心思想: 先重建边缘,再以边缘为引导重建全图
流程:
1. 检测 LR 图像边缘
2. 预测 HR 边缘方向和强度
3. 沿边缘方向插值
4. 以边缘为约束优化全图
"""
def reconstruct(self, lr_image, scale_factor):
# Step 1: 边缘检测
edges_lr = self.detect_edges(lr_image)
# Step 2: 边缘方向估计
directions = self.estimate_edge_directions(edges_lr)
# Step 3: 方向自适应插值
hr_image = self.directional_interpolation(
lr_image, directions, scale_factor
)
# Step 4: 边缘引导优化
hr_image = self.edge_guided_optimization(
hr_image, edges_lr, scale_factor
)
return hr_image
3. 基于深度学习的方法
3.1 SRCNN(2014)--- 开山之作
SRCNN: Super-Resolution Convolutional Neural Network
论文: "Image Super-Resolution Using Deep Convolutional Networks" (Dong et al., 2014)
架构:
LR 输入 → [Bicubic 上采样] → [特征提取] → [非线性映射] → [重建] → HR 输出
┌─────────────────────────────────────────────────────────────┐
│ │
│ LR Image ──► Bicubic ──► Conv1 ──► Conv2 ──► Conv3 ──► HR│
│ (低分辨率) (上采样) (9×9) (1×1) (5×5) │
│ │
│ 输入: 上采样后的 LR 图像 (已插值到目标尺寸) │
│ 输出: 高分辨率图像 │
│ │
└─────────────────────────────────────────────────────────────┘
python
import torch
import torch.nn as nn
class SRCNN(nn.Module):
"""
三层 CNN 实现超分辨率
Layer 1: 特征提取 (patch extraction)
Layer 2: 非线性映射 (non-linear mapping)
Layer 3: 重建 (reconstruction)
"""
def __init__(self, num_channels=1):
super().__init__()
# 特征提取层
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=4)
# 非线性映射层
self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
# 重建层
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
前向传播
注意: 输入需要先经过 Bicubic 上采样到目标尺寸
"""
# 特征提取
x = self.relu(self.conv1(x)) # [B, 64, H, W]
# 非线性映射
x = self.relu(self.conv2(x)) # [B, 32, H, W]
# 重建
x = self.conv3(x) # [B, C, H, W]
return x
# 训练配置
"""
损失函数: MSE (像素级)
优化器: SGD 或 Adam
评估: PSNR / SSIM
局限性:
- 先上采样再处理,计算量大
- 感受野有限
- 难以恢复高频细节
"""
3.2 FSRCNN(2016)--- 加速版本
FSRCNN: Fast Super-Resolution CNN
论文: "Accelerating the Super-Resolution Convolutional Neural Network" (Dong et al., 2016)
改进:
1. 直接在 LR 空间操作(不上采样)
2. 末端使用转置卷积进行上采样
3. 使用更小的卷积核和更多层
架构:
LR → [特征提取] → [收缩] → [映射] × d → [扩展] → [反卷积上采样] → HR
python
class FSRCNN(nn.Module):
def __init__(self, scale_factor=4, num_channels=1, d=56, s=12, m=4):
super().__init__()
# 特征提取
self.feature_extraction = nn.Conv2d(num_channels, d, kernel_size=5, padding=2)
# 收缩层 (降维)
self.shrinking = nn.Conv2d(d, s, kernel_size=1)
# 非线性映射 (多层)
self.mapping = nn.Sequential(*[
nn.Sequential(
nn.Conv2d(s, s, kernel_size=3, padding=1),
nn.PReLU()
) for _ in range(m)
])
# 扩展层 (升维)
self.expanding = nn.Conv2d(s, d, kernel_size=1)
# 反卷积上采样 (核心改进)
self.deconv = nn.ConvTranspose2d(
d, num_channels,
kernel_size=9,
stride=scale_factor,
padding=4,
output_padding=scale_factor - 1
)
self.prelu = nn.PReLU()
def forward(self, x):
# 直接在 LR 空间操作
x = self.prelu(self.feature_extraction(x)) # [B, d, h, w]
x = self.prelu(self.shrinking(x)) # [B, s, h, w]
x = self.mapping(x) # [B, s, h, w]
x = self.prelu(self.expanding(x)) # [B, d, h, w]
x = self.deconv(x) # [B, C, H, W]
return x
"""
速度对比:
- SRCNN: 0.43s (×4, 256×256)
- FSRCNN: 0.015s (×4, 256×256)
- 加速约 28 倍
"""
3.3 ESPCN(2016)--- 亚像素卷积
ESPCN: Efficient Sub-Pixel CNN
论文: "Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel
Convolutional Neural Network" (Shi et al., 2016)
核心创新: 亚像素卷积层 (Sub-Pixel Convolution / PixelShuffle)
原理:
将 r² 个通道的特征图重排为 r×r 的高分辨率图像
[B, C×r², H, W] → [B, C, H×r, W×r]
python
class PixelShuffle(nn.Module):
"""
亚像素卷积层
输入: [B, C × r², H, W]
输出: [B, C, H × r, W × r]
原理: 将通道维度重排为空间维度
"""
def __init__(self, upscale_factor):
super().__init__()
self.r = upscale_factor
def forward(self, x):
B, C, H, W = x.shape
r = self.r
# C = C_out × r²
assert C % (r * r) == 0
C_out = C // (r * r)
# 重排: [B, C_out×r², H, W] → [B, C_out, H×r, W×r]
return x.view(B, C_out, r, r, H, W).permute(0, 1, 4, 2, 5, 3).reshape(
B, C_out, H * r, W * r
)
class ESPCN(nn.Module):
def __init__(self, scale_factor=4, num_channels=1):
super().__init__()
# 特征提取 (在 LR 空间)
self.feature_extraction = nn.Sequential(
nn.Conv2d(num_channels, 64, kernel_size=5, padding=2),
nn.Tanh(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.Tanh()
)
# 亚像素上采样
self.sub_pixel = nn.Sequential(
nn.Conv2d(32, num_channels * scale_factor ** 2, kernel_size=3, padding=1),
PixelShuffle(scale_factor)
)
def forward(self, x):
x = self.feature_extraction(x)
x = self.sub_pixel(x)
return x
"""
优势:
- 计算效率高: 上采样在最后一步完成
- 实时性好: 适合视频处理
- 无棋盘格伪影 (相比反卷积)
"""
3.4 VDSR(2016)--- 残差学习
VDSR: Very Deep Super-Resolution
论文: "Accurate Image Super-Resolution Using Very Deep Convolutional Networks"
(Kim et al., 2016)
核心创新:
1. 残差学习: 学习 HR - LR_bicubic 的残差
2. 深层网络: 20 层卷积
3. 高学习率: 残差学习允许更大的学习率
python
class VDSR(nn.Module):
"""
残差学习架构
核心公式: HR = LR_bicubic + Residual
网络只需学习残差部分,收敛更快
"""
def __init__(self, num_channels=1, num_layers=20):
super().__init__()
layers = []
# 第一层
layers.append(nn.Conv2d(num_channels, 64, kernel_size=3, padding=1))
layers.append(nn.ReLU(inplace=True))
# 中间层 (18 层)
for _ in range(num_layers - 2):
layers.append(nn.Conv2d(64, 64, kernel_size=3, padding=1))
layers.append(nn.ReLU(inplace=True))
# 最后一层
layers.append(nn.Conv2d(64, num_channels, kernel_size=3, padding=1))
self.network = nn.Sequential(*layers)
# 残差缩放因子 (可学习)
self.residual_scale = nn.Parameter(torch.tensor(0.1))
def forward(self, lr_bicubic):
"""
输入: Bicubic 上采样后的 LR 图像
输出: 超分辨率结果
HR = LR_bicubic + scale × Network(LR_bicubic)
"""
residual = self.network(lr_bicubic)
return lr_bicubic + self.residual_scale * residual
"""
残差学习的优势:
1. 梯度流更顺畅 (恒等映射)
2. 收敛速度更快
3. 可以训练更深的网络
4. 学习残差比学习完整图像更容易
"""
3.5 EDSR(2017)--- 增强型残差网络
EDSR: Enhanced Deep Residual Network
论文: "Enhanced Deep Residual Networks for Single Image Super-Resolution"
(Lim et al., 2017)
核心改进:
1. 移除 BatchNorm (对 SR 任务有害)
2. 简化残差块结构
3. 使用缩放残差 (residual scaling)
python
class ResidualBlock(nn.Module):
"""
EDSR 残差块
改进:
- 移除 BN 层 (BN 会归一化特征,丢失范围信息)
- 使用残差缩放 (×0.1) 稳定训练
"""
def __init__(self, channels, residual_scale=0.1):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.residual_scale = residual_scale
def forward(self, x):
residual = self.conv1(x)
residual = self.relu(residual)
residual = self.conv2(residual)
return x + residual * self.residual_scale
class EDSR(nn.Module):
def __init__(self, scale_factor=4, num_channels=3,
num_features=256, num_blocks=32):
super().__init__()
# 浅层特征提取
self.head = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
# 深层特征提取 (残差块堆叠)
self.body = nn.Sequential(*[
ResidualBlock(num_features) for _ in range(num_blocks)
])
# 上采样模块
self.upsample = UpsampleModule(num_features, scale_factor)
# 重建层
self.tail = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
def forward(self, x):
# 浅层特征
shallow = self.head(x)
# 深层特征 (带全局残差)
deep = self.body(shallow)
deep = deep + shallow # 全局残差连接
# 上采样
upsampled = self.upsample(deep)
# 重建
output = self.tail(upsampled)
return output
class UpsampleModule(nn.Module):
"""
上采样模块: 亚像素卷积实现
"""
def __init__(self, channels, scale_factor):
super().__init__()
if scale_factor == 2:
self.up = nn.Sequential(
nn.Conv2d(channels, channels * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2)
)
elif scale_factor == 4:
self.up = nn.Sequential(
nn.Conv2d(channels, channels * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2),
nn.Conv2d(channels, channels * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2)
)
def forward(self, x):
return self.up(x)
"""
EDSR 在 DIV2K 数据集上的表现 (PSNR):
×2: 34.65 dB
×3: 30.92 dB
×4: 28.80 dB
"""
3.6 RDN(2018)--- 密集连接网络
RDN: Residual Dense Network
论文: "Residual Dense Network for Image Super-Resolution" (Zhang et al., 2018)
核心创新:
1. 残差密集块 (RDB): 块内密集连接
2. 特征融合: 从所有 RDB 提取特征
3. 层级特征利用
python
class ResidualDenseBlock(nn.Module):
"""
残差密集块 (RDB)
特点: 每层都与前面所有层连接
"""
def __init__(self, channels, growth_rate=32, num_layers=8):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
nn.Conv2d(channels + i * growth_rate, growth_rate,
kernel_size=3, padding=1)
)
# 1×1 卷积融合
self融合 = nn.Conv2d(
channels + num_layers * growth_rate,
channels,
kernel_size=1
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
features = [x]
for layer in self.layers:
out = self.relu(layer(torch.cat(features, dim=1)))
features.append(out)
# 融合所有特征
fused = self融合(torch.cat(features, dim=1))
# 局部残差
return x + fused * 0.2
class RDN(nn.Module):
def __init__(self, scale_factor=4, num_channels=3,
num_features=64, num_blocks=16, growth_rate=32):
super().__init__()
# 浅层特征
self.sfe1 = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
self.sfe2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
# 残差密集块
self.rdbs = nn.ModuleList([
ResidualDenseBlock(num_features, growth_rate)
for _ in range(num_blocks)
])
# 特征融合
self.gff = nn.Sequential(
nn.Conv2d(num_blocks * num_features, num_features, kernel_size=1),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
)
# 上采样和重建
self.upsample = UpsampleModule(num_features, scale_factor)
self重建 = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
def forward(self, x):
# 浅层特征
sfe1 = self.sfe1(x)
sfe2 = self.sfe2(sfe1)
# RDB 特征收集
rdb_features = []
out = sfe2
for rdb in self.rdbs:
out = rdb(out)
rdb_features.append(out)
# 全局特征融合 (带全局残差)
fused = self.gff(torch.cat(rdb_features, dim=1))
fused = fused + sfe1 # 全局残差
# 上采样和重建
output = self.upsample(fused)
output = self重建(output)
return output
3.7 RCAN(2018)--- 通道注意力
RCAN: Residual Channel Attention Network
论文: "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"
(Zhang et al., 2018)
核心创新:
1. 通道注意力机制 (Channel Attention)
2. 残差中的残差 (RIR) 结构
3. 超深网络 (400+ 层)
python
class ChannelAttention(nn.Module):
"""
通道注意力模块
原理: 学习每个通道的重要性权重
方法: 全局池化 → 全连接 → Sigmoid
"""
def __init__(self, channels, reduction=16):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.shape
# 全局信息压缩
y = self.pool(x).view(b, c)
# 通道权重
y = self.fc(y).view(b, c, 1, 1)
# 加权
return x * y.expand_as(x)
class RCAB(nn.Module):
"""
残差通道注意力块
"""
def __init__(self, channels, reduction=16):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.ca = ChannelAttention(channels, reduction)
def forward(self, x):
residual = self.conv1(x)
residual = self.relu(residual)
residual = self.conv2(residual)
# 通道注意力
residual = self.ca(residual)
return x + residual
class RCAN(nn.Module):
"""
残差中的残差网络 (RIR)
结构:
Residual Group → Residual Block → Channel Attention
"""
def __init__(self, scale_factor=4, num_channels=3,
num_features=64, num_groups=10, num_blocks=20):
super().__init__()
# 浅层特征
self.head = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
# 残差组
self.groups = nn.ModuleList([
ResidualGroup(num_features, num_blocks)
for _ in range(num_groups)
])
# 组间融合
self融合 = nn.Conv2d(num_features * num_groups, num_features, kernel_size=1)
# 全局残差
self.tail_conv = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
# 上采样
self.upsample = UpsampleModule(num_features, scale_factor)
self重建 = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
def forward(self, x):
head = self.head(x)
# 收集各组特征
group_outputs = []
out = head
for group in self.groups:
out = group(out)
group_outputs.append(out)
# 融合 + 全局残差
fused = self融合(torch.cat(group_outputs, dim=1))
out = self.tail_conv(fused) + head
# 上采样重建
out = self.upsample(out)
out = self重建(out)
return out
"""
通道注意力的作用:
- 不同通道捕获不同的特征 (边缘、纹理、颜色等)
- 注意力机制让网络聚焦于重要通道
- 类似于 SENet 的思想,但应用于 SR 任务
"""
3.8 SwinIR(2021)--- Transformer 架构
SwinIR: Swin Transformer for Image Restoration
论文: "SwinIR: Image Restoration Using Swin Transformer" (Liang et al., 2021)
核心创新:
1. 将 Swin Transformer 应用于图像恢复
2. 移位窗口注意力 (Shifted Window Attention)
3. 长距离依赖建模
python
import torch
import torch.nn as nn
from einops import rearrange
class WindowAttention(nn.Module):
"""
窗口注意力机制
将特征图划分为不重叠的窗口,在窗口内计算注意力
大幅降低计算复杂度: O(n²) → O(n × w²)
"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# 相对位置编码
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x):
B, N, C = x.shape
# QKV 投影
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# 注意力计算
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
# 加权聚合
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
return out
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer 块
包含:
1. 窗口注意力 (W-MSA)
2. 移位窗口注意力 (SW-MSA)
3. MLP
"""
def __init__(self, dim, num_heads, window_size=8, shift_size=0):
super().__init__()
self.shift_size = shift_size
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
# 窗口注意力
shortcut = x
x = self.norm1(x)
x = self.attn(x) + shortcut
# MLP
shortcut = x
x = self.norm2(x)
x = self.mlp(x) + shortcut
return x
class SwinIR(nn.Module):
"""
Swin Transformer 超分辨率网络
"""
def __init__(self, scale_factor=4, num_channels=3,
embed_dim=180, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=8):
super().__init__()
# 浅层特征提取
self.conv_first = nn.Conv2d(num_channels, embed_dim, kernel_size=3, padding=1)
# 深层特征提取 (Swin Transformer 块)
self.layers = nn.ModuleList()
for i, depth in enumerate(depths):
self.layers.append(
nn.Sequential(*[
SwinTransformerBlock(embed_dim, num_heads[i], window_size)
for _ in range(depth)
])
)
# 融合
self融合 = nn.Conv2d(embed_dim * len(depths), embed_dim, kernel_size=1)
# 上采样
self.upsample = UpsampleModule(embed_dim, scale_factor)
# 重建
self重建 = nn.Conv2d(embed_dim, num_channels, kernel_size=3, padding=1)
def forward(self, x):
# 浅层特征
shallow = self.conv_first(x)
# 深层特征
layer_outputs = []
out = shallow
for layer in self.layers:
out = layer(out)
layer_outputs.append(out)
# 融合
fused = self融合(torch.cat(layer_outputs, dim=1))
out = fused + shallow # 全局残差
# 上采样重建
out = self.upsample(out)
out = self重建(out)
return out
4. 生成对抗网络方法
4.1 SRGAN(2017)
SRGAN: Super-Resolution GAN
论文: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
(Ledig et al., 2017)
核心思想:
使用 GAN 训练框架,生成器学习超分辨率,判别器区分真实 HR 和生成 HR
目标: 生成视觉上逼真的图像,而非仅仅优化 PSNR
python
class SRResNet(nn.Module):
"""
生成器网络 (基于 ResNet)
"""
def __init__(self, num_channels=3, num_features=64, num_blocks=16):
super().__init__()
# 浅层特征
self.head = nn.Sequential(
nn.Conv2d(num_channels, num_features, kernel_size=9, padding=4),
nn.PReLU()
)
# 残差块
self.body = nn.Sequential(*[ResidualBlock(num_features) for _ in range(num_blocks)])
# 融合
self融合 = nn.Sequential(
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features)
)
# 上采样 (亚像素卷积)
self.upsample = nn.Sequential(
nn.Conv2d(num_features, num_features * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2),
nn.PReLU(),
nn.Conv2d(num_features, num_features * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2),
nn.PReLU()
)
# 重建
self.tail = nn.Conv2d(num_channels, num_channels, kernel_size=9, padding=4)
def forward(self, x):
head = self.head(x)
body = self.body(head)
out = self融合(body) + head
out = self.upsample(out)
out = self.tail(out)
return out
class Discriminator(nn.Module):
"""
判别器网络
判断输入图像是真实的 HR 还是生成的 SR
"""
def __init__(self, num_channels=3):
super().__init__()
def block(in_channels, out_channels, stride):
return [
nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
]
self.features = nn.Sequential(
*block(num_channels, 64, 1), # 无下采样
*block(64, 64, 2), # 下采样 ×2
*block(64, 128, 1),
*block(128, 128, 2), # 下采样 ×2
*block(128, 256, 1),
*block(256, 256, 2), # 下采样 ×2
*block(256, 512, 1),
*block(512, 512, 2) # 下采样 ×2
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
features = self.features(x)
output = self.classifier(features)
return output
class SRGAN:
"""
SRGAN 训练框架
"""
def __init__(self, generator, discriminator,
content_weight=1, adversarial_weight=0.001):
self.G = generator
self.D = discriminator
self.content_weight = content_weight
self.adversarial_weight = adversarial_weight
# 损失函数
self.content_loss = nn.MSELoss()
self.adversarial_loss = nn.BCELoss()
# VGG 特征提取器 (用于感知损失)
self.vgg = self.build_vgg_feature_extractor()
def train_step(self, lr, hr):
# ──── 训练判别器 ────
# 真实图像标签为 1
real_output = self.D(hr)
d_loss_real = self.adversarial_loss(real_output, torch.ones_like(real_output))
# 生成 SR 图像
sr = self.G(lr)
# 生成图像标签为 0
fake_output = self.D(sr.detach())
d_loss_fake = self.adversarial_loss(fake_output, torch.zeros_like(fake_output))
d_loss = (d_loss_real + d_loss_fake) / 2
# ──── 训练生成器 ────
# 内容损失 (MSE)
loss_content = self.content_loss(sr, hr)
# 感知损失 (VGG 特征)
sr_features = self.vgg(sr)
hr_features = self.vgg(hr)
loss_perceptual = self.content_loss(sr_features, hr_features)
# 对抗损失
fake_output = self.D(sr)
loss_adversarial = self.adversarial_loss(fake_output, torch.ones_like(fake_output))
# 总损失
g_loss = (self.content_weight * (loss_content + loss_perceptual) +
self.adversarial_weight * loss_adversarial)
return d_loss, g_loss
4.2 ESRGAN(2018)
ESRGAN: Enhanced Super-Resolution GAN
论文: "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks"
(Wang et al., 2018)
改进:
1. 移除 BN 层
2. 使用 Residual-in-Residual Dense Block (RRDB)
3. 相对判别器 (Relativistic Discriminator)
4. VGG19 感知损失
python
class RRDB(nn.Module):
"""
Residual-in-Residual Dense Block
结构: 残差块内嵌密集连接
"""
def __init__(self, channels, growth_rate=32, num_layers=3):
super().__init__()
self.rdb1 = ResidualDenseBlock(channels, growth_rate, num_layers)
self.rdb2 = ResidualDenseBlock(channels, growth_rate, num_layers)
self.rdb3 = ResidualDenseBlock(channels, growth_rate, num_layers)
self.residual_scale = 0.2
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return x + out * self.residual_scale
class RelativisticDiscriminator(nn.Module):
"""
相对判别器
标准 GAN: D(x) = P(x 是真实的)
相对 GAN: D(x, y) = P(x 比 y 更真实)
优势: 判别器不仅判断真假,还判断相对真实度
"""
def __init__(self, num_channels=3):
super().__init__()
# 与标准判别器结构相同
self.features = self.build_features(num_channels)
self.classifier = self.build_classifier()
def forward(self, real, fake):
real_features = self.features(real)
fake_features = self.features(fake)
# 相对真实度
real_output = self.classifier(real_features)
fake_output = self.classifier(fake_features)
return real_output, fake_output
def relativistic_loss(discriminator, real, fake):
"""
相对对抗损失
"""
real_output, fake_output = discriminator(real, fake)
# 判别器损失
d_loss = (
F.binary_cross_entropy(real_output, torch.ones_like(real_output)) +
F.binary_cross_entropy(fake_output, torch.zeros_like(fake_output))
) / 2
# 生成器损失 (目标: 生成的比真实的更真实)
g_loss = (
F.binary_cross_entropy(fake_output, torch.ones_like(fake_output)) +
F.binary_cross_entropy(real_output, torch.zeros_like(real_output))
) / 2
return d_loss, g_loss
4.3 Real-ESRGAN(2021)
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
论文: "Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data"
(Wang et al., 2021)
核心创新:
1. 高阶退化建模 (二阶退化)
2. U-Net 判别器 (谱归一化)
3. 面向真实世界的退化模型
python
class HighOrderDegradation:
"""
高阶退化模型
真实世界的图像退化是复杂的,一阶模型不够
Real-ESRGAN 使用二阶退化模拟
一阶退化: y = (x ⊗ k) ↓s + n
二阶退化: y = ((x ⊗ k₁) ↓s₁ + n₁) ⊗ k₂ ↓s₂ + n₂
"""
def __init__(self):
self.degradation_types = [
'blur', # 模糊 (高斯、运动)
'downsample', # 下采样
'noise', # 噪声 (高斯、泊松、JPEG)
'jpeg_compression'# JPEG 压缩
]
def apply_first_order(self, hr_image):
"""一阶退化"""
# 随机模糊核
kernel = self.random_blur_kernel()
blurred = self.apply_blur(hr_image, kernel)
# 随机下采样
scale = random.choice([2, 4])
downsampled = self.downsample(blurred, scale)
# 随机噪声
noise_type = random.choice(['gaussian', 'poisson'])
noisy = self.add_noise(downsampled, noise_type)
# JPEG 压缩
if random.random() > 0.5:
noisy = self.jpeg_compress(noisy)
return noisy
def apply_second_order(self, hr_image):
"""二阶退化 (更接近真实)"""
# 第一阶
degraded = self.apply_first_order(hr_image)
# 第二阶 (再次退化)
degraded = self.apply_first_order(degraded)
return degraded
class UNetDiscriminator(nn.Module):
"""
U-Net 判别器
不仅输出全局真伪判断,还输出逐像素的真伪图
提供更丰富的梯度信息
"""
def __init__(self, num_channels=3, num_features=64):
super().__init__()
# 编码器
self.enc1 = self.conv_block(num_channels, num_features)
self.enc2 = self.conv_block(num_features, num_features * 2)
self.enc3 = self.conv_block(num_features * 2, num_features * 4)
# 瓶颈
self.bottleneck = self.conv_block(num_features * 4, num_features * 8)
# 解码器
self.dec3 = self.conv_block(num_features * 8 + num_features * 4, num_features * 4)
self.dec2 = self.conv_block(num_features * 4 + num_features * 2, num_features * 2)
self.dec1 = self.conv_block(num_features * 2 + num_features, num_features)
# 输出层 (逐像素判别)
self.output = nn.Conv2d(num_features, 1, kernel_size=1)
# 谱归一化
self.apply_spectral_norm()
def conv_block(self, in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x):
# 编码
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
# 瓶颈
b = self.bottleneck(e3)
# 解码 (带跳跃连接)
d3 = self.dec3(torch.cat([b, e3], dim=1))
d2 = self.dec2(torch.cat([d3, e2], dim=1))
d1 = self.dec1(torch.cat([d2, e1], dim=1))
# 逐像素输出
output = self.output(d1)
return output
5. 扩散模型方法
5.1 SR3(2021)
SR3: Image Super-Resolution via Iterative Refinement
论文: "Image Super-Resolution via Iterative Refinement" (Saharia et al., 2021)
核心思想:
将超分辨率建模为条件扩散过程
正向过程: 逐步向 HR 图像添加噪声
逆向过程: 从噪声中恢复 HR 图像 (以 LR 为条件)
python
class ConditionalUNet(nn.Module):
"""
条件 U-Net
输入: 噪声图像 x_t + LR 条件图像
输出: 预测的噪声 ε_θ
"""
def __init__(self, in_channels=6, out_channels=3, features=128):
super().__init__()
# 时间嵌入
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(features),
nn.Linear(features, features * 4),
nn.GELU(),
nn.Linear(features * 4, features)
)
# 编码器
self.enc1 = self.conv_block(in_channels, features)
self.enc2 = self.conv_block(features, features * 2)
self.enc3 = self.conv_block(features * 2, features * 4)
# 瓶颈
self.bottleneck = self.conv_block(features * 4, features * 8)
# 解码器 (带跳跃连接)
self.dec3 = self.conv_block(features * 8 + features * 4, features * 4)
self.dec2 = self.conv_block(features * 4 + features * 2, features * 2)
self.dec1 = self.conv_block(features * 2 + features, features)
# 输出
self.output = nn.Conv2d(features, out_channels, kernel_size=1)
def forward(self, x, t, lr_condition):
"""
x: 噪声图像 [B, 3, H, W]
t: 时间步 [B]
lr_condition: LR 条件图像 [B, 3, H, W]
"""
# 拼接条件
x = torch.cat([x, lr_condition], dim=1) # [B, 6, H, W]
# 时间嵌入
t_emb = self.time_mlp(t)
# 编码
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
# 瓶颈
b = self.bottleneck(e3)
# 解码
d3 = self.dec3(torch.cat([b, e3], dim=1))
d2 = self.dec2(torch.cat([d3, e2], dim=1))
d1 = self.dec1(torch.cat([d2, e1], dim=1))
# 预测噪声
noise_pred = self.output(d1)
return noise_pred
class SR3:
"""
SR3 超分辨率模型
"""
def __init__(self, model, num_timesteps=1000):
self.model = model
self.num_timesteps = num_timesteps
# 噪声调度
self.betas = self.cosine_beta_schedule(num_timesteps)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
def forward_diffusion(self, hr_image, t, noise=None):
"""
正向扩散: q(x_t | x_0)
x_t = √(ᾱ_t) * x_0 + √(1 - ᾱ_t) * ε
"""
if noise is None:
noise = torch.randn_like(hr_image)
alpha_cumprod = self.alphas_cumprod[t].view(-1, 1, 1, 1)
noisy_image = torch.sqrt(alpha_cumprod) * hr_image + \
torch.sqrt(1 - alpha_cumprod) * noise
return noisy_image, noise
def reverse_diffusion(self, lr_image, num_steps=100):
"""
逆向扩散: p(x_{t-1} | x_t, lr)
从纯噪声开始,逐步去噪
"""
# 初始噪声
x = torch.randn_like(lr_image)
for t in reversed(range(num_steps)):
# 预测噪声
noise_pred = self.model(x, t, lr_image)
# 去噪一步
x = self.denoise_step(x, noise_pred, t)
return x
def train_step(self, lr_image, hr_image):
"""
训练步骤
1. 随机采样时间步 t
2. 对 HR 图像添加噪声得到 x_t
3. 模型预测噪声
4. 计算 MSE 损失
"""
batch_size = lr_image.shape[0]
# 随机时间步
t = torch.randint(0, self.num_timesteps, (batch_size,))
# 随机噪声
noise = torch.randn_like(hr_image)
# 正向扩散
noisy_hr, noise = self.forward_diffusion(hr_image, t, noise)
# 预测噪声
noise_pred = self.model(noisy_hr, t, lr_image)
# 损失
loss = F.mse_loss(noise_pred, noise)
return loss
5.2 DiffIR(2023)
DiffIR: Effective Diffusion Model for Image Restoration
论文: "DiffIR: Effective Diffusion Model for Image Restoration" (Xia et al., 2023)
改进:
1. 两阶段训练
2. 更高效的扩散过程
3. 减少采样步数
python
class DiffIR:
"""
两阶段 DiffIR
阶段 1: 训练预测器 (Predictor)
- 输入退化图像
- 输出粗略的恢复结果
阶段 2: 训练扩散精炼器 (Refiner)
- 以粗略结果为条件
- 精细恢复高频细节
"""
def __init__(self):
self.predictor = UNet(in_channels=3, out_channels=3)
self.refiner = ConditionalUNet(in_channels=6, out_channels=3)
self.diffusion = GaussianDiffusion(num_timesteps=100)
def train_predictor(self, degraded, clean):
"""阶段 1: 训练预测器"""
pred = self.predictor(degraded)
loss = F.l1_loss(pred, clean)
return loss
def train_refiner(self, degraded, clean):
"""阶段 2: 训练扩散精炼器"""
# 先用预测器得到粗略结果
with torch.no_grad():
coarse = self.predictor(degraded)
# 扩散精炼
t = torch.randint(0, 100, (degraded.shape[0],))
noise = torch.randn_like(clean)
noisy_clean = self.diffusion.forward_process(clean, t, noise)
# 条件: 粗略结果 + 退化图像
condition = torch.cat([coarse, degraded], dim=1)
# 预测噪声
noise_pred = self.refiner(noisy_clean, t, condition)
loss = F.mse_loss(noise_pred, noise)
return loss
def inference(self, degraded):
"""推理"""
# 阶段 1: 粗略恢复
coarse = self.predictor(degraded)
# 阶段 2: 扩散精炼
condition = torch.cat([coarse, degraded], dim=1)
refined = self.diffusion.sample(condition, num_steps=10)
return refined
6. 视频超分辨率
6.1 时序对齐
python
class TemporalAlignment(nn.Module):
"""
时序对齐模块
问题: 相邻帧之间存在运动,需要对齐后才能融合
方法:
1. 光流估计
2. 可变形卷积
3. 注意力对齐
"""
pass
class OpticalFlowAlignment(nn.Module):
"""
基于光流的时序对齐
流程:
1. 估计相邻帧之间的光流
2. 使用光流对齐帧
3. 对齐后进行特征融合
"""
def __init__(self):
super().__init__()
self.flow_estimator = FlowNet() # 光流估计网络
def align(self, reference, target):
"""
将 target 帧对齐到 reference 帧
"""
# 估计光流
flow = self.flow_estimator(reference, target)
# 使用光流进行变形
aligned = self.warp(target, flow)
return aligned
def warp(self, image, flow):
"""
使用光流变形图像
使用双线性插值实现可微分的变形
"""
B, C, H, W = image.shape
# 生成网格
grid_y, grid_x = torch.meshgrid(
torch.arange(H), torch.arange(W), indexing='ij'
)
grid = torch.stack([grid_x, grid_y], dim=0).float()
# 加上光流偏移
grid = grid + flow
grid = grid.permute(0, 2, 3, 1)
# 归一化到 [-1, 1]
grid[..., 0] = 2 * grid[..., 0] / (W - 1) - 1
grid[..., 1] = 2 * grid[..., 1] / (H - 1) - 1
# 双线性采样
aligned = F.grid_sample(image, grid, mode='bilinear',
padding_mode='border', align_corners=True)
return aligned
class DeformableConvAlignment(nn.Module):
"""
可变形卷积对齐
优势: 学习采样位置,比固定光流更灵活
"""
def __init__(self, channels, kernel_size=3):
super().__init__()
# 学习偏移量
self.offset_conv = nn.Conv2d(
channels * 2, # 输入: 参考帧 + 目标帧
2 * kernel_size * kernel_size, # 输出: x, y 偏移
kernel_size=3, padding=1
)
# 可变形卷积
self.deform_conv = DeformConv2d(
channels, channels,
kernel_size=kernel_size, padding=kernel_size // 2
)
def forward(self, reference, target):
# 计算偏移量
offset = self.offset_conv(torch.cat([reference, target], dim=1))
# 可变形卷积
aligned = self.deform_conv(target, offset)
return aligned
6.2 视频 SR 网络架构
python
class BasicVSR(nn.Module):
"""
BasicVSR: 基于双向传播的视频超分辨率
论文: "BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond" (Chan et al., 2021)
核心思想:
1. 双向时序传播 (前向 + 后向)
2. 光流对齐
3. 残差块特征提取
"""
def __init__(self, num_channels=3, num_features=64, num_blocks=30):
super().__init__()
# 特征提取器
self.feat_extract = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
# 光流估计
self.flow_estimator = SpyNet()
# 传播模块 (双向)
self.backward_resblocks = nn.Sequential(*[
ResidualBlock(num_features) for _ in range(num_blocks)
])
self.forward_resblocks = nn.Sequential(*[
ResidualBlock(num_features) for _ in range(num_blocks)
])
# 融合和重建
self融合 = nn.Conv2d(num_features * 2, num_features, kernel_size=1)
self.upsample = UpsampleModule(num_features, scale_factor=4)
self重建 = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
def forward(self, lr_frames):
"""
lr_frames: [B, T, C, H, W] (T 帧 LR 视频)
"""
B, T, C, H, W = lr_frames.shape
# 提取特征
features = []
for t in range(T):
feat = self.feat_estimate(lr_frames[:, t])
features.append(feat)
# 反向传播 (从最后一帧到第一帧)
backward_features = []
feat = features[-1]
for t in reversed(range(T)):
if t < T - 1:
# 光流对齐
flow = self.flow_estimator(lr_frames[:, t], lr_frames[:, t + 1])
aligned = self.warp(backward_features[-1], flow)
feat = feat + aligned
feat = self.backward_resblocks(feat)
backward_features.append(feat)
backward_features = backward_features[::-1]
# 前向传播
forward_features = []
feat = features[0]
for t in range(T):
if t > 0:
flow = self.flow_estimator(lr_frames[:, t], lr_frames[:, t - 1])
aligned = self.warp(forward_features[-1], flow)
feat = feat + aligned
# 融合双向特征
combined = torch.cat([feat, backward_features[t]], dim=1)
feat = self融合(combined)
feat = self.forward_resblocks(feat)
forward_features.append(feat)
# 重建 HR 帧
hr_frames = []
for feat in forward_features:
hr = self.upsample(feat)
hr = self重建(hr)
hr_frames.append(hr)
return torch.stack(hr_frames, dim=1)
class BasicVSRPlusPlus(nn.Module):
"""
BasicVSR++: 改进版本
改进:
1. 二次传播 (Second-Order Propagation)
2. 流引导可变形对齐
3. 更高效的特征融合
"""
pass
7. 损失函数与评估指标
7.1 损失函数
python
class SRLosses:
"""超分辨率损失函数集合"""
@staticmethod
def pixel_loss(pred, target, loss_type='l1'):
"""
像素级损失
L1: ‖pred - target‖₁
L2: ‖pred - target‖₂²
L1 优势: 边缘更清晰,对异常值更鲁棒
L2 优势: 优化更稳定,PSNR 更高
"""
if loss_type == 'l1':
return F.l1_loss(pred, target)
elif loss_type == 'l2':
return F.mse_loss(pred, target)
elif loss_type == 'charbonnier':
# 平滑 L1,对小误差更敏感
eps = 1e-6
return torch.mean(torch.sqrt((pred - target) ** 2 + eps))
@staticmethod
def perceptual_loss(pred, target, vgg_model, layer_weights=None):
"""
感知损失 (Perceptual Loss)
在 VGG 特征空间计算距离,而非像素空间
生成更符合人类视觉感知的结果
"""
if layer_weights is None:
layer_weights = {
'relu1_2': 1.0,
'relu2_2': 1.0,
'relu3_3': 1.0,
'relu4_3': 1.0
}
loss = 0
pred_features = vgg_model(pred)
target_features = vgg_model(target)
for layer, weight in layer_weights.items():
loss += weight * F.l1_loss(
pred_features[layer],
target_features[layer]
)
return loss
@staticmethod
def style_loss(pred, target, vgg_model):
"""
风格损失 (Style Loss / Gram Loss)
匹配特征的 Gram 矩阵,保持纹理风格
"""
def gram_matrix(features):
B, C, H, W = features.shape
features = features.view(B, C, -1)
gram = torch.bmm(features, features.transpose(1, 2))
return gram / (C * H * W)
pred_features = vgg_model(pred)
target_features = vgg_model(target)
loss = 0
for layer in pred_features:
pred_gram = gram_matrix(pred_features[layer])
target_gram = gram_matrix(target_features[layer])
loss += F.mse_loss(pred_gram, target_gram)
return loss
@staticmethod
def adversarial_loss(discriminator_output, mode='original'):
"""
对抗损失
原始 GAN: -log(D(G(x)))
LSGAN: (D(G(x)) - 1)²
Hinge: max(0, 1 - D(G(x)))
"""
if mode == 'original':
return F.binary_cross_entropy_with_logits(
discriminator_output,
torch.ones_like(discriminator_output)
)
elif mode == 'lsgan':
return F.mse_loss(
discriminator_output,
torch.ones_like(discriminator_output)
)
elif mode == 'hinge':
return -discriminator_output.mean()
@staticmethod
def frequency_loss(pred, target, alpha=1.0):
"""
频域损失
在频域约束高频细节的恢复
"""
# FFT
pred_fft = torch.fft.fft2(pred)
target_fft = torch.fft.fft2(target)
# 频域 L1 损失
loss = F.l1_loss(
torch.abs(pred_fft),
torch.abs(target_fft)
)
# 可以加权高频部分
# 高频对应 FFT 的边缘区域
return loss * alpha
7.2 评估指标
python
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
class SRMetrics:
"""超分辨率评估指标"""
@staticmethod
def calculate_psnr(pred, target, max_val=1.0):
"""
PSNR (Peak Signal-to-Noise Ratio)
公式: PSNR = 10 * log10(MAX² / MSE)
范围: 通常 20-40 dB,越高越好
局限: 不完全符合人类视觉感知
"""
mse = np.mean((pred - target) ** 2)
if mse == 0:
return float('inf')
return 10 * np.log10(max_val ** 2 / mse)
@staticmethod
def calculate_ssim(pred, target, window_size=11):
"""
SSIM (Structural Similarity Index)
考虑亮度、对比度、结构三个方面
范围: [-1, 1],越接近 1 越好
优势: 比 PSNR 更符合人类感知
"""
return ssim(pred, target, data_range=1.0,
win_size=window_size, channel_axis=-1)
@staticmethod
def calculate_lpips(pred, target, net='alex'):
"""
LPIPS (Learned Perceptual Image Patch Similarity)
使用预训练网络计算感知距离
范围: [0, 1],越低越好
优势: 最符合人类感知的指标
"""
import lpips
loss_fn = lpips.LPIPS(net=net)
# 转换为 tensor
pred_t = torch.from_numpy(pred).permute(2, 0, 1).unsqueeze(0).float()
target_t = torch.from_numpy(target).permute(2, 0, 1).unsqueeze(0).float()
return loss_fn(pred_t, target_t).item()
@staticmethod
def calculate_niqe(image):
"""
NIQE (Natural Image Quality Evaluator)
无参考图像质量评估
无需 ground truth,基于自然图像统计
范围: 越低越好
"""
# 提取自然场景统计特征
features = extract_niqe_features(image)
# 与自然图像分布比较
niqe_score = compute_niqe_score(features)
return niqe_score
# 指标对比
"""
┌─────────────────────────────────────────────────────────────┐
│ 评估指标对比 │
├──────────┬─────────────┬──────────────┬─────────────────────┤
│ 指标 │ 需要 GT │ 感知相关性 │ 适用场景 │
├──────────┼─────────────┼──────────────┼─────────────────────┤
│ PSNR │ ✓ │ 低 │ 客观质量评估 │
│ SSIM │ ✓ │ 中 │ 结构保持评估 │
│ LPIPS │ ✓ │ 高 │ 感知质量评估 │
│ FID │ ✓(数据集) │ 高 │ 生成质量评估 │
│ NIQE │ ✗ │ 中 │ 无参考评估 │
└──────────┴─────────────┴──────────────┴─────────────────────┘
"""
8. 工程实践与部署
8.1 模型选择指南
┌─────────────────────────────────────────────────────────────┐
│ 模型选择决策树 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 需求分析: │
│ ├── 实时性要求高? │
│ │ ├── 是 → ESPCN, FSRCNN (轻量级) │
│ │ └─ 否 → 继续评估 │
│ │ │
│ ├── 视觉质量优先? │
│ │ ├── 是 → ESRGAN, Real-ESRGAN (GAN 方法) │
│ │ └─ 否 → 继续评估 │
│ │ │
│ ├── PSNR 指标优先? │
│ │ ├── 是 → EDSR, RCAN, SwinIR (回归方法) │
│ │ └─ 否 → 继续评估 │
│ │ │
│ ├── 真实世界退化? │
│ │ ├── 是 → Real-ESRGAN, DiffIR (盲 SR) │
│ │ └─ 否 → 标准方法 │
│ │ │
│ └── 资源受限? │
│ ├── 是 → 轻量级模型 + 知识蒸馏 │
│ └─ 否 → 大模型 + 集成 │
│ │
└─────────────────────────────────────────────────────────────┘
8.2 模型部署
python
class SRModelDeployer:
"""超分辨率模型部署工具"""
@staticmethod
def export_onnx(model, input_shape, save_path):
"""
导出 ONNX 格式
"""
dummy_input = torch.randn(input_shape)
torch.onnx.export(
model,
dummy_input,
save_path,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 2: 'height', 3: 'width'}
}
)
@staticmethod
def optimize_tensorrt(onnx_path, engine_path, fp16=True):
"""
TensorRT 优化
"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
# 解析 ONNX
with open(onnx_path, 'rb') as f:
parser.parse(f.read())
# 配置
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
# 构建引擎
engine = builder.build_engine(network, config)
# 保存
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
@staticmethod
def tile_process(model, image, tile_size=512, tile_overlap=32):
"""
分块处理大图像
问题: 大图像无法一次性放入显存
解决: 分块处理,重叠区域渐入渐出融合
"""
h, w = image.shape[:2]
# 计算分块
tiles = []
for y in range(0, h, tile_size - tile_overlap):
for x in range(0, w, tile_size - tile_overlap):
# 边界处理
y_end = min(y + tile_size, h)
x_end = min(x + tile_size, w)
y_start = max(0, y_end - tile_size)
x_start = max(0, x_end - tile_size)
tile = image[y_start:y_end, x_start:x_end]
tiles.append({
'tile': tile,
'position': (y_start, x_start, y_end, x_end)
})
# 处理每个块
result = np.zeros_like(image)
weight_map = np.zeros_like(image)
for tile_info in tiles:
tile = tile_info['tile']
pos = tile_info['position']
# 模型推理
sr_tile = model(tile)
# 渐入渐出权重
weight = create_weight_map(tile.shape, tile_overlap)
# 累加
result[pos[0]:pos[2], pos[1]:pos[3]] += sr_tile * weight
weight_map[pos[0]:pos[2], pos[1]:pos[3]] += weight
# 归一化
result = result / (weight_map + 1e-8)
return result
8.3 训练技巧
python
class SRTrainingTricks:
"""超分辨率训练技巧"""
@staticmethod
def progressive_training(model, dataset, scale_factors=[2, 3, 4]):
"""
渐进式训练
先训练小倍数,再微调大倍数
有助于稳定训练,提升性能
"""
for scale in scale_factors:
print(f"Training with scale factor ×{scale}")
# 调整模型上采样倍数
model.set_scale_factor(scale)
# 准备对应尺度的数据
train_loader = create_dataloader(dataset, scale)
# 训练
train(model, train_loader, epochs=100)
@staticmethod
def patch_training(hr_images, patch_size=192, batch_size=16):
"""
图像块训练
从 HR 图像中随机裁剪块,降低显存需求
"""
class PatchDataset(torch.utils.data.Dataset):
def __init__(self, hr_images, patch_size, scale_factor):
self.hr_images = hr_images
self.patch_size = patch_size
self.scale_factor = scale_factor
self.lr_patch_size = patch_size // scale_factor
def __getitem__(self, idx):
# 随机选择图像
img = self.hr_images[idx % len(self.hr_images)]
# 随机裁剪 HR 块
h, w = img.shape[:2]
y = np.random.randint(0, h - self.patch_size)
x = np.random.randint(0, w - self.patch_size)
hr_patch = img[y:y+self.patch_size, x:x+self.patch_size]
# 下采样得到 LR 块
lr_patch = cv2.resize(
hr_patch,
(self.lr_patch_size, self.lr_patch_size),
interpolation=cv2.INTER_CUBIC
)
# 随机翻转和旋转 (数据增强)
lr_patch, hr_patch = self.augment(lr_patch, hr_patch)
return lr_patch, hr_patch
def augment(self, lr, hr):
# 随机水平翻转
if np.random.random() > 0.5:
lr = np.flip(lr, axis=1).copy()
hr = np.flip(hr, axis=1).copy()
# 随机垂直翻转
if np.random.random() > 0.5:
lr = np.flip(lr, axis=0).copy()
hr = np.flip(hr, axis=0).copy()
# 随机 90 度旋转
k = np.random.randint(0, 4)
lr = np.rot90(lr, k).copy()
hr = np.rot90(hr, k).copy()
return lr, hr
return PatchDataset(hr_images, patch_size, scale_factor=4)
@staticmethod
def self_ensemble(model, lr_image):
"""
自集成 (Self-Ensemble)
对输入进行多种变换,推理后取平均
可提升 0.1-0.3 dB
"""
def augment_transform(img, mode):
if mode == 0: return img
elif mode == 1: return np.flip(img, axis=0).copy()
elif mode == 2: return np.flip(img, axis=1).copy()
elif mode == 3: return np.rot90(img, k=1).copy()
elif mode == 4: return np.rot90(img, k=2).copy()
elif mode == 5: return np.rot90(img, k=3).copy()
elif mode == 6: return np.flip(np.rot90(img, k=1), axis=0).copy()
elif mode == 7: return np.flip(np.rot90(img, k=1), axis=1).copy()
def deaugment_transform(img, mode):
if mode == 0: return img
elif mode == 1: return np.flip(img, axis=0).copy()
elif mode == 2: return np.flip(img, axis=1).copy()
elif mode == 3: return np.rot90(img, k=-1).copy()
elif mode == 4: return np.rot90(img, k=-2).copy()
elif mode == 5: return np.rot90(img, k=-3).copy()
elif mode == 6: return np.rot90(np.flip(img, axis=0), k=-1).copy()
elif mode == 7: return np.rot90(np.flip(img, axis=1), k=-1).copy()
# 8 种变换
results = []
for mode in range(8):
augmented = augment_transform(lr_image, mode)
sr = model(augmented)
original = deaugment_transform(sr, mode)
results.append(original)
# 取平均
return np.mean(results, axis=0)
附录
A. 算法发展时间线
2014 ──┬── SRCNN (首次将 CNN 应用于 SR)
│
2015 ──┼── DRCN (递归网络)
│
2016 ──┼── FSRCNN (加速版) / ESPCN (亚像素卷积) / VDSR (残差学习)
│
2017 ──┼── SRGAN (GAN 方法) / EDSR (增强残差)
│
2018 ──┼── ESRGAN (增强 GAN) / RDN (密集连接) / RCAN (通道注意力)
│
2019 ──┼── SRFBN (反馈网络) / RankSRGAN (感知指标优化)
│
2020 ──┼── HAN (分层注意力) / RRDB (广泛使用)
│
2021 ──┼── SwinIR (Transformer) / Real-ESRGAN (真实世界) / SR3 (扩散)
│
2022 ──┼── StableSR (稳定扩散) / BasicVSR++ (视频 SR)
│
2023 ──┼── DiffIR (高效扩散) / SeeSR (语义引导)
│
2024+ ──┴── 持续发展: 更高效架构、更真实退化、多模态融合
B. 常用数据集
| 数据集 | 图像数量 | 用途 | 特点 |
|---|---|---|---|
| DIV2K | 1000 | 训练/验证 | 高质量,广泛使用 |
| Set5 | 5 | 测试 | 经典测试集 |
| Set14 | 14 | 测试 | 多样化场景 |
| BSD100 | 100 | 测试 | 自然图像 |
| Urban100 | 100 | 测试 | 城市建筑,重复结构 |
| Manga109 | 109 | 测试 | 日本漫画 |
| Flickr2K | 2650 | 训练 | 高分辨率真实图像 |
| OST | 10000+ | 训练 | 大规模场景数据集 |