超分辨率算法深度解析(Super-Resolution Algorithms)

目录

  1. 问题定义与数学建模
  2. 传统方法
  3. 基于深度学习的方法
  4. 生成对抗网络方法
  5. 扩散模型方法
  6. 视频超分辨率
  7. 损失函数与评估指标
  8. 工程实践与部署

1. 问题定义与数学建模

1.1 问题描述

超分辨率(Super-Resolution, SR)是从低分辨率(Low-Resolution, LR)图像重建高分辨率(High-Resolution, HR)图像的逆问题。

复制代码
                    超分辨率重建
                    
  LR 输入 (64×64)              HR 输出 (256×256)
  ┌──────────────┐             ┌────────────────────────────┐
  │              │             │                            │
  │   模糊的     │    SR       │     清晰的                 │
  │   低分辨率   │  ──────►    │     高分辨率               │
  │   图像       │   算法      │     图像                   │
  │              │             │                            │
  └──────────────┘             │                            │
                               │                            │
                               └────────────────────────────┘
                               
  放大倍数: ×4 (面积放大 16 倍)

1.2 数学模型

图像降质模型
复制代码
LR 图像的生成过程:

  y = D(H(x)) + n

  其中:
    x ∈ ℝᴴˣᵂ     --- 高分辨率图像
    H              --- 模糊算子(光学模糊、运动模糊等)
    D              --- 下采样算子(降采样因子 s)
    n              --- 噪声(高斯噪声、泊松噪声等)
    y ∈ ℝʰˣʷ     --- 低分辨率图像 (h = H/s, w = W/s)
逆问题求解
复制代码
超分辨率的目标:

  x̂ = argmin_x ‖y - D(H(x))‖² + λ·R(x)
              ─────────────────   ──────
                  数据保真项       正则化项

  正则化项 R(x) 的作用:
    - 引入先验知识
    - 约束解空间
    - 抑制噪声放大

1.3 问题分类

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    超分辨率问题分类                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  按输入数量:                                                 │
│  ├── 单帧超分辨率 (Single Image SR, SISR)                   │
│  │   └─ 从单张 LR 图像重建 HR                               │
│  └─ 多帧超分辨率 (Multi-Frame SR)                           │
│      └─ 利用多帧 LR 图像的互补信息                           │
│                                                             │
│  按放大倍数:                                                 │
│  ├── 固定倍数 (×2, ×4, ×8)                                  │
│  └─ 任意倍数 (Arbitrary Scale)                              │
│                                                             │
│  按退化类型:                                                 │
│  ├── 理想退化 (Bicubic 下采样)                               │
│  ├── 盲超分辨率 (Blind SR, 未知退化)                         │
│  └─ 真实世界超分辨率 (Real-World SR)                         │
│                                                             │
│  按应用领域:                                                 │
│  ├── 自然图像超分辨率                                        │
│  ├── 人脸超分辨率 (Face SR)                                  │
│  ├── 遥感图像超分辨率                                        │
│  └─ 医学图像超分辨率                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 传统方法

2.1 基于插值的方法

最近邻插值(Nearest Neighbor)
python 复制代码
def nearest_neighbor_interpolation(lr_image, scale_factor):
    """
    原理: 将每个 LR 像素直接复制到对应 HR 区域
    
    优点: 计算简单,速度快
    缺点: 产生明显的锯齿和块状伪影
    """
    h, w = lr_image.shape[:2]
    H, W = h * scale_factor, w * scale_factor
    hr_image = np.zeros((H, W), dtype=lr_image.dtype)
    
    for i in range(H):
        for j in range(W):
            # 映射回 LR 坐标
            src_i = min(int(i / scale_factor), h - 1)
            src_j = min(int(j / scale_factor), w - 1)
            hr_image[i, j] = lr_image[src_i, src_j]
    
    return hr_image
双线性插值(Bilinear Interpolation)
python 复制代码
def bilinear_interpolation(lr_image, scale_factor):
    """
    原理: 在两个方向上分别进行线性插值
    
    公式: f(x,y) = f(0,0)(1-x)(1-y) + f(1,0)x(1-y) 
                 + f(0,1)(1-x)y + f(1,1)xy
    
    优点: 比最近邻平滑,计算量适中
    缺点: 过度平滑,丢失高频细节
    """
    h, w = lr_image.shape[:2]
    H, W = h * scale_factor, w * scale_factor
    hr_image = np.zeros((H, W), dtype=np.float64)
    
    for i in range(H):
        for j in range(W):
            # 映射到 LR 坐标系
            x = j / scale_factor
            y = i / scale_factor
            
            # 四个最近邻像素坐标
            x0, y0 = int(x), int(y)
            x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1)
            
            # 插值权重
            dx, dy = x - x0, y - y0
            
            # 双线性插值计算
            hr_image[i, j] = (
                lr_image[y0, x0] * (1 - dx) * (1 - dy) +
                lr_image[y0, x1] * dx * (1 - dy) +
                lr_image[y1, x0] * (1 - dx) * dy +
                lr_image[y1, x1] * dx * dy
            )
    
    return hr_image
双三次插值(Bicubic Interpolation)
python 复制代码
def bicubic_kernel(x, a=-0.5):
    """
    双三次插值核函数 (Keys 核)
    
    a = -0.5 时为标准 Catmull-Rom 样条
    """
    x = abs(x)
    if x <= 1:
        return (a + 2) * x**3 - (a + 3) * x**2 + 1
    elif x < 2:
        return a * x**3 - 5 * a * x**2 + 8 * a * x - 4 * a
    else:
        return 0

def bicubic_interpolation(lr_image, scale_factor):
    """
    原理: 使用 4×4 邻域的加权求和
    
    优点: 边缘保持较好,视觉质量高
    缺点: 计算量较大,仍会丢失高频信息
    
    这是大多数 SR 论文的默认退化方式
    """
    h, w = lr_image.shape[:2]
    H, W = h * scale_factor, w * scale_factor
    hr_image = np.zeros((H, W), dtype=np.float64)
    
    for i in range(H):
        for j in range(W):
            x = j / scale_factor
            y = i / scale_factor
            
            x0, y0 = int(x), int(y)
            
            pixel = 0.0
            for m in range(-1, 3):
                for n in range(-1, 3):
                    # 边界处理
                    px = min(max(x0 + m, 0), w - 1)
                    py = min(max(y0 + n, 0), h - 1)
                    
                    # 权重 = W(x) * W(y)
                    wx = bicubic_kernel(x - (x0 + m))
                    wy = bicubic_kernel(y - (y0 + n))
                    
                    pixel += lr_image[py, px] * wx * wy
            
            hr_image[i, j] = np.clip(pixel, 0, 255)
    
    return hr_image

2.2 基于重建的方法

稀疏编码(Sparse Coding)
python 复制代码
class SparseCodingSR:
    """
    原理: LR 和 HR 图像块共享相同的稀疏表示
    
    流程:
      1. 训练 LR-HR 字典对 (D_l, D_h)
      2. 对 LR 图像块求稀疏系数 α
      3. 用 α 和 D_h 重建 HR 图像块
    """
    def __init__(self, patch_size=5, dict_size=1024, sparsity=3):
        self.patch_size = patch_size
        self.dict_size = dict_size
        self.sparsity = sparsity
        
    def train_dictionary(self, lr_patches, hr_patches):
        """
        联合字典学习
        
        目标: min ‖Y_l - D_l·α‖² + ‖Y_h - D_h·α‖² + λ‖α‖₁
        """
        # 使用 K-SVD 或在线字典学习
        from sklearn.decomposition import MiniBatchDictionaryLearning
        
        # LR 字典
        self.dict_lr = MiniBatchDictionaryLearning(
            n_components=self.dict_size, alpha=1.0
        ).fit(lr_patches)
        
        # HR 字典 (使用相同的稀疏系数)
        self.dict_hr = MiniBatchDictionaryLearning(
            n_components=self.dict_size, alpha=1.0
        ).fit(hr_patches)
        
    def reconstruct(self, lr_image, scale_factor):
        """
        超分辨率重建
        """
        # 提取 LR 图像块
        lr_patches = extract_patches(lr_image, self.patch_size)
        
        # 求稀疏系数
        sparse_codes = self.dict_lr.transform(lr_patches)
        
        # 用 HR 字典重建
        hr_patches = np.dot(sparse_codes, self.dict_hr.components_)
        
        # 聚合重叠的图像块
        hr_image = aggregate_patches(hr_patches, scale_factor)
        
        return hr_image
自相似性方法(Self-Example)
复制代码
核心思想: 图像内部存在跨尺度的相似结构

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│   原图中的小结构  ────────相似───────►  放大后的结构          │
│                                                             │
│   ┌───┐                                    ┌─────────┐     │
│   │ ▪ │   在原图的不同尺度搜索相似块        │ ▪  ▪     │     │
│   └───┘   ─────────────────────────►       │         │     │
│             利用高分辨率细节填充            │ ▪  ▪     │     │
│                                            └─────────┘     │
│                                                             │
│   代表方法: Glasner et al. (2009)                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.3 基于边缘的方法

python 复制代码
class EdgeDirectedSR:
    """
    核心思想: 先重建边缘,再以边缘为引导重建全图
    
    流程:
      1. 检测 LR 图像边缘
      2. 预测 HR 边缘方向和强度
      3. 沿边缘方向插值
      4. 以边缘为约束优化全图
    """
    def reconstruct(self, lr_image, scale_factor):
        # Step 1: 边缘检测
        edges_lr = self.detect_edges(lr_image)
        
        # Step 2: 边缘方向估计
        directions = self.estimate_edge_directions(edges_lr)
        
        # Step 3: 方向自适应插值
        hr_image = self.directional_interpolation(
            lr_image, directions, scale_factor
        )
        
        # Step 4: 边缘引导优化
        hr_image = self.edge_guided_optimization(
            hr_image, edges_lr, scale_factor
        )
        
        return hr_image

3. 基于深度学习的方法

3.1 SRCNN(2014)--- 开山之作

复制代码
SRCNN: Super-Resolution Convolutional Neural Network
论文: "Image Super-Resolution Using Deep Convolutional Networks" (Dong et al., 2014)

架构:
  LR 输入 → [Bicubic 上采样] → [特征提取] → [非线性映射] → [重建] → HR 输出

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│   LR Image ──► Bicubic ──► Conv1 ──► Conv2 ──► Conv3 ──► HR│
│   (低分辨率)     (上采样)    (9×9)     (1×1)     (5×5)      │
│                                                             │
│   输入: 上采样后的 LR 图像 (已插值到目标尺寸)                │
│   输出: 高分辨率图像                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘
python 复制代码
import torch
import torch.nn as nn

class SRCNN(nn.Module):
    """
    三层 CNN 实现超分辨率
    
    Layer 1: 特征提取 (patch extraction)
    Layer 2: 非线性映射 (non-linear mapping)
    Layer 3: 重建 (reconstruction)
    """
    def __init__(self, num_channels=1):
        super().__init__()
        
        # 特征提取层
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=4)
        
        # 非线性映射层
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        
        # 重建层
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2)
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        """
        前向传播
        
        注意: 输入需要先经过 Bicubic 上采样到目标尺寸
        """
        # 特征提取
        x = self.relu(self.conv1(x))       # [B, 64, H, W]
        
        # 非线性映射
        x = self.relu(self.conv2(x))       # [B, 32, H, W]
        
        # 重建
        x = self.conv3(x)                  # [B, C, H, W]
        
        return x

# 训练配置
"""
损失函数: MSE (像素级)
优化器: SGD 或 Adam
评估: PSNR / SSIM

局限性:
  - 先上采样再处理,计算量大
  - 感受野有限
  - 难以恢复高频细节
"""

3.2 FSRCNN(2016)--- 加速版本

复制代码
FSRCNN: Fast Super-Resolution CNN
论文: "Accelerating the Super-Resolution Convolutional Neural Network" (Dong et al., 2016)

改进:
  1. 直接在 LR 空间操作(不上采样)
  2. 末端使用转置卷积进行上采样
  3. 使用更小的卷积核和更多层

架构:
  LR → [特征提取] → [收缩] → [映射] × d → [扩展] → [反卷积上采样] → HR
python 复制代码
class FSRCNN(nn.Module):
    def __init__(self, scale_factor=4, num_channels=1, d=56, s=12, m=4):
        super().__init__()
        
        # 特征提取
        self.feature_extraction = nn.Conv2d(num_channels, d, kernel_size=5, padding=2)
        
        # 收缩层 (降维)
        self.shrinking = nn.Conv2d(d, s, kernel_size=1)
        
        # 非线性映射 (多层)
        self.mapping = nn.Sequential(*[
            nn.Sequential(
                nn.Conv2d(s, s, kernel_size=3, padding=1),
                nn.PReLU()
            ) for _ in range(m)
        ])
        
        # 扩展层 (升维)
        self.expanding = nn.Conv2d(s, d, kernel_size=1)
        
        # 反卷积上采样 (核心改进)
        self.deconv = nn.ConvTranspose2d(
            d, num_channels, 
            kernel_size=9, 
            stride=scale_factor, 
            padding=4,
            output_padding=scale_factor - 1
        )
        
        self.prelu = nn.PReLU()
    
    def forward(self, x):
        # 直接在 LR 空间操作
        x = self.prelu(self.feature_extraction(x))  # [B, d, h, w]
        x = self.prelu(self.shrinking(x))            # [B, s, h, w]
        x = self.mapping(x)                          # [B, s, h, w]
        x = self.prelu(self.expanding(x))            # [B, d, h, w]
        x = self.deconv(x)                           # [B, C, H, W]
        return x

"""
速度对比:
  - SRCNN: 0.43s (×4, 256×256)
  - FSRCNN: 0.015s (×4, 256×256)
  - 加速约 28 倍
"""

3.3 ESPCN(2016)--- 亚像素卷积

复制代码
ESPCN: Efficient Sub-Pixel CNN
论文: "Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel 
      Convolutional Neural Network" (Shi et al., 2016)

核心创新: 亚像素卷积层 (Sub-Pixel Convolution / PixelShuffle)

原理:
  将 r² 个通道的特征图重排为 r×r 的高分辨率图像
  
  [B, C×r², H, W] → [B, C, H×r, W×r]
python 复制代码
class PixelShuffle(nn.Module):
    """
    亚像素卷积层
    
    输入: [B, C × r², H, W]
    输出: [B, C, H × r, W × r]
    
    原理: 将通道维度重排为空间维度
    """
    def __init__(self, upscale_factor):
        super().__init__()
        self.r = upscale_factor
    
    def forward(self, x):
        B, C, H, W = x.shape
        r = self.r
        
        # C = C_out × r²
        assert C % (r * r) == 0
        C_out = C // (r * r)
        
        # 重排: [B, C_out×r², H, W] → [B, C_out, H×r, W×r]
        return x.view(B, C_out, r, r, H, W).permute(0, 1, 4, 2, 5, 3).reshape(
            B, C_out, H * r, W * r
        )

class ESPCN(nn.Module):
    def __init__(self, scale_factor=4, num_channels=1):
        super().__init__()
        
        # 特征提取 (在 LR 空间)
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=2),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.Tanh()
        )
        
        # 亚像素上采样
        self.sub_pixel = nn.Sequential(
            nn.Conv2d(32, num_channels * scale_factor ** 2, kernel_size=3, padding=1),
            PixelShuffle(scale_factor)
        )
    
    def forward(self, x):
        x = self.feature_extraction(x)
        x = self.sub_pixel(x)
        return x

"""
优势:
  - 计算效率高: 上采样在最后一步完成
  - 实时性好: 适合视频处理
  - 无棋盘格伪影 (相比反卷积)
"""

3.4 VDSR(2016)--- 残差学习

复制代码
VDSR: Very Deep Super-Resolution
论文: "Accurate Image Super-Resolution Using Very Deep Convolutional Networks" 
      (Kim et al., 2016)

核心创新:
  1. 残差学习: 学习 HR - LR_bicubic 的残差
  2. 深层网络: 20 层卷积
  3. 高学习率: 残差学习允许更大的学习率
python 复制代码
class VDSR(nn.Module):
    """
    残差学习架构
    
    核心公式: HR = LR_bicubic + Residual
    
    网络只需学习残差部分,收敛更快
    """
    def __init__(self, num_channels=1, num_layers=20):
        super().__init__()
        
        layers = []
        
        # 第一层
        layers.append(nn.Conv2d(num_channels, 64, kernel_size=3, padding=1))
        layers.append(nn.ReLU(inplace=True))
        
        # 中间层 (18 层)
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(64, 64, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
        
        # 最后一层
        layers.append(nn.Conv2d(64, num_channels, kernel_size=3, padding=1))
        
        self.network = nn.Sequential(*layers)
        
        # 残差缩放因子 (可学习)
        self.residual_scale = nn.Parameter(torch.tensor(0.1))
    
    def forward(self, lr_bicubic):
        """
        输入: Bicubic 上采样后的 LR 图像
        输出: 超分辨率结果
        
        HR = LR_bicubic + scale × Network(LR_bicubic)
        """
        residual = self.network(lr_bicubic)
        return lr_bicubic + self.residual_scale * residual

"""
残差学习的优势:
  1. 梯度流更顺畅 (恒等映射)
  2. 收敛速度更快
  3. 可以训练更深的网络
  4. 学习残差比学习完整图像更容易
"""

3.5 EDSR(2017)--- 增强型残差网络

复制代码
EDSR: Enhanced Deep Residual Network
论文: "Enhanced Deep Residual Networks for Single Image Super-Resolution" 
      (Lim et al., 2017)

核心改进:
  1. 移除 BatchNorm (对 SR 任务有害)
  2. 简化残差块结构
  3. 使用缩放残差 (residual scaling)
python 复制代码
class ResidualBlock(nn.Module):
    """
    EDSR 残差块
    
    改进:
      - 移除 BN 层 (BN 会归一化特征,丢失范围信息)
      - 使用残差缩放 (×0.1) 稳定训练
    """
    def __init__(self, channels, residual_scale=0.1):
        super().__init__()
        
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        
        self.residual_scale = residual_scale
    
    def forward(self, x):
        residual = self.conv1(x)
        residual = self.relu(residual)
        residual = self.conv2(residual)
        
        return x + residual * self.residual_scale

class EDSR(nn.Module):
    def __init__(self, scale_factor=4, num_channels=3, 
                 num_features=256, num_blocks=32):
        super().__init__()
        
        # 浅层特征提取
        self.head = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
        
        # 深层特征提取 (残差块堆叠)
        self.body = nn.Sequential(*[
            ResidualBlock(num_features) for _ in range(num_blocks)
        ])
        
        # 上采样模块
        self.upsample = UpsampleModule(num_features, scale_factor)
        
        # 重建层
        self.tail = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        # 浅层特征
        shallow = self.head(x)
        
        # 深层特征 (带全局残差)
        deep = self.body(shallow)
        deep = deep + shallow  # 全局残差连接
        
        # 上采样
        upsampled = self.upsample(deep)
        
        # 重建
        output = self.tail(upsampled)
        
        return output

class UpsampleModule(nn.Module):
    """
    上采样模块: 亚像素卷积实现
    """
    def __init__(self, channels, scale_factor):
        super().__init__()
        
        if scale_factor == 2:
            self.up = nn.Sequential(
                nn.Conv2d(channels, channels * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2)
            )
        elif scale_factor == 4:
            self.up = nn.Sequential(
                nn.Conv2d(channels, channels * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.Conv2d(channels, channels * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2)
            )
    
    def forward(self, x):
        return self.up(x)

"""
EDSR 在 DIV2K 数据集上的表现 (PSNR):
  ×2: 34.65 dB
  ×3: 30.92 dB  
  ×4: 28.80 dB
"""

3.6 RDN(2018)--- 密集连接网络

复制代码
RDN: Residual Dense Network
论文: "Residual Dense Network for Image Super-Resolution" (Zhang et al., 2018)

核心创新:
  1. 残差密集块 (RDB): 块内密集连接
  2. 特征融合: 从所有 RDB 提取特征
  3. 层级特征利用
python 复制代码
class ResidualDenseBlock(nn.Module):
    """
    残差密集块 (RDB)
    
    特点: 每层都与前面所有层连接
    """
    def __init__(self, channels, growth_rate=32, num_layers=8):
        super().__init__()
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(
                nn.Conv2d(channels + i * growth_rate, growth_rate, 
                         kernel_size=3, padding=1)
            )
        
        # 1×1 卷积融合
        self融合 = nn.Conv2d(
            channels + num_layers * growth_rate, 
            channels, 
            kernel_size=1
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        features = [x]
        
        for layer in self.layers:
            out = self.relu(layer(torch.cat(features, dim=1)))
            features.append(out)
        
        # 融合所有特征
        fused = self融合(torch.cat(features, dim=1))
        
        # 局部残差
        return x + fused * 0.2

class RDN(nn.Module):
    def __init__(self, scale_factor=4, num_channels=3, 
                 num_features=64, num_blocks=16, growth_rate=32):
        super().__init__()
        
        # 浅层特征
        self.sfe1 = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
        self.sfe2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        
        # 残差密集块
        self.rdbs = nn.ModuleList([
            ResidualDenseBlock(num_features, growth_rate) 
            for _ in range(num_blocks)
        ])
        
        # 特征融合
        self.gff = nn.Sequential(
            nn.Conv2d(num_blocks * num_features, num_features, kernel_size=1),
            nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        )
        
        # 上采样和重建
        self.upsample = UpsampleModule(num_features, scale_factor)
        self重建 = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        # 浅层特征
        sfe1 = self.sfe1(x)
        sfe2 = self.sfe2(sfe1)
        
        # RDB 特征收集
        rdb_features = []
        out = sfe2
        for rdb in self.rdbs:
            out = rdb(out)
            rdb_features.append(out)
        
        # 全局特征融合 (带全局残差)
        fused = self.gff(torch.cat(rdb_features, dim=1))
        fused = fused + sfe1  # 全局残差
        
        # 上采样和重建
        output = self.upsample(fused)
        output = self重建(output)
        
        return output

3.7 RCAN(2018)--- 通道注意力

复制代码
RCAN: Residual Channel Attention Network
论文: "Image Super-Resolution Using Very Deep Residual Channel Attention Networks" 
      (Zhang et al., 2018)

核心创新:
  1. 通道注意力机制 (Channel Attention)
  2. 残差中的残差 (RIR) 结构
  3. 超深网络 (400+ 层)
python 复制代码
class ChannelAttention(nn.Module):
    """
    通道注意力模块
    
    原理: 学习每个通道的重要性权重
    方法: 全局池化 → 全连接 → Sigmoid
    """
    def __init__(self, channels, reduction=16):
        super().__init__()
        
        self.pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化
        
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.shape
        
        # 全局信息压缩
        y = self.pool(x).view(b, c)
        
        # 通道权重
        y = self.fc(y).view(b, c, 1, 1)
        
        # 加权
        return x * y.expand_as(x)

class RCAB(nn.Module):
    """
    残差通道注意力块
    """
    def __init__(self, channels, reduction=16):
        super().__init__()
        
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        
        self.ca = ChannelAttention(channels, reduction)
    
    def forward(self, x):
        residual = self.conv1(x)
        residual = self.relu(residual)
        residual = self.conv2(residual)
        
        # 通道注意力
        residual = self.ca(residual)
        
        return x + residual

class RCAN(nn.Module):
    """
    残差中的残差网络 (RIR)
    
    结构:
      Residual Group → Residual Block → Channel Attention
    """
    def __init__(self, scale_factor=4, num_channels=3, 
                 num_features=64, num_groups=10, num_blocks=20):
        super().__init__()
        
        # 浅层特征
        self.head = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
        
        # 残差组
        self.groups = nn.ModuleList([
            ResidualGroup(num_features, num_blocks) 
            for _ in range(num_groups)
        ])
        
        # 组间融合
        self融合 = nn.Conv2d(num_features * num_groups, num_features, kernel_size=1)
        
        # 全局残差
        self.tail_conv = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        
        # 上采样
        self.upsample = UpsampleModule(num_features, scale_factor)
        self重建 = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        head = self.head(x)
        
        # 收集各组特征
        group_outputs = []
        out = head
        for group in self.groups:
            out = group(out)
            group_outputs.append(out)
        
        # 融合 + 全局残差
        fused = self融合(torch.cat(group_outputs, dim=1))
        out = self.tail_conv(fused) + head
        
        # 上采样重建
        out = self.upsample(out)
        out = self重建(out)
        
        return out

"""
通道注意力的作用:
  - 不同通道捕获不同的特征 (边缘、纹理、颜色等)
  - 注意力机制让网络聚焦于重要通道
  - 类似于 SENet 的思想,但应用于 SR 任务
"""

3.8 SwinIR(2021)--- Transformer 架构

复制代码
SwinIR: Swin Transformer for Image Restoration
论文: "SwinIR: Image Restoration Using Swin Transformer" (Liang et al., 2021)

核心创新:
  1. 将 Swin Transformer 应用于图像恢复
  2. 移位窗口注意力 (Shifted Window Attention)
  3. 长距离依赖建模
python 复制代码
import torch
import torch.nn as nn
from einops import rearrange

class WindowAttention(nn.Module):
    """
    窗口注意力机制
    
    将特征图划分为不重叠的窗口,在窗口内计算注意力
    大幅降低计算复杂度: O(n²) → O(n × w²)
    """
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
        # 相对位置编码
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
    
    def forward(self, x):
        B, N, C = x.shape
        
        # QKV 投影
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        # 注意力计算
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        # 加权聚合
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        
        return out

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer 块
    
    包含:
      1. 窗口注意力 (W-MSA)
      2. 移位窗口注意力 (SW-MSA)
      3. MLP
    """
    def __init__(self, dim, num_heads, window_size=8, shift_size=0):
        super().__init__()
        self.shift_size = shift_size
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, x):
        # 窗口注意力
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x) + shortcut
        
        # MLP
        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x) + shortcut
        
        return x

class SwinIR(nn.Module):
    """
    Swin Transformer 超分辨率网络
    """
    def __init__(self, scale_factor=4, num_channels=3, 
                 embed_dim=180, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
                 window_size=8):
        super().__init__()
        
        # 浅层特征提取
        self.conv_first = nn.Conv2d(num_channels, embed_dim, kernel_size=3, padding=1)
        
        # 深层特征提取 (Swin Transformer 块)
        self.layers = nn.ModuleList()
        for i, depth in enumerate(depths):
            self.layers.append(
                nn.Sequential(*[
                    SwinTransformerBlock(embed_dim, num_heads[i], window_size)
                    for _ in range(depth)
                ])
            )
        
        # 融合
        self融合 = nn.Conv2d(embed_dim * len(depths), embed_dim, kernel_size=1)
        
        # 上采样
        self.upsample = UpsampleModule(embed_dim, scale_factor)
        
        # 重建
        self重建 = nn.Conv2d(embed_dim, num_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        # 浅层特征
        shallow = self.conv_first(x)
        
        # 深层特征
        layer_outputs = []
        out = shallow
        for layer in self.layers:
            out = layer(out)
            layer_outputs.append(out)
        
        # 融合
        fused = self融合(torch.cat(layer_outputs, dim=1))
        out = fused + shallow  # 全局残差
        
        # 上采样重建
        out = self.upsample(out)
        out = self重建(out)
        
        return out

4. 生成对抗网络方法

4.1 SRGAN(2017)

复制代码
SRGAN: Super-Resolution GAN
论文: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
      (Ledig et al., 2017)

核心思想:
  使用 GAN 训练框架,生成器学习超分辨率,判别器区分真实 HR 和生成 HR
  
  目标: 生成视觉上逼真的图像,而非仅仅优化 PSNR
python 复制代码
class SRResNet(nn.Module):
    """
    生成器网络 (基于 ResNet)
    """
    def __init__(self, num_channels=3, num_features=64, num_blocks=16):
        super().__init__()
        
        # 浅层特征
        self.head = nn.Sequential(
            nn.Conv2d(num_channels, num_features, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        # 残差块
        self.body = nn.Sequential(*[ResidualBlock(num_features) for _ in range(num_blocks)])
        
        # 融合
        self融合 = nn.Sequential(
            nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features)
        )
        
        # 上采样 (亚像素卷积)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_features, num_features * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(num_features, num_features * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        
        # 重建
        self.tail = nn.Conv2d(num_channels, num_channels, kernel_size=9, padding=4)
    
    def forward(self, x):
        head = self.head(x)
        body = self.body(head)
        out = self融合(body) + head
        out = self.upsample(out)
        out = self.tail(out)
        return out

class Discriminator(nn.Module):
    """
    判别器网络
    
    判断输入图像是真实的 HR 还是生成的 SR
    """
    def __init__(self, num_channels=3):
        super().__init__()
        
        def block(in_channels, out_channels, stride):
            return [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                         stride=stride, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        
        self.features = nn.Sequential(
            *block(num_channels, 64, 1),      # 无下采样
            *block(64, 64, 2),                 # 下采样 ×2
            *block(64, 128, 1),
            *block(128, 128, 2),               # 下采样 ×2
            *block(128, 256, 1),
            *block(256, 256, 2),               # 下采样 ×2
            *block(256, 512, 1),
            *block(512, 512, 2)                # 下采样 ×2
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        features = self.features(x)
        output = self.classifier(features)
        return output

class SRGAN:
    """
    SRGAN 训练框架
    """
    def __init__(self, generator, discriminator, 
                 content_weight=1, adversarial_weight=0.001):
        self.G = generator
        self.D = discriminator
        self.content_weight = content_weight
        self.adversarial_weight = adversarial_weight
        
        # 损失函数
        self.content_loss = nn.MSELoss()
        self.adversarial_loss = nn.BCELoss()
        
        # VGG 特征提取器 (用于感知损失)
        self.vgg = self.build_vgg_feature_extractor()
    
    def train_step(self, lr, hr):
        # ──── 训练判别器 ────
        # 真实图像标签为 1
        real_output = self.D(hr)
        d_loss_real = self.adversarial_loss(real_output, torch.ones_like(real_output))
        
        # 生成 SR 图像
        sr = self.G(lr)
        
        # 生成图像标签为 0
        fake_output = self.D(sr.detach())
        d_loss_fake = self.adversarial_loss(fake_output, torch.zeros_like(fake_output))
        
        d_loss = (d_loss_real + d_loss_fake) / 2
        
        # ──── 训练生成器 ────
        # 内容损失 (MSE)
        loss_content = self.content_loss(sr, hr)
        
        # 感知损失 (VGG 特征)
        sr_features = self.vgg(sr)
        hr_features = self.vgg(hr)
        loss_perceptual = self.content_loss(sr_features, hr_features)
        
        # 对抗损失
        fake_output = self.D(sr)
        loss_adversarial = self.adversarial_loss(fake_output, torch.ones_like(fake_output))
        
        # 总损失
        g_loss = (self.content_weight * (loss_content + loss_perceptual) + 
                  self.adversarial_weight * loss_adversarial)
        
        return d_loss, g_loss

4.2 ESRGAN(2018)

复制代码
ESRGAN: Enhanced Super-Resolution GAN
论文: "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" 
      (Wang et al., 2018)

改进:
  1. 移除 BN 层
  2. 使用 Residual-in-Residual Dense Block (RRDB)
  3. 相对判别器 (Relativistic Discriminator)
  4. VGG19 感知损失
python 复制代码
class RRDB(nn.Module):
    """
    Residual-in-Residual Dense Block
    
    结构: 残差块内嵌密集连接
    """
    def __init__(self, channels, growth_rate=32, num_layers=3):
        super().__init__()
        
        self.rdb1 = ResidualDenseBlock(channels, growth_rate, num_layers)
        self.rdb2 = ResidualDenseBlock(channels, growth_rate, num_layers)
        self.rdb3 = ResidualDenseBlock(channels, growth_rate, num_layers)
        
        self.residual_scale = 0.2
    
    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        
        return x + out * self.residual_scale

class RelativisticDiscriminator(nn.Module):
    """
    相对判别器
    
    标准 GAN: D(x) = P(x 是真实的)
    相对 GAN: D(x, y) = P(x 比 y 更真实)
    
    优势: 判别器不仅判断真假,还判断相对真实度
    """
    def __init__(self, num_channels=3):
        super().__init__()
        # 与标准判别器结构相同
        self.features = self.build_features(num_channels)
        self.classifier = self.build_classifier()
    
    def forward(self, real, fake):
        real_features = self.features(real)
        fake_features = self.features(fake)
        
        # 相对真实度
        real_output = self.classifier(real_features)
        fake_output = self.classifier(fake_features)
        
        return real_output, fake_output

def relativistic_loss(discriminator, real, fake):
    """
    相对对抗损失
    """
    real_output, fake_output = discriminator(real, fake)
    
    # 判别器损失
    d_loss = (
        F.binary_cross_entropy(real_output, torch.ones_like(real_output)) +
        F.binary_cross_entropy(fake_output, torch.zeros_like(fake_output))
    ) / 2
    
    # 生成器损失 (目标: 生成的比真实的更真实)
    g_loss = (
        F.binary_cross_entropy(fake_output, torch.ones_like(fake_output)) +
        F.binary_cross_entropy(real_output, torch.zeros_like(real_output))
    ) / 2
    
    return d_loss, g_loss

4.3 Real-ESRGAN(2021)

复制代码
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
论文: "Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data" 
      (Wang et al., 2021)

核心创新:
  1. 高阶退化建模 (二阶退化)
  2. U-Net 判别器 (谱归一化)
  3. 面向真实世界的退化模型
python 复制代码
class HighOrderDegradation:
    """
    高阶退化模型
    
    真实世界的图像退化是复杂的,一阶模型不够
    Real-ESRGAN 使用二阶退化模拟
    
    一阶退化: y = (x ⊗ k) ↓s + n
    二阶退化: y = ((x ⊗ k₁) ↓s₁ + n₁) ⊗ k₂ ↓s₂ + n₂
    """
    def __init__(self):
        self.degradation_types = [
            'blur',           # 模糊 (高斯、运动)
            'downsample',     # 下采样
            'noise',          # 噪声 (高斯、泊松、JPEG)
            'jpeg_compression'# JPEG 压缩
        ]
    
    def apply_first_order(self, hr_image):
        """一阶退化"""
        # 随机模糊核
        kernel = self.random_blur_kernel()
        blurred = self.apply_blur(hr_image, kernel)
        
        # 随机下采样
        scale = random.choice([2, 4])
        downsampled = self.downsample(blurred, scale)
        
        # 随机噪声
        noise_type = random.choice(['gaussian', 'poisson'])
        noisy = self.add_noise(downsampled, noise_type)
        
        # JPEG 压缩
        if random.random() > 0.5:
            noisy = self.jpeg_compress(noisy)
        
        return noisy
    
    def apply_second_order(self, hr_image):
        """二阶退化 (更接近真实)"""
        # 第一阶
        degraded = self.apply_first_order(hr_image)
        
        # 第二阶 (再次退化)
        degraded = self.apply_first_order(degraded)
        
        return degraded

class UNetDiscriminator(nn.Module):
    """
    U-Net 判别器
    
    不仅输出全局真伪判断,还输出逐像素的真伪图
    提供更丰富的梯度信息
    """
    def __init__(self, num_channels=3, num_features=64):
        super().__init__()
        
        # 编码器
        self.enc1 = self.conv_block(num_channels, num_features)
        self.enc2 = self.conv_block(num_features, num_features * 2)
        self.enc3 = self.conv_block(num_features * 2, num_features * 4)
        
        # 瓶颈
        self.bottleneck = self.conv_block(num_features * 4, num_features * 8)
        
        # 解码器
        self.dec3 = self.conv_block(num_features * 8 + num_features * 4, num_features * 4)
        self.dec2 = self.conv_block(num_features * 4 + num_features * 2, num_features * 2)
        self.dec1 = self.conv_block(num_features * 2 + num_features, num_features)
        
        # 输出层 (逐像素判别)
        self.output = nn.Conv2d(num_features, 1, kernel_size=1)
        
        # 谱归一化
        self.apply_spectral_norm()
    
    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
    
    def forward(self, x):
        # 编码
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        
        # 瓶颈
        b = self.bottleneck(e3)
        
        # 解码 (带跳跃连接)
        d3 = self.dec3(torch.cat([b, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        
        # 逐像素输出
        output = self.output(d1)
        
        return output

5. 扩散模型方法

5.1 SR3(2021)

复制代码
SR3: Image Super-Resolution via Iterative Refinement
论文: "Image Super-Resolution via Iterative Refinement" (Saharia et al., 2021)

核心思想:
  将超分辨率建模为条件扩散过程
  
  正向过程: 逐步向 HR 图像添加噪声
  逆向过程: 从噪声中恢复 HR 图像 (以 LR 为条件)
python 复制代码
class ConditionalUNet(nn.Module):
    """
    条件 U-Net
    
    输入: 噪声图像 x_t + LR 条件图像
    输出: 预测的噪声 ε_θ
    """
    def __init__(self, in_channels=6, out_channels=3, features=128):
        super().__init__()
        
        # 时间嵌入
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(features),
            nn.Linear(features, features * 4),
            nn.GELU(),
            nn.Linear(features * 4, features)
        )
        
        # 编码器
        self.enc1 = self.conv_block(in_channels, features)
        self.enc2 = self.conv_block(features, features * 2)
        self.enc3 = self.conv_block(features * 2, features * 4)
        
        # 瓶颈
        self.bottleneck = self.conv_block(features * 4, features * 8)
        
        # 解码器 (带跳跃连接)
        self.dec3 = self.conv_block(features * 8 + features * 4, features * 4)
        self.dec2 = self.conv_block(features * 4 + features * 2, features * 2)
        self.dec1 = self.conv_block(features * 2 + features, features)
        
        # 输出
        self.output = nn.Conv2d(features, out_channels, kernel_size=1)
    
    def forward(self, x, t, lr_condition):
        """
        x: 噪声图像 [B, 3, H, W]
        t: 时间步 [B]
        lr_condition: LR 条件图像 [B, 3, H, W]
        """
        # 拼接条件
        x = torch.cat([x, lr_condition], dim=1)  # [B, 6, H, W]
        
        # 时间嵌入
        t_emb = self.time_mlp(t)
        
        # 编码
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        
        # 瓶颈
        b = self.bottleneck(e3)
        
        # 解码
        d3 = self.dec3(torch.cat([b, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        
        # 预测噪声
        noise_pred = self.output(d1)
        
        return noise_pred

class SR3:
    """
    SR3 超分辨率模型
    """
    def __init__(self, model, num_timesteps=1000):
        self.model = model
        self.num_timesteps = num_timesteps
        
        # 噪声调度
        self.betas = self.cosine_beta_schedule(num_timesteps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    
    def forward_diffusion(self, hr_image, t, noise=None):
        """
        正向扩散: q(x_t | x_0)
        
        x_t = √(ᾱ_t) * x_0 + √(1 - ᾱ_t) * ε
        """
        if noise is None:
            noise = torch.randn_like(hr_image)
        
        alpha_cumprod = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        
        noisy_image = torch.sqrt(alpha_cumprod) * hr_image + \
                     torch.sqrt(1 - alpha_cumprod) * noise
        
        return noisy_image, noise
    
    def reverse_diffusion(self, lr_image, num_steps=100):
        """
        逆向扩散: p(x_{t-1} | x_t, lr)
        
        从纯噪声开始,逐步去噪
        """
        # 初始噪声
        x = torch.randn_like(lr_image)
        
        for t in reversed(range(num_steps)):
            # 预测噪声
            noise_pred = self.model(x, t, lr_image)
            
            # 去噪一步
            x = self.denoise_step(x, noise_pred, t)
        
        return x
    
    def train_step(self, lr_image, hr_image):
        """
        训练步骤
        
        1. 随机采样时间步 t
        2. 对 HR 图像添加噪声得到 x_t
        3. 模型预测噪声
        4. 计算 MSE 损失
        """
        batch_size = lr_image.shape[0]
        
        # 随机时间步
        t = torch.randint(0, self.num_timesteps, (batch_size,))
        
        # 随机噪声
        noise = torch.randn_like(hr_image)
        
        # 正向扩散
        noisy_hr, noise = self.forward_diffusion(hr_image, t, noise)
        
        # 预测噪声
        noise_pred = self.model(noisy_hr, t, lr_image)
        
        # 损失
        loss = F.mse_loss(noise_pred, noise)
        
        return loss

5.2 DiffIR(2023)

复制代码
DiffIR: Effective Diffusion Model for Image Restoration
论文: "DiffIR: Effective Diffusion Model for Image Restoration" (Xia et al., 2023)

改进:
  1. 两阶段训练
  2. 更高效的扩散过程
  3. 减少采样步数
python 复制代码
class DiffIR:
    """
    两阶段 DiffIR
    
    阶段 1: 训练预测器 (Predictor)
      - 输入退化图像
      - 输出粗略的恢复结果
    
    阶段 2: 训练扩散精炼器 (Refiner)
      - 以粗略结果为条件
      - 精细恢复高频细节
    """
    def __init__(self):
        self.predictor = UNet(in_channels=3, out_channels=3)
        self.refiner = ConditionalUNet(in_channels=6, out_channels=3)
        self.diffusion = GaussianDiffusion(num_timesteps=100)
    
    def train_predictor(self, degraded, clean):
        """阶段 1: 训练预测器"""
        pred = self.predictor(degraded)
        loss = F.l1_loss(pred, clean)
        return loss
    
    def train_refiner(self, degraded, clean):
        """阶段 2: 训练扩散精炼器"""
        # 先用预测器得到粗略结果
        with torch.no_grad():
            coarse = self.predictor(degraded)
        
        # 扩散精炼
        t = torch.randint(0, 100, (degraded.shape[0],))
        noise = torch.randn_like(clean)
        noisy_clean = self.diffusion.forward_process(clean, t, noise)
        
        # 条件: 粗略结果 + 退化图像
        condition = torch.cat([coarse, degraded], dim=1)
        
        # 预测噪声
        noise_pred = self.refiner(noisy_clean, t, condition)
        
        loss = F.mse_loss(noise_pred, noise)
        return loss
    
    def inference(self, degraded):
        """推理"""
        # 阶段 1: 粗略恢复
        coarse = self.predictor(degraded)
        
        # 阶段 2: 扩散精炼
        condition = torch.cat([coarse, degraded], dim=1)
        refined = self.diffusion.sample(condition, num_steps=10)
        
        return refined

6. 视频超分辨率

6.1 时序对齐

python 复制代码
class TemporalAlignment(nn.Module):
    """
    时序对齐模块
    
    问题: 相邻帧之间存在运动,需要对齐后才能融合
    
    方法:
      1. 光流估计
      2. 可变形卷积
      3. 注意力对齐
    """
    pass

class OpticalFlowAlignment(nn.Module):
    """
    基于光流的时序对齐
    
    流程:
      1. 估计相邻帧之间的光流
      2. 使用光流对齐帧
      3. 对齐后进行特征融合
    """
    def __init__(self):
        super().__init__()
        self.flow_estimator = FlowNet()  # 光流估计网络
    
    def align(self, reference, target):
        """
        将 target 帧对齐到 reference 帧
        """
        # 估计光流
        flow = self.flow_estimator(reference, target)
        
        # 使用光流进行变形
        aligned = self.warp(target, flow)
        
        return aligned
    
    def warp(self, image, flow):
        """
        使用光流变形图像
        
        使用双线性插值实现可微分的变形
        """
        B, C, H, W = image.shape
        
        # 生成网格
        grid_y, grid_x = torch.meshgrid(
            torch.arange(H), torch.arange(W), indexing='ij'
        )
        grid = torch.stack([grid_x, grid_y], dim=0).float()
        
        # 加上光流偏移
        grid = grid + flow
        grid = grid.permute(0, 2, 3, 1)
        
        # 归一化到 [-1, 1]
        grid[..., 0] = 2 * grid[..., 0] / (W - 1) - 1
        grid[..., 1] = 2 * grid[..., 1] / (H - 1) - 1
        
        # 双线性采样
        aligned = F.grid_sample(image, grid, mode='bilinear', 
                               padding_mode='border', align_corners=True)
        
        return aligned

class DeformableConvAlignment(nn.Module):
    """
    可变形卷积对齐
    
    优势: 学习采样位置,比固定光流更灵活
    """
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        
        # 学习偏移量
        self.offset_conv = nn.Conv2d(
            channels * 2,  # 输入: 参考帧 + 目标帧
            2 * kernel_size * kernel_size,  # 输出: x, y 偏移
            kernel_size=3, padding=1
        )
        
        # 可变形卷积
        self.deform_conv = DeformConv2d(
            channels, channels, 
            kernel_size=kernel_size, padding=kernel_size // 2
        )
    
    def forward(self, reference, target):
        # 计算偏移量
        offset = self.offset_conv(torch.cat([reference, target], dim=1))
        
        # 可变形卷积
        aligned = self.deform_conv(target, offset)
        
        return aligned

6.2 视频 SR 网络架构

python 复制代码
class BasicVSR(nn.Module):
    """
    BasicVSR: 基于双向传播的视频超分辨率
    
    论文: "BasicVSR: The Search for Essential Components in Video Super-Resolution 
          and Beyond" (Chan et al., 2021)
    
    核心思想:
      1. 双向时序传播 (前向 + 后向)
      2. 光流对齐
      3. 残差块特征提取
    """
    def __init__(self, num_channels=3, num_features=64, num_blocks=30):
        super().__init__()
        
        # 特征提取器
        self.feat_extract = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
        
        # 光流估计
        self.flow_estimator = SpyNet()
        
        # 传播模块 (双向)
        self.backward_resblocks = nn.Sequential(*[
            ResidualBlock(num_features) for _ in range(num_blocks)
        ])
        self.forward_resblocks = nn.Sequential(*[
            ResidualBlock(num_features) for _ in range(num_blocks)
        ])
        
        # 融合和重建
        self融合 = nn.Conv2d(num_features * 2, num_features, kernel_size=1)
        self.upsample = UpsampleModule(num_features, scale_factor=4)
        self重建 = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
    
    def forward(self, lr_frames):
        """
        lr_frames: [B, T, C, H, W] (T 帧 LR 视频)
        """
        B, T, C, H, W = lr_frames.shape
        
        # 提取特征
        features = []
        for t in range(T):
            feat = self.feat_estimate(lr_frames[:, t])
            features.append(feat)
        
        # 反向传播 (从最后一帧到第一帧)
        backward_features = []
        feat = features[-1]
        for t in reversed(range(T)):
            if t < T - 1:
                # 光流对齐
                flow = self.flow_estimator(lr_frames[:, t], lr_frames[:, t + 1])
                aligned = self.warp(backward_features[-1], flow)
                feat = feat + aligned
            
            feat = self.backward_resblocks(feat)
            backward_features.append(feat)
        backward_features = backward_features[::-1]
        
        # 前向传播
        forward_features = []
        feat = features[0]
        for t in range(T):
            if t > 0:
                flow = self.flow_estimator(lr_frames[:, t], lr_frames[:, t - 1])
                aligned = self.warp(forward_features[-1], flow)
                feat = feat + aligned
            
            # 融合双向特征
            combined = torch.cat([feat, backward_features[t]], dim=1)
            feat = self融合(combined)
            
            feat = self.forward_resblocks(feat)
            forward_features.append(feat)
        
        # 重建 HR 帧
        hr_frames = []
        for feat in forward_features:
            hr = self.upsample(feat)
            hr = self重建(hr)
            hr_frames.append(hr)
        
        return torch.stack(hr_frames, dim=1)

class BasicVSRPlusPlus(nn.Module):
    """
    BasicVSR++: 改进版本
    
    改进:
      1. 二次传播 (Second-Order Propagation)
      2. 流引导可变形对齐
      3. 更高效的特征融合
    """
    pass

7. 损失函数与评估指标

7.1 损失函数

python 复制代码
class SRLosses:
    """超分辨率损失函数集合"""
    
    @staticmethod
    def pixel_loss(pred, target, loss_type='l1'):
        """
        像素级损失
        
        L1: ‖pred - target‖₁
        L2: ‖pred - target‖₂²
        
        L1 优势: 边缘更清晰,对异常值更鲁棒
        L2 优势: 优化更稳定,PSNR 更高
        """
        if loss_type == 'l1':
            return F.l1_loss(pred, target)
        elif loss_type == 'l2':
            return F.mse_loss(pred, target)
        elif loss_type == 'charbonnier':
            # 平滑 L1,对小误差更敏感
            eps = 1e-6
            return torch.mean(torch.sqrt((pred - target) ** 2 + eps))
    
    @staticmethod
    def perceptual_loss(pred, target, vgg_model, layer_weights=None):
        """
        感知损失 (Perceptual Loss)
        
        在 VGG 特征空间计算距离,而非像素空间
        生成更符合人类视觉感知的结果
        """
        if layer_weights is None:
            layer_weights = {
                'relu1_2': 1.0,
                'relu2_2': 1.0,
                'relu3_3': 1.0,
                'relu4_3': 1.0
            }
        
        loss = 0
        pred_features = vgg_model(pred)
        target_features = vgg_model(target)
        
        for layer, weight in layer_weights.items():
            loss += weight * F.l1_loss(
                pred_features[layer], 
                target_features[layer]
            )
        
        return loss
    
    @staticmethod
    def style_loss(pred, target, vgg_model):
        """
        风格损失 (Style Loss / Gram Loss)
        
        匹配特征的 Gram 矩阵,保持纹理风格
        """
        def gram_matrix(features):
            B, C, H, W = features.shape
            features = features.view(B, C, -1)
            gram = torch.bmm(features, features.transpose(1, 2))
            return gram / (C * H * W)
        
        pred_features = vgg_model(pred)
        target_features = vgg_model(target)
        
        loss = 0
        for layer in pred_features:
            pred_gram = gram_matrix(pred_features[layer])
            target_gram = gram_matrix(target_features[layer])
            loss += F.mse_loss(pred_gram, target_gram)
        
        return loss
    
    @staticmethod
    def adversarial_loss(discriminator_output, mode='original'):
        """
        对抗损失
        
        原始 GAN: -log(D(G(x)))
        LSGAN: (D(G(x)) - 1)²
        Hinge: max(0, 1 - D(G(x)))
        """
        if mode == 'original':
            return F.binary_cross_entropy_with_logits(
                discriminator_output, 
                torch.ones_like(discriminator_output)
            )
        elif mode == 'lsgan':
            return F.mse_loss(
                discriminator_output, 
                torch.ones_like(discriminator_output)
            )
        elif mode == 'hinge':
            return -discriminator_output.mean()
    
    @staticmethod
    def frequency_loss(pred, target, alpha=1.0):
        """
        频域损失
        
        在频域约束高频细节的恢复
        """
        # FFT
        pred_fft = torch.fft.fft2(pred)
        target_fft = torch.fft.fft2(target)
        
        # 频域 L1 损失
        loss = F.l1_loss(
            torch.abs(pred_fft), 
            torch.abs(target_fft)
        )
        
        # 可以加权高频部分
        # 高频对应 FFT 的边缘区域
        
        return loss * alpha

7.2 评估指标

python 复制代码
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

class SRMetrics:
    """超分辨率评估指标"""
    
    @staticmethod
    def calculate_psnr(pred, target, max_val=1.0):
        """
        PSNR (Peak Signal-to-Noise Ratio)
        
        公式: PSNR = 10 * log10(MAX² / MSE)
        
        范围: 通常 20-40 dB,越高越好
        局限: 不完全符合人类视觉感知
        """
        mse = np.mean((pred - target) ** 2)
        if mse == 0:
            return float('inf')
        return 10 * np.log10(max_val ** 2 / mse)
    
    @staticmethod
    def calculate_ssim(pred, target, window_size=11):
        """
        SSIM (Structural Similarity Index)
        
        考虑亮度、对比度、结构三个方面
        
        范围: [-1, 1],越接近 1 越好
        优势: 比 PSNR 更符合人类感知
        """
        return ssim(pred, target, data_range=1.0, 
                   win_size=window_size, channel_axis=-1)
    
    @staticmethod
    def calculate_lpips(pred, target, net='alex'):
        """
        LPIPS (Learned Perceptual Image Patch Similarity)
        
        使用预训练网络计算感知距离
        范围: [0, 1],越低越好
        优势: 最符合人类感知的指标
        """
        import lpips
        loss_fn = lpips.LPIPS(net=net)
        
        # 转换为 tensor
        pred_t = torch.from_numpy(pred).permute(2, 0, 1).unsqueeze(0).float()
        target_t = torch.from_numpy(target).permute(2, 0, 1).unsqueeze(0).float()
        
        return loss_fn(pred_t, target_t).item()
    
    @staticmethod
    def calculate_niqe(image):
        """
        NIQE (Natural Image Quality Evaluator)
        
        无参考图像质量评估
        无需 ground truth,基于自然图像统计
        范围: 越低越好
        """
        # 提取自然场景统计特征
        features = extract_niqe_features(image)
        
        # 与自然图像分布比较
        niqe_score = compute_niqe_score(features)
        
        return niqe_score

# 指标对比
"""
┌─────────────────────────────────────────────────────────────┐
│                    评估指标对比                               │
├──────────┬─────────────┬──────────────┬─────────────────────┤
│  指标    │  需要 GT    │  感知相关性  │  适用场景            │
├──────────┼─────────────┼──────────────┼─────────────────────┤
│  PSNR    │  ✓          │  低          │  客观质量评估        │
│  SSIM    │  ✓          │  中          │  结构保持评估        │
│  LPIPS   │  ✓          │  高          │  感知质量评估        │
│  FID     │  ✓(数据集)  │  高          │  生成质量评估        │
│  NIQE    │  ✗          │  中          │  无参考评估          │
└──────────┴─────────────┴──────────────┴─────────────────────┘
"""

8. 工程实践与部署

8.1 模型选择指南

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    模型选择决策树                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  需求分析:                                                   │
│  ├── 实时性要求高?                                          │
│  │   ├── 是 → ESPCN, FSRCNN (轻量级)                       │
│  │   └─ 否 → 继续评估                                      │
│  │                                                          │
│  ├── 视觉质量优先?                                          │
│  │   ├── 是 → ESRGAN, Real-ESRGAN (GAN 方法)               │
│  │   └─ 否 → 继续评估                                      │
│  │                                                          │
│  ├── PSNR 指标优先?                                         │
│  │   ├── 是 → EDSR, RCAN, SwinIR (回归方法)                │
│  │   └─ 否 → 继续评估                                      │
│  │                                                          │
│  ├── 真实世界退化?                                          │
│  │   ├── 是 → Real-ESRGAN, DiffIR (盲 SR)                  │
│  │   └─ 否 → 标准方法                                      │
│  │                                                          │
│  └── 资源受限?                                              │
│      ├── 是 → 轻量级模型 + 知识蒸馏                         │
│      └─ 否 → 大模型 + 集成                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

8.2 模型部署

python 复制代码
class SRModelDeployer:
    """超分辨率模型部署工具"""
    
    @staticmethod
    def export_onnx(model, input_shape, save_path):
        """
        导出 ONNX 格式
        """
        dummy_input = torch.randn(input_shape)
        
        torch.onnx.export(
            model,
            dummy_input,
            save_path,
            opset_version=11,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size', 2: 'height', 3: 'width'},
                'output': {0: 'batch_size', 2: 'height', 3: 'width'}
            }
        )
    
    @staticmethod
    def optimize_tensorrt(onnx_path, engine_path, fp16=True):
        """
        TensorRT 优化
        """
        import tensorrt as trt
        
        logger = trt.Logger(trt.Logger.WARNING)
        builder = trt.Builder(logger)
        network = builder.create_network(
            1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        )
        parser = trt.OnnxParser(network, logger)
        
        # 解析 ONNX
        with open(onnx_path, 'rb') as f:
            parser.parse(f.read())
        
        # 配置
        config = builder.create_builder_config()
        config.max_workspace_size = 1 << 30  # 1GB
        
        if fp16:
            config.set_flag(trt.BuilderFlag.FP16)
        
        # 构建引擎
        engine = builder.build_engine(network, config)
        
        # 保存
        with open(engine_path, 'wb') as f:
            f.write(engine.serialize())
    
    @staticmethod
    def tile_process(model, image, tile_size=512, tile_overlap=32):
        """
        分块处理大图像
        
        问题: 大图像无法一次性放入显存
        解决: 分块处理,重叠区域渐入渐出融合
        """
        h, w = image.shape[:2]
        
        # 计算分块
        tiles = []
        for y in range(0, h, tile_size - tile_overlap):
            for x in range(0, w, tile_size - tile_overlap):
                # 边界处理
                y_end = min(y + tile_size, h)
                x_end = min(x + tile_size, w)
                y_start = max(0, y_end - tile_size)
                x_start = max(0, x_end - tile_size)
                
                tile = image[y_start:y_end, x_start:x_end]
                tiles.append({
                    'tile': tile,
                    'position': (y_start, x_start, y_end, x_end)
                })
        
        # 处理每个块
        result = np.zeros_like(image)
        weight_map = np.zeros_like(image)
        
        for tile_info in tiles:
            tile = tile_info['tile']
            pos = tile_info['position']
            
            # 模型推理
            sr_tile = model(tile)
            
            # 渐入渐出权重
            weight = create_weight_map(tile.shape, tile_overlap)
            
            # 累加
            result[pos[0]:pos[2], pos[1]:pos[3]] += sr_tile * weight
            weight_map[pos[0]:pos[2], pos[1]:pos[3]] += weight
        
        # 归一化
        result = result / (weight_map + 1e-8)
        
        return result

8.3 训练技巧

python 复制代码
class SRTrainingTricks:
    """超分辨率训练技巧"""
    
    @staticmethod
    def progressive_training(model, dataset, scale_factors=[2, 3, 4]):
        """
        渐进式训练
        
        先训练小倍数,再微调大倍数
        有助于稳定训练,提升性能
        """
        for scale in scale_factors:
            print(f"Training with scale factor ×{scale}")
            
            # 调整模型上采样倍数
            model.set_scale_factor(scale)
            
            # 准备对应尺度的数据
            train_loader = create_dataloader(dataset, scale)
            
            # 训练
            train(model, train_loader, epochs=100)
    
    @staticmethod
    def patch_training(hr_images, patch_size=192, batch_size=16):
        """
        图像块训练
        
        从 HR 图像中随机裁剪块,降低显存需求
        """
        class PatchDataset(torch.utils.data.Dataset):
            def __init__(self, hr_images, patch_size, scale_factor):
                self.hr_images = hr_images
                self.patch_size = patch_size
                self.scale_factor = scale_factor
                self.lr_patch_size = patch_size // scale_factor
            
            def __getitem__(self, idx):
                # 随机选择图像
                img = self.hr_images[idx % len(self.hr_images)]
                
                # 随机裁剪 HR 块
                h, w = img.shape[:2]
                y = np.random.randint(0, h - self.patch_size)
                x = np.random.randint(0, w - self.patch_size)
                hr_patch = img[y:y+self.patch_size, x:x+self.patch_size]
                
                # 下采样得到 LR 块
                lr_patch = cv2.resize(
                    hr_patch, 
                    (self.lr_patch_size, self.lr_patch_size),
                    interpolation=cv2.INTER_CUBIC
                )
                
                # 随机翻转和旋转 (数据增强)
                lr_patch, hr_patch = self.augment(lr_patch, hr_patch)
                
                return lr_patch, hr_patch
            
            def augment(self, lr, hr):
                # 随机水平翻转
                if np.random.random() > 0.5:
                    lr = np.flip(lr, axis=1).copy()
                    hr = np.flip(hr, axis=1).copy()
                
                # 随机垂直翻转
                if np.random.random() > 0.5:
                    lr = np.flip(lr, axis=0).copy()
                    hr = np.flip(hr, axis=0).copy()
                
                # 随机 90 度旋转
                k = np.random.randint(0, 4)
                lr = np.rot90(lr, k).copy()
                hr = np.rot90(hr, k).copy()
                
                return lr, hr
        
        return PatchDataset(hr_images, patch_size, scale_factor=4)
    
    @staticmethod
    def self_ensemble(model, lr_image):
        """
        自集成 (Self-Ensemble)
        
        对输入进行多种变换,推理后取平均
        可提升 0.1-0.3 dB
        """
        def augment_transform(img, mode):
            if mode == 0: return img
            elif mode == 1: return np.flip(img, axis=0).copy()
            elif mode == 2: return np.flip(img, axis=1).copy()
            elif mode == 3: return np.rot90(img, k=1).copy()
            elif mode == 4: return np.rot90(img, k=2).copy()
            elif mode == 5: return np.rot90(img, k=3).copy()
            elif mode == 6: return np.flip(np.rot90(img, k=1), axis=0).copy()
            elif mode == 7: return np.flip(np.rot90(img, k=1), axis=1).copy()
        
        def deaugment_transform(img, mode):
            if mode == 0: return img
            elif mode == 1: return np.flip(img, axis=0).copy()
            elif mode == 2: return np.flip(img, axis=1).copy()
            elif mode == 3: return np.rot90(img, k=-1).copy()
            elif mode == 4: return np.rot90(img, k=-2).copy()
            elif mode == 5: return np.rot90(img, k=-3).copy()
            elif mode == 6: return np.rot90(np.flip(img, axis=0), k=-1).copy()
            elif mode == 7: return np.rot90(np.flip(img, axis=1), k=-1).copy()
        
        # 8 种变换
        results = []
        for mode in range(8):
            augmented = augment_transform(lr_image, mode)
            sr = model(augmented)
            original = deaugment_transform(sr, mode)
            results.append(original)
        
        # 取平均
        return np.mean(results, axis=0)

附录

A. 算法发展时间线

复制代码
2014  ──┬──  SRCNN (首次将 CNN 应用于 SR)
        │
2015  ──┼──  DRCN (递归网络)
        │
2016  ──┼──  FSRCNN (加速版) / ESPCN (亚像素卷积) / VDSR (残差学习)
        │
2017  ──┼──  SRGAN (GAN 方法) / EDSR (增强残差)
        │
2018  ──┼──  ESRGAN (增强 GAN) / RDN (密集连接) / RCAN (通道注意力)
        │
2019  ──┼──  SRFBN (反馈网络) / RankSRGAN (感知指标优化)
        │
2020  ──┼──  HAN (分层注意力) / RRDB (广泛使用)
        │
2021  ──┼──  SwinIR (Transformer) / Real-ESRGAN (真实世界) / SR3 (扩散)
        │
2022  ──┼──  StableSR (稳定扩散) / BasicVSR++ (视频 SR)
        │
2023  ──┼──  DiffIR (高效扩散) / SeeSR (语义引导)
        │
2024+ ──┴──  持续发展: 更高效架构、更真实退化、多模态融合

B. 常用数据集

数据集 图像数量 用途 特点
DIV2K 1000 训练/验证 高质量,广泛使用
Set5 5 测试 经典测试集
Set14 14 测试 多样化场景
BSD100 100 测试 自然图像
Urban100 100 测试 城市建筑,重复结构
Manga109 109 测试 日本漫画
Flickr2K 2650 训练 高分辨率真实图像
OST 10000+ 训练 大规模场景数据集

相关推荐
嵌入式老牛6 小时前
液晶段码(米/日字格)识别—倾斜校正
opencv·算法·仿射变换
luj_17686 小时前
残熵算法:风险缓冲与效率优化的融合
c语言·开发语言·网络·经验分享·算法
oddsand17 小时前
pgvector 三大相似度算法
人工智能·算法·机器学习
运筹vivo@7 小时前
LeetCode 2574. 左右元素和的差值
算法·leetcode·职场和发展·每日一题
计算机安禾7 小时前
【数据库系统原理】第4篇:关系数据结构的形式化定义:域、笛卡尔积与关系模式
数据结构·数据库·算法
手写码匠7 小时前
手写 DeepSeek 推理引擎优化:从 FP16 到 INT4 的量化加速实战
人工智能·深度学习·算法·aigc
GuWenyue7 小时前
LeetCode 76 最小覆盖子串|JS 滑动窗口标准解法
前端·算法·面试
一只齐刘海的猫8 小时前
【Leetcode】移动零
算法·leetcode·职场和发展
落羽的落羽9 小时前
【项目】JsonRpc框架——开发实现1(细节功能、字段定义、抽象层、具象层)
linux·服务器·网络·c++·人工智能·算法·机器学习