深度学习图像超分辨率技术全面解析:从入门到精通

深度学习图像超分辨率技术全面解析:从入门到精通

本文系统梳理基于深度学习的单图像超分辨率(SISR)技术,涵盖问题定义、网络架构设计、损失函数、评估指标、前沿方法及实际应用,配合完整的PyTorch代码实现,帮你全面掌握这一图像处理的核心技术。


一、什么是图像超分辨率

1.1 问题定义

图像超分辨率(Super-Resolution, SR) 是指从低分辨率(LR)图像重建高分辨率(HR)图像的技术。

复制代码
超分辨率任务示意:

┌─────────────┐                    ┌─────────────────────┐
│             │                    │                     │
│   64×64     │    SR网络          │      256×256        │
│   LR图像    │  ──────────→       │      HR图像         │
│             │    ×4放大          │                     │
└─────────────┘                    └─────────────────────┘

目标:恢复清晰的纹理、边缘和细节

1.2 数学建模

图像退化过程通常建模为:

复制代码
退化模型:

I_LR = D(I_HR; θ_D)

其中:
D(I_HR; θ_D) = (I_HR ⊗ κ) ↓_s + n

参数说明:
- I_HR: 原始高分辨率图像
- κ: 模糊核(如高斯模糊)
- ⊗: 卷积操作
- ↓_s: 下采样操作,缩放因子为s
- n: 加性噪声(通常是高斯白噪声)

超分辨率任务:
给定 I_LR,重建 I_SR ≈ I_HR

I_SR = F(I_LR; θ_F)

其中 F 是超分辨率模型,θ_F 是模型参数

1.3 为什么超分辨率很难?

复制代码
超分辨率是一个病态问题(ill-posed problem):

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│  一个低分辨率图像可能对应无数个高分辨率图像                  │
│                                                             │
│        HR₁ ─┐                                               │
│        HR₂ ─┼──→ 下采样 ──→ LR                             │
│        HR₃ ─┘                                               │
│        ...                                                  │
│                                                             │
│  信息在下采样过程中丢失了,无法完美恢复                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

挑战:
1. 高频细节(纹理、边缘)在下采样时丢失
2. 需要从有限信息中"猜测"缺失的细节
3. 不同场景需要不同的先验知识

1.4 应用场景

复制代码
超分辨率的实际应用:

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│  安防监控    │ 增强模糊的监控画面,辅助人脸识别            │
│  医学影像    │ 提升CT/MRI图像清晰度,辅助诊断              │
│  卫星遥感    │ 增强卫星图像分辨率,提取地物信息            │
│  视频增强    │ 将老旧视频/低分辨率视频转为高清             │
│  手机摄影    │ 数码变焦、夜景增强                          │
│  游戏/影视   │ 老游戏/老电影高清重制                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

二、常用数据集与评估指标

2.1 基准数据集

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    常用超分辨率数据集                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  训练集:                                                   │
│  ┌─────────────────────────────────────────────────────┐   │
│  │ DIV2K      │ 1000张高质量2K图像,最常用              │   │
│  │ Flickr2K   │ 2650张2K图像,常与DIV2K合并使用        │   │
│  │ ImageNet   │ 大规模图像数据集,用于预训练            │   │
│  │ T91        │ 91张图像,早期常用的小数据集            │   │
│  │ BSDS500    │ 500张自然图像                          │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  测试集:                                                   │
│  ┌─────────────────────────────────────────────────────┐   │
│  │ Set5       │ 5张经典测试图像                         │   │
│  │ Set14      │ 14张测试图像                            │   │
│  │ BSD100     │ 100张自然图像                           │   │
│  │ Urban100   │ 100张城市建筑图像,边缘丰富             │   │
│  │ Manga109   │ 109张日本漫画图像                       │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 退化模式

python 复制代码
"""
常用的图像退化方式
"""
import cv2
import numpy as np
from scipy.ndimage import gaussian_filter

def bicubic_degradation(hr_image, scale=4):
    """
    BI模式:双三次下采样(最常用)
    
    直接使用bicubic插值进行下采样
    """
    h, w = hr_image.shape[:2]
    lr_image = cv2.resize(hr_image, (w//scale, h//scale), 
                         interpolation=cv2.INTER_CUBIC)
    return lr_image


def blur_downsample_degradation(hr_image, scale=3, kernel_size=7, sigma=1.6):
    """
    BD模式:模糊 + 下采样
    
    先用高斯核模糊,再下采样
    """
    # 高斯模糊
    blurred = cv2.GaussianBlur(hr_image, (kernel_size, kernel_size), sigma)
    
    # 下采样
    h, w = blurred.shape[:2]
    lr_image = cv2.resize(blurred, (w//scale, h//scale), 
                         interpolation=cv2.INTER_CUBIC)
    return lr_image


def downsample_noise_degradation(hr_image, scale=3, noise_level=30):
    """
    DN模式:下采样 + 噪声
    
    先下采样,再加高斯噪声
    """
    # 下采样
    h, w = hr_image.shape[:2]
    lr_image = cv2.resize(hr_image, (w//scale, h//scale), 
                         interpolation=cv2.INTER_CUBIC)
    
    # 添加高斯噪声
    noise = np.random.normal(0, noise_level, lr_image.shape)
    lr_image = np.clip(lr_image + noise, 0, 255).astype(np.uint8)
    
    return lr_image


def complex_degradation(hr_image, scale=4, blur_sigma=1.5, noise_sigma=10, 
                       jpeg_quality=70):
    """
    复杂退化模式(更接近真实世界)
    
    模糊 → 下采样 → 噪声 → JPEG压缩
    """
    # 1. 模糊
    blurred = cv2.GaussianBlur(hr_image, (21, 21), blur_sigma)
    
    # 2. 下采样
    h, w = blurred.shape[:2]
    downsampled = cv2.resize(blurred, (w//scale, h//scale), 
                            interpolation=cv2.INTER_CUBIC)
    
    # 3. 添加噪声
    noise = np.random.normal(0, noise_sigma, downsampled.shape)
    noisy = np.clip(downsampled + noise, 0, 255).astype(np.uint8)
    
    # 4. JPEG压缩
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
    _, encoded = cv2.imencode('.jpg', noisy, encode_param)
    lr_image = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
    
    return lr_image

2.3 评估指标

PSNR(峰值信噪比)
python 复制代码
import numpy as np
import torch

def calculate_psnr(img1, img2, max_val=255.0):
    """
    计算PSNR(Peak Signal-to-Noise Ratio)
    
    PSNR = 10 * log10(MAX² / MSE)
    
    PSNR越高,图像质量越好
    一般来说:
    - PSNR < 30dB: 质量较差
    - 30-40dB: 质量可接受
    - PSNR > 40dB: 质量很好
    """
    mse = np.mean((img1.astype(np.float64) - img2.astype(np.float64)) ** 2)
    if mse == 0:
        return float('inf')
    
    psnr = 10 * np.log10((max_val ** 2) / mse)
    return psnr


def calculate_psnr_torch(img1, img2, max_val=1.0):
    """PyTorch版本的PSNR计算"""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    
    psnr = 10 * torch.log10((max_val ** 2) / mse)
    return psnr.item()
SSIM(结构相似性)
python 复制代码
def calculate_ssim(img1, img2, window_size=11, C1=0.01**2, C2=0.03**2):
    """
    计算SSIM(Structural Similarity Index)
    
    SSIM考虑三个方面:
    1. 亮度对比 (luminance)
    2. 对比度对比 (contrast)  
    3. 结构对比 (structure)
    
    SSIM = (2*μx*μy + C1)(2*σxy + C2) / ((μx² + μy² + C1)(σx² + σy² + C2))
    
    SSIM范围:[-1, 1],越接近1越好
    """
    from scipy.ndimage import uniform_filter
    
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    
    # 计算均值
    mu1 = uniform_filter(img1, window_size)
    mu2 = uniform_filter(img2, window_size)
    
    # 计算方差和协方差
    sigma1_sq = uniform_filter(img1 ** 2, window_size) - mu1 ** 2
    sigma2_sq = uniform_filter(img2 ** 2, window_size) - mu2 ** 2
    sigma12 = uniform_filter(img1 * img2, window_size) - mu1 * mu2
    
    # 计算SSIM
    numerator = (2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2)
    
    ssim_map = numerator / denominator
    
    return np.mean(ssim_map)
感知质量指标
python 复制代码
"""
感知质量评估指标

PSNR/SSIM关注像素级差异,但不能完全反映人眼视觉感受
感知质量指标更关注图像的视觉效果
"""

# LPIPS (Learned Perceptual Image Patch Similarity)
# pip install lpips
import lpips

def calculate_lpips(img1, img2, net='alex'):
    """
    计算LPIPS
    
    使用预训练网络(如AlexNet、VGG)提取特征
    计算特征空间的距离
    
    LPIPS越低越好
    """
    loss_fn = lpips.LPIPS(net=net)
    
    # 转换为torch tensor,范围[-1, 1]
    img1_tensor = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1
    img2_tensor = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1
    
    with torch.no_grad():
        distance = loss_fn(img1_tensor, img2_tensor)
    
    return distance.item()


# NIQE (Natural Image Quality Evaluator) - 无参考评估
def calculate_niqe(img):
    """
    计算NIQE(无需参考图像)
    
    基于自然图像统计特性
    NIQE越低越好
    
    需要安装:pip install pyiqa
    """
    import pyiqa
    
    niqe_metric = pyiqa.create_metric('niqe')
    
    img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    score = niqe_metric(img_tensor)
    
    return score.item()

三、上采样方法

3.1 传统插值方法

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class InterpolationUpsampler:
    """
    传统插值上采样方法
    """
    
    @staticmethod
    def nearest_upsample(x, scale_factor):
        """
        最近邻插值
        
        简单快速,但会产生块状伪影
        """
        return F.interpolate(x, scale_factor=scale_factor, mode='nearest')
    
    @staticmethod
    def bilinear_upsample(x, scale_factor):
        """
        双线性插值
        
        结果平滑,但可能模糊
        """
        return F.interpolate(x, scale_factor=scale_factor, 
                           mode='bilinear', align_corners=False)
    
    @staticmethod
    def bicubic_upsample(x, scale_factor):
        """
        双三次插值
        
        比双线性更平滑,计算量稍大
        """
        return F.interpolate(x, scale_factor=scale_factor, 
                           mode='bicubic', align_corners=False)

3.2 转置卷积

python 复制代码
class TransposedConvUpsampler(nn.Module):
    """
    转置卷积上采样(也叫反卷积)
    
    可学习的上采样方式
    
    原理:
    - 在输入特征图周围/之间添加padding
    - 然后进行标准卷积
    
    问题:容易产生棋盘格伪影(checkerboard artifacts)
    """
    
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        
        # 转置卷积
        # kernel_size = 2 * scale_factor
        # stride = scale_factor
        self.deconv = nn.ConvTranspose2d(
            in_channels, out_channels,
            kernel_size=2*scale_factor,
            stride=scale_factor,
            padding=scale_factor // 2
        )
    
    def forward(self, x):
        return self.deconv(x)


class TransposedConvUpsamplerV2(nn.Module):
    """
    改进的转置卷积(减少棋盘格伪影)
    
    使用小kernel + 后续卷积
    """
    
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        
        self.deconv = nn.ConvTranspose2d(
            in_channels, out_channels,
            kernel_size=scale_factor,
            stride=scale_factor,
            padding=0
        )
        # 后续卷积平滑
        self.conv = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
    
    def forward(self, x):
        x = self.deconv(x)
        x = self.conv(x)
        return x

3.3 亚像素卷积(Sub-Pixel Convolution)

python 复制代码
class SubPixelUpsampler(nn.Module):
    """
    亚像素卷积上采样(PixelShuffle)
    
    原理:
    1. 先用卷积增加通道数(C → C * r²)
    2. 然后重排像素(Pixel Shuffle)
    
    优点:
    - 高效:大部分计算在低分辨率空间进行
    - 无棋盘格伪影
    - 是目前最流行的上采样方式
    
    H×W×(C*r²) → (H*r)×(W*r)×C
    """
    
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        
        # 先卷积增加通道数
        self.conv = nn.Conv2d(
            in_channels, 
            out_channels * (scale_factor ** 2),
            kernel_size=3, 
            padding=1
        )
        
        # 像素重排
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        return x


def pixel_shuffle_demo():
    """
    PixelShuffle原理演示
    """
    print("PixelShuffle原理:")
    print("=" * 50)
    
    # 假设输入是 1×1×4 (H=1, W=1, C=4)
    # scale_factor = 2
    # 输出是 2×2×1
    
    # 输入特征图的4个通道
    # [a, b, c, d] 重排为 2×2
    # [[a, b],
    #  [c, d]]
    
    x = torch.arange(1, 17).float().view(1, 4, 2, 2)
    print(f"输入形状: {x.shape}")  # [1, 4, 2, 2]
    
    ps = nn.PixelShuffle(2)
    y = ps(x)
    print(f"输出形状: {y.shape}")  # [1, 1, 4, 4]
    
    print("\n通道数减少为原来的1/r²,空间尺寸增加r倍")

pixel_shuffle_demo()

3.4 上采样策略对比

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    上采样策略对比                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Pre-upsampling(前置上采样):                                  │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  LR ──→ 插值放大 ──→ 深度网络 ──→ HR                    │   │
│  │  优点:简单直接                                          │   │
│  │  缺点:计算量大(在HR空间做卷积)                        │   │
│  │  代表:SRCNN, VDSR                                       │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  Post-upsampling(后置上采样):                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  LR ──→ 深度网络 ──→ 上采样层 ──→ HR                    │   │
│  │  优点:计算高效(在LR空间做卷积)                        │   │
│  │  缺点:上采样层设计关键                                  │   │
│  │  代表:ESPCN, EDSR, RCAN                                 │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  Progressive upsampling(渐进上采样):                          │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  LR ──→ 网络+×2 ──→ 网络+×2 ──→ HR (×4)                 │   │
│  │  优点:逐步重建,更稳定                                  │   │
│  │  缺点:网络较复杂                                        │   │
│  │  代表:LapSRN, ProSR                                     │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

四、损失函数

4.1 像素级损失

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class PixelLoss(nn.Module):
    """
    像素级损失函数
    
    直接衡量像素值的差异
    """
    
    def __init__(self, loss_type='l1'):
        super().__init__()
        self.loss_type = loss_type
    
    def forward(self, pred, target):
        if self.loss_type == 'l1':
            # L1损失(MAE)
            # 对异常值更鲁棒
            return F.l1_loss(pred, target)
        
        elif self.loss_type == 'l2':
            # L2损失(MSE)
            # 与PSNR直接相关
            # 容易导致过度平滑
            return F.mse_loss(pred, target)
        
        elif self.loss_type == 'charbonnier':
            # Charbonnier损失
            # L1的平滑近似,处处可微
            eps = 1e-6
            diff = pred - target
            return torch.mean(torch.sqrt(diff ** 2 + eps ** 2))
        
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")


class CharbonnierLoss(nn.Module):
    """
    Charbonnier损失(L1的可微近似)
    
    L_char = sqrt((pred - target)² + ε²)
    
    当|x|远大于ε时,近似于L1
    在0点处可微(L1在0点不可微)
    """
    
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    
    def forward(self, pred, target):
        diff = pred - target
        loss = torch.sqrt(diff ** 2 + self.eps ** 2)
        return torch.mean(loss)

4.2 感知损失(Perceptual Loss)

python 复制代码
import torchvision.models as models


class VGGPerceptualLoss(nn.Module):
    """
    VGG感知损失
    
    使用预训练VGG网络提取高层特征
    计算特征空间的距离
    
    优点:
    - 更关注语义和结构相似性
    - 生成的图像视觉效果更好
    
    缺点:
    - 可能导致颜色偏移
    - PSNR可能下降
    """
    
    def __init__(self, layer_weights=None, use_input_norm=True):
        super().__init__()
        
        # 加载预训练VGG19
        vgg = models.vgg19(pretrained=True).features
        
        # 定义要提取的层
        # conv1_2, conv2_2, conv3_4, conv4_4, conv5_4
        self.layer_indices = [2, 7, 16, 25, 34]
        
        # 默认权重
        if layer_weights is None:
            self.layer_weights = [0.1, 0.1, 1.0, 1.0, 1.0]
        else:
            self.layer_weights = layer_weights
        
        # 分割VGG为多个阶段
        self.stages = nn.ModuleList()
        prev_idx = 0
        for idx in self.layer_indices:
            self.stages.append(nn.Sequential(*list(vgg.children())[prev_idx:idx+1]))
            prev_idx = idx + 1
        
        # 冻结参数
        for param in self.parameters():
            param.requires_grad = False
        
        # 输入归一化(ImageNet均值和标准差)
        self.use_input_norm = use_input_norm
        self.register_buffer(
            'mean', 
            torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        )
        self.register_buffer(
            'std', 
            torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
        )
    
    def forward(self, pred, target):
        """
        计算感知损失
        """
        if self.use_input_norm:
            pred = (pred - self.mean) / self.std
            target = (target - self.mean) / self.std
        
        loss = 0.0
        pred_feat = pred
        target_feat = target
        
        for stage, weight in zip(self.stages, self.layer_weights):
            pred_feat = stage(pred_feat)
            target_feat = stage(target_feat)
            
            # L1距离
            loss += weight * F.l1_loss(pred_feat, target_feat)
        
        return loss


class ContentStyleLoss(nn.Module):
    """
    内容损失 + 风格损失
    
    内容损失:特征的L2距离
    风格损失:Gram矩阵的距离(捕捉纹理信息)
    """
    
    def __init__(self, content_weight=1.0, style_weight=0.1):
        super().__init__()
        self.content_weight = content_weight
        self.style_weight = style_weight
        
        # VGG特征提取器
        vgg = models.vgg19(pretrained=True).features[:16]
        self.vgg = vgg
        for param in self.vgg.parameters():
            param.requires_grad = False
    
    def gram_matrix(self, x):
        """
        计算Gram矩阵
        
        G = F @ F^T
        捕捉特征通道之间的相关性(纹理信息)
        """
        b, c, h, w = x.size()
        features = x.view(b, c, h * w)
        gram = torch.bmm(features, features.transpose(1, 2))
        return gram / (c * h * w)
    
    def forward(self, pred, target):
        pred_feat = self.vgg(pred)
        target_feat = self.vgg(target)
        
        # 内容损失
        content_loss = F.mse_loss(pred_feat, target_feat)
        
        # 风格损失
        pred_gram = self.gram_matrix(pred_feat)
        target_gram = self.gram_matrix(target_feat)
        style_loss = F.mse_loss(pred_gram, target_gram)
        
        return self.content_weight * content_loss + self.style_weight * style_loss

4.3 对抗损失(Adversarial Loss)

python 复制代码
class GANLoss(nn.Module):
    """
    GAN损失
    
    让生成的SR图像在判别器看来像真实HR图像
    """
    
    def __init__(self, gan_type='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.gan_type = gan_type
        self.real_label = real_label
        self.fake_label = fake_label
        
        if gan_type == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_type == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_type == 'wgan':
            self.loss = None  # Wasserstein距离
        elif gan_type == 'hinge':
            self.loss = None  # Hinge损失
        else:
            raise ValueError(f"Unknown GAN type: {gan_type}")
    
    def get_target_tensor(self, pred, is_real):
        """获取目标标签"""
        if is_real:
            return torch.full_like(pred, self.real_label)
        else:
            return torch.full_like(pred, self.fake_label)
    
    def forward(self, pred, is_real):
        """
        计算GAN损失
        
        Args:
            pred: 判别器输出
            is_real: 是否为真实样本
        """
        if self.gan_type in ['vanilla', 'lsgan']:
            target = self.get_target_tensor(pred, is_real)
            return self.loss(pred, target)
        
        elif self.gan_type == 'wgan':
            if is_real:
                return -pred.mean()
            else:
                return pred.mean()
        
        elif self.gan_type == 'hinge':
            if is_real:
                return F.relu(1.0 - pred).mean()
            else:
                return F.relu(1.0 + pred).mean()


class RelativisticGANLoss(nn.Module):
    """
    相对GAN损失(ESRGAN中使用)
    
    不只是判断"真假",而是判断"谁更真"
    
    D_Ra(x_r, x_f) = σ(C(x_r) - E[C(x_f)])
    D_Ra(x_f, x_r) = σ(C(x_f) - E[C(x_r)])
    """
    
    def __init__(self):
        super().__init__()
        self.loss = nn.BCEWithLogitsLoss()
    
    def forward(self, real_pred, fake_pred, is_discriminator):
        """
        Args:
            real_pred: 判别器对真实图像的预测
            fake_pred: 判别器对生成图像的预测
            is_discriminator: 是否是判别器的损失
        """
        # 相对预测
        real_relative = real_pred - fake_pred.mean()
        fake_relative = fake_pred - real_pred.mean()
        
        if is_discriminator:
            # 判别器:真实图像应该比假图像更真
            real_loss = self.loss(real_relative, torch.ones_like(real_relative))
            fake_loss = self.loss(fake_relative, torch.zeros_like(fake_relative))
            return (real_loss + fake_loss) / 2
        else:
            # 生成器:假图像应该比真实图像更真
            real_loss = self.loss(real_relative, torch.zeros_like(real_relative))
            fake_loss = self.loss(fake_relative, torch.ones_like(fake_relative))
            return (real_loss + fake_loss) / 2

4.4 综合损失

python 复制代码
class SRLoss(nn.Module):
    """
    综合超分辨率损失
    
    结合多种损失函数
    """
    
    def __init__(self, pixel_weight=1.0, perceptual_weight=0.1, 
                 adversarial_weight=0.01):
        super().__init__()
        
        self.pixel_weight = pixel_weight
        self.perceptual_weight = perceptual_weight
        self.adversarial_weight = adversarial_weight
        
        # 像素损失
        self.pixel_loss = nn.L1Loss()
        
        # 感知损失
        if perceptual_weight > 0:
            self.perceptual_loss = VGGPerceptualLoss()
        
        # 对抗损失
        if adversarial_weight > 0:
            self.gan_loss = GANLoss(gan_type='vanilla')
    
    def forward(self, pred, target, discriminator=None):
        """
        计算总损失
        """
        losses = {}
        
        # 像素损失
        losses['pixel'] = self.pixel_loss(pred, target)
        total_loss = self.pixel_weight * losses['pixel']
        
        # 感知损失
        if self.perceptual_weight > 0:
            losses['perceptual'] = self.perceptual_loss(pred, target)
            total_loss += self.perceptual_weight * losses['perceptual']
        
        # 对抗损失
        if self.adversarial_weight > 0 and discriminator is not None:
            fake_pred = discriminator(pred)
            losses['adversarial'] = self.gan_loss(fake_pred, is_real=True)
            total_loss += self.adversarial_weight * losses['adversarial']
        
        losses['total'] = total_loss
        
        return total_loss, losses

五、经典网络架构

5.1 SRCNN:开山之作

python 复制代码
class SRCNN(nn.Module):
    """
    SRCNN - 第一个深度学习超分辨率模型(2014)
    
    三层结构,对应传统方法的三个步骤:
    1. 特征提取(Patch extraction)
    2. 非线性映射(Non-linear mapping)
    3. 重建(Reconstruction)
    
    输入:双三次插值放大后的图像
    输出:高分辨率图像
    """
    
    def __init__(self, num_channels=3, feature_dim=64, mapping_dim=32):
        super().__init__()
        
        # 特征提取层
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, feature_dim, kernel_size=9, padding=4),
            nn.ReLU(inplace=True)
        )
        
        # 非线性映射层
        self.mapping = nn.Sequential(
            nn.Conv2d(feature_dim, mapping_dim, kernel_size=1),
            nn.ReLU(inplace=True)
        )
        
        # 重建层
        self.reconstruction = nn.Conv2d(mapping_dim, num_channels, kernel_size=5, padding=2)
    
    def forward(self, x):
        """
        前向传播
        
        输入x应该是双三次插值放大后的图像
        """
        feat = self.feature_extraction(x)
        mapped = self.mapping(feat)
        out = self.reconstruction(mapped)
        return out


def srcnn_example():
    """SRCNN使用示例"""
    model = SRCNN()
    
    # 先用bicubic放大,再输入网络
    lr_image = torch.randn(1, 3, 64, 64)
    lr_upscaled = F.interpolate(lr_image, scale_factor=4, mode='bicubic')
    
    sr_image = model(lr_upscaled)
    print(f"输入(放大后): {lr_upscaled.shape}")
    print(f"输出: {sr_image.shape}")

5.2 VDSR:深度网络+残差学习

python 复制代码
class VDSR(nn.Module):
    """
    VDSR - Very Deep Super Resolution(2016)
    
    关键创新:
    1. 更深的网络(20层)
    2. 全局残差学习(学习残差而非完整图像)
    3. 梯度裁剪解决梯度爆炸
    4. 可学习的多尺度(一个模型处理多种放大倍数)
    
    网络学习的是残差:R = HR - LR_upscaled
    输出:LR_upscaled + R
    """
    
    def __init__(self, num_channels=1, num_features=64, num_layers=20):
        super().__init__()
        
        layers = []
        
        # 第一层
        layers.append(nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1))
        layers.append(nn.ReLU(inplace=True))
        
        # 中间层
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
        
        # 最后一层
        layers.append(nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        """
        前向传播
        
        全局残差学习:输出 = 输入 + 网络输出(残差)
        """
        residual = self.network(x)
        return x + residual

5.3 残差块设计

python 复制代码
class BasicBlock(nn.Module):
    """
    基础残差块
    
    Conv → ReLU → Conv + Skip Connection
    """
    
    def __init__(self, num_features, kernel_size=3):
        super().__init__()
        
        self.conv1 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        return out + residual


class ResidualBlock(nn.Module):
    """
    标准残差块(去除BN层)
    
    EDSR发现:在超分辨率任务中,BN层会消耗大量显存且不提升性能
    """
    
    def __init__(self, num_features, kernel_size=3, res_scale=1.0):
        super().__init__()
        
        self.res_scale = res_scale
        
        self.conv1 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size, padding=kernel_size//2)
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        
        # 残差缩放(稳定训练)
        out = out * self.res_scale
        
        return out + residual


class ResidualDenseBlock(nn.Module):
    """
    残差密集块(RDB)- RDN中使用
    
    结合残差学习和密集连接
    每一层都接收前面所有层的特征
    """
    
    def __init__(self, num_features, growth_rate=32, num_layers=5):
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            in_channels = num_features + i * growth_rate
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels, growth_rate, 3, padding=1),
                nn.ReLU(inplace=True)
            ))
        
        # 局部特征融合
        self.lff = nn.Conv2d(
            num_features + num_layers * growth_rate,
            num_features,
            kernel_size=1
        )
    
    def forward(self, x):
        features = [x]
        
        for layer in self.layers:
            # 密集连接:拼接所有之前的特征
            out = layer(torch.cat(features, dim=1))
            features.append(out)
        
        # 局部特征融合
        out = self.lff(torch.cat(features, dim=1))
        
        # 残差连接
        return out + x

5.4 EDSR:去除BN的深度残差网络

python 复制代码
class EDSR(nn.Module):
    """
    EDSR - Enhanced Deep Residual Networks(2017)
    
    关键改进:
    1. 去除BN层(节省显存,提升性能)
    2. 残差缩放(稳定深层网络训练)
    3. 更宽的网络(256通道)
    4. 后置上采样(计算高效)
    """
    
    def __init__(self, num_channels=3, num_features=256, num_blocks=32, 
                 scale_factor=4, res_scale=0.1):
        super().__init__()
        
        self.scale_factor = scale_factor
        
        # 头部:特征提取
        self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)
        
        # 主体:残差块堆叠
        body = []
        for _ in range(num_blocks):
            body.append(ResidualBlock(num_features, res_scale=res_scale))
        body.append(nn.Conv2d(num_features, num_features, 3, padding=1))
        self.body = nn.Sequential(*body)
        
        # 尾部:上采样
        self.upsample = self._make_upsample(num_features, scale_factor)
        
        # 输出层
        self.tail = nn.Conv2d(num_features, num_channels, 3, padding=1)
    
    def _make_upsample(self, num_features, scale_factor):
        """构建上采样模块"""
        layers = []
        
        if scale_factor == 2 or scale_factor == 4:
            for _ in range(scale_factor // 2):
                layers.append(nn.Conv2d(num_features, num_features * 4, 3, padding=1))
                layers.append(nn.PixelShuffle(2))
        elif scale_factor == 3:
            layers.append(nn.Conv2d(num_features, num_features * 9, 3, padding=1))
            layers.append(nn.PixelShuffle(3))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # 特征提取
        head_feat = self.head(x)
        
        # 残差学习
        body_feat = self.body(head_feat)
        body_feat = body_feat + head_feat  # 全局残差
        
        # 上采样
        upsampled = self.upsample(body_feat)
        
        # 输出
        out = self.tail(upsampled)
        
        return out

5.5 RCAN:通道注意力网络

python 复制代码
class ChannelAttention(nn.Module):
    """
    通道注意力模块(CA)
    
    自适应地给不同通道分配不同的权重
    让网络关注更重要的特征通道
    
    结构:
    全局平均池化 → FC → ReLU → FC → Sigmoid
    """
    
    def __init__(self, num_features, reduction=16):
        super().__init__()
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(num_features, num_features // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features // reduction, num_features, 1, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # 全局信息聚合
        y = self.avg_pool(x)
        
        # 通道权重
        y = self.fc(y)
        
        # 通道加权
        return x * y


class RCAB(nn.Module):
    """
    残差通道注意力块(RCAB)
    
    Conv → ReLU → Conv → CA → + Skip
    """
    
    def __init__(self, num_features, reduction=16, res_scale=1.0):
        super().__init__()
        
        self.res_scale = res_scale
        
        self.conv1 = nn.Conv2d(num_features, num_features, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_features, num_features, 3, padding=1)
        self.ca = ChannelAttention(num_features, reduction)
    
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.ca(out)
        
        out = out * self.res_scale
        
        return out + residual


class ResidualGroup(nn.Module):
    """
    残差组(RG)
    
    多个RCAB + 短跳连接
    """
    
    def __init__(self, num_features, num_rcab=20, reduction=16):
        super().__init__()
        
        modules = []
        for _ in range(num_rcab):
            modules.append(RCAB(num_features, reduction))
        modules.append(nn.Conv2d(num_features, num_features, 3, padding=1))
        
        self.body = nn.Sequential(*modules)
    
    def forward(self, x):
        return self.body(x) + x


class RCAN(nn.Module):
    """
    RCAN - Residual Channel Attention Networks(2018)
    
    关键创新:
    1. 通道注意力机制
    2. 残差组(Residual Group)结构
    3. 长短跳连接
    """
    
    def __init__(self, num_channels=3, num_features=64, num_groups=10, 
                 num_rcab=20, reduction=16, scale_factor=4):
        super().__init__()
        
        # 头部
        self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)
        
        # 主体:残差组
        body = []
        for _ in range(num_groups):
            body.append(ResidualGroup(num_features, num_rcab, reduction))
        body.append(nn.Conv2d(num_features, num_features, 3, padding=1))
        self.body = nn.Sequential(*body)
        
        # 上采样
        self.upsample = self._make_upsample(num_features, scale_factor)
        
        # 尾部
        self.tail = nn.Conv2d(num_features, num_channels, 3, padding=1)
    
    def _make_upsample(self, num_features, scale_factor):
        layers = []
        if scale_factor in [2, 4, 8]:
            for _ in range(int(np.log2(scale_factor))):
                layers.append(nn.Conv2d(num_features, num_features * 4, 3, padding=1))
                layers.append(nn.PixelShuffle(2))
        elif scale_factor == 3:
            layers.append(nn.Conv2d(num_features, num_features * 9, 3, padding=1))
            layers.append(nn.PixelShuffle(3))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        head_feat = self.head(x)
        body_feat = self.body(head_feat) + head_feat
        upsampled = self.upsample(body_feat)
        out = self.tail(upsampled)
        return out

六、GAN-based超分辨率

6.1 SRGAN

python 复制代码
class SRResNet(nn.Module):
    """
    SRResNet - SRGAN的生成器
    
    基于残差块的深度网络
    """
    
    def __init__(self, num_channels=3, num_features=64, num_blocks=16, scale_factor=4):
        super().__init__()
        
        # 第一个卷积
        self.conv1 = nn.Sequential(
            nn.Conv2d(num_channels, num_features, 9, padding=4),
            nn.PReLU()
        )
        
        # 残差块
        self.res_blocks = nn.Sequential(
            *[ResidualBlockBN(num_features) for _ in range(num_blocks)]
        )
        
        # 第二个卷积
        self.conv2 = nn.Sequential(
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.BatchNorm2d(num_features)
        )
        
        # 上采样
        self.upsample = self._make_upsample(num_features, scale_factor)
        
        # 输出卷积
        self.conv3 = nn.Conv2d(num_features, num_channels, 9, padding=4)
    
    def _make_upsample(self, num_features, scale_factor):
        layers = []
        for _ in range(int(np.log2(scale_factor))):
            layers.extend([
                nn.Conv2d(num_features, num_features * 4, 3, padding=1),
                nn.PixelShuffle(2),
                nn.PReLU()
            ])
        return nn.Sequential(*layers)
    
    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.res_blocks(feat1)
        feat2 = self.conv2(feat2) + feat1
        upsampled = self.upsample(feat2)
        out = self.conv3(upsampled)
        return out


class ResidualBlockBN(nn.Module):
    """带BN的残差块(SRGAN使用)"""
    
    def __init__(self, num_features):
        super().__init__()
        
        self.conv_block = nn.Sequential(
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.BatchNorm2d(num_features),
            nn.PReLU(),
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.BatchNorm2d(num_features)
        )
    
    def forward(self, x):
        return x + self.conv_block(x)


class Discriminator(nn.Module):
    """
    SRGAN判别器
    
    VGG风格的分类网络
    判断输入是真实HR图像还是生成的SR图像
    """
    
    def __init__(self, input_shape=(3, 96, 96)):
        super().__init__()
        
        in_channels, in_height, in_width = input_shape
        
        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, 3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters
        
        self.features = nn.Sequential(*layers)
        
        # 计算特征图大小
        ds_size = in_height // 2 ** 4
        
        self.classifier = nn.Sequential(
            nn.Linear(512 * ds_size * ds_size, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )
    
    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        validity = self.classifier(features)
        return validity

6.2 ESRGAN

python 复制代码
class RRDB(nn.Module):
    """
    Residual in Residual Dense Block (RRDB)
    
    ESRGAN的核心模块
    比SRResNet的残差块更强大
    
    结构:3个RDB + 残差连接
    """
    
    def __init__(self, num_features, growth_rate=32, res_scale=0.2):
        super().__init__()
        
        self.res_scale = res_scale
        
        self.rdb1 = ResidualDenseBlock(num_features, growth_rate)
        self.rdb2 = ResidualDenseBlock(num_features, growth_rate)
        self.rdb3 = ResidualDenseBlock(num_features, growth_rate)
    
    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        
        return x + self.res_scale * out


class RRDBNet(nn.Module):
    """
    RRDB Network - ESRGAN的生成器
    
    关键改进:
    1. 去除BN层
    2. RRDB模块(更强的特征提取能力)
    3. 相对判别器
    """
    
    def __init__(self, num_channels=3, num_features=64, num_blocks=23, 
                 growth_rate=32, scale_factor=4):
        super().__init__()
        
        # 第一个卷积
        self.conv_first = nn.Conv2d(num_channels, num_features, 3, padding=1)
        
        # RRDB模块
        self.rrdb_blocks = nn.Sequential(
            *[RRDB(num_features, growth_rate) for _ in range(num_blocks)]
        )
        
        # 第二个卷积
        self.conv_body = nn.Conv2d(num_features, num_features, 3, padding=1)
        
        # 上采样
        self.upsample = nn.Sequential(
            nn.Conv2d(num_features, num_features * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_features, num_features * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # 最后的卷积
        self.conv_last = nn.Sequential(
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_features, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        feat_first = self.conv_first(x)
        feat_body = self.rrdb_blocks(feat_first)
        feat_body = self.conv_body(feat_body)
        feat = feat_first + feat_body
        upsampled = self.upsample(feat)
        out = self.conv_last(upsampled)
        return out

七、轻量级超分辨率网络

7.1 FSRCNN:快速超分辨率

python 复制代码
class FSRCNN(nn.Module):
    """
    FSRCNN - Fast Super-Resolution CNN(2016)
    
    关键改进:
    1. 后置上采样(在LR空间做卷积)
    2. 沙漏结构(通道先扩张后收缩)
    3. 转置卷积上采样
    
    比SRCNN快40倍以上
    """
    
    def __init__(self, scale_factor=4, num_channels=1, d=56, s=12, m=4):
        """
        Args:
            d: 特征提取层的通道数
            s: 收缩层的通道数
            m: 映射层的数量
        """
        super().__init__()
        
        # 特征提取
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=2),
            nn.PReLU()
        )
        
        # 收缩
        self.shrinking = nn.Sequential(
            nn.Conv2d(d, s, kernel_size=1),
            nn.PReLU()
        )
        
        # 非线性映射
        mapping = []
        for _ in range(m):
            mapping.extend([
                nn.Conv2d(s, s, kernel_size=3, padding=1),
                nn.PReLU()
            ])
        self.mapping = nn.Sequential(*mapping)
        
        # 扩展
        self.expanding = nn.Sequential(
            nn.Conv2d(s, d, kernel_size=1),
            nn.PReLU()
        )
        
        # 转置卷积上采样
        self.deconv = nn.ConvTranspose2d(
            d, num_channels,
            kernel_size=9,
            stride=scale_factor,
            padding=4,
            output_padding=scale_factor - 1
        )
    
    def forward(self, x):
        feat = self.feature_extraction(x)
        shrunk = self.shrinking(feat)
        mapped = self.mapping(shrunk)
        expanded = self.expanding(mapped)
        out = self.deconv(expanded)
        return out

7.2 IMDN:信息蒸馏网络

python 复制代码
class IMDModule(nn.Module):
    """
    信息多蒸馏模块(IMDN的核心)
    
    渐进式提取特征,每一步保留部分特征
    """
    
    def __init__(self, in_channels, distillation_rate=0.25):
        super().__init__()
        
        self.distilled_channels = int(in_channels * distillation_rate)
        self.remaining_channels = int(in_channels - self.distilled_channels)
        
        self.c1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.c2 = nn.Conv2d(self.remaining_channels, in_channels, 3, padding=1)
        self.c3 = nn.Conv2d(self.remaining_channels, in_channels, 3, padding=1)
        self.c4 = nn.Conv2d(self.remaining_channels, self.distilled_channels, 3, padding=1)
        
        self.act = nn.LeakyReLU(0.05, inplace=True)
        
        # 融合层
        self.fusion = nn.Conv2d(self.distilled_channels * 4, in_channels, 1)
    
    def forward(self, x):
        out1 = self.act(self.c1(x))
        distilled1, remaining1 = torch.split(out1, [self.distilled_channels, self.remaining_channels], dim=1)
        
        out2 = self.act(self.c2(remaining1))
        distilled2, remaining2 = torch.split(out2, [self.distilled_channels, self.remaining_channels], dim=1)
        
        out3 = self.act(self.c3(remaining2))
        distilled3, remaining3 = torch.split(out3, [self.distilled_channels, self.remaining_channels], dim=1)
        
        distilled4 = self.act(self.c4(remaining3))
        
        # 拼接所有蒸馏出的特征
        out = torch.cat([distilled1, distilled2, distilled3, distilled4], dim=1)
        out = self.fusion(out)
        
        return out + x


class IMDN(nn.Module):
    """
    IMDN - Information Multi-Distillation Network(2019)
    
    轻量级超分辨率网络
    参数量约715K,性能优秀
    """
    
    def __init__(self, num_channels=3, num_features=64, num_blocks=6, scale_factor=4):
        super().__init__()
        
        # 特征提取
        self.conv_first = nn.Conv2d(num_channels, num_features, 3, padding=1)
        
        # IMDB模块
        self.imdbs = nn.ModuleList([
            IMDModule(num_features) for _ in range(num_blocks)
        ])
        
        # 特征融合
        self.fusion = nn.Conv2d(num_features * num_blocks, num_features, 1)
        
        # 上采样
        self.upsample = SubPixelUpsampler(num_features, num_features, scale_factor)
        
        # 输出
        self.conv_last = nn.Conv2d(num_features, num_channels, 3, padding=1)
    
    def forward(self, x):
        feat_first = self.conv_first(x)
        
        features = []
        feat = feat_first
        for imdb in self.imdbs:
            feat = imdb(feat)
            features.append(feat)
        
        # 拼接所有块的输出
        feat_cat = torch.cat(features, dim=1)
        feat_fused = self.fusion(feat_cat) + feat_first
        
        upsampled = self.upsample(feat_fused)
        out = self.conv_last(upsampled)
        
        return out

八、完整训练流程

8.1 数据集准备

python 复制代码
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class SRDataset(Dataset):
    """
    超分辨率数据集
    """
    
    def __init__(self, hr_dir, scale_factor=4, patch_size=96, augment=True):
        """
        Args:
            hr_dir: 高分辨率图像目录
            scale_factor: 放大倍数
            patch_size: HR patch大小
            augment: 是否数据增强
        """
        self.hr_dir = hr_dir
        self.scale_factor = scale_factor
        self.patch_size = patch_size
        self.augment = augment
        
        self.lr_patch_size = patch_size // scale_factor
        
        # 获取所有图像路径
        self.image_paths = []
        for f in os.listdir(hr_dir):
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                self.image_paths.append(os.path.join(hr_dir, f))
        
        # 预处理
        self.to_tensor = transforms.ToTensor()
    
    def __len__(self):
        return len(self.image_paths)
    
    def random_crop(self, hr_image):
        """随机裁剪HR patch"""
        w, h = hr_image.size
        
        # 确保能裁剪出完整的patch
        x = np.random.randint(0, w - self.patch_size + 1)
        y = np.random.randint(0, h - self.patch_size + 1)
        
        hr_patch = hr_image.crop((x, y, x + self.patch_size, y + self.patch_size))
        
        return hr_patch
    
    def augment_patch(self, hr_patch, lr_patch):
        """数据增强"""
        # 随机水平翻转
        if np.random.random() < 0.5:
            hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
            lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 随机垂直翻转
        if np.random.random() < 0.5:
            hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
            lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
        
        # 随机旋转90度
        if np.random.random() < 0.5:
            angle = np.random.choice([90, 180, 270])
            hr_patch = hr_patch.rotate(angle)
            lr_patch = lr_patch.rotate(angle)
        
        return hr_patch, lr_patch
    
    def __getitem__(self, idx):
        # 加载HR图像
        hr_image = Image.open(self.image_paths[idx]).convert('RGB')
        
        # 随机裁剪
        hr_patch = self.random_crop(hr_image)
        
        # 下采样生成LR patch
        lr_patch = hr_patch.resize(
            (self.lr_patch_size, self.lr_patch_size),
            Image.BICUBIC
        )
        
        # 数据增强
        if self.augment:
            hr_patch, lr_patch = self.augment_patch(hr_patch, lr_patch)
        
        # 转换为tensor
        hr_tensor = self.to_tensor(hr_patch)
        lr_tensor = self.to_tensor(lr_patch)
        
        return lr_tensor, hr_tensor


class SRTestDataset(Dataset):
    """测试数据集(不裁剪)"""
    
    def __init__(self, lr_dir, hr_dir=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        
        self.lr_paths = sorted([
            os.path.join(lr_dir, f) for f in os.listdir(lr_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])
        
        if hr_dir:
            self.hr_paths = sorted([
                os.path.join(hr_dir, f) for f in os.listdir(hr_dir)
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ])
        else:
            self.hr_paths = None
        
        self.to_tensor = transforms.ToTensor()
    
    def __len__(self):
        return len(self.lr_paths)
    
    def __getitem__(self, idx):
        lr_image = Image.open(self.lr_paths[idx]).convert('RGB')
        lr_tensor = self.to_tensor(lr_image)
        
        if self.hr_paths:
            hr_image = Image.open(self.hr_paths[idx]).convert('RGB')
            hr_tensor = self.to_tensor(hr_image)
            return lr_tensor, hr_tensor, os.path.basename(self.lr_paths[idx])
        
        return lr_tensor, os.path.basename(self.lr_paths[idx])

8.2 训练器

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


class SRTrainer:
    """
    超分辨率训练器
    """
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 创建模型
        self.model = self._build_model().to(self.device)
        
        # 优化器
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config['lr'],
            betas=(0.9, 0.999)
        )
        
        # 学习率调度
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=config['lr_milestones'],
            gamma=0.5
        )
        
        # 损失函数
        self.criterion = self._build_loss()
        
        # TensorBoard
        self.writer = SummaryWriter(config['log_dir'])
        
        # 最佳指标
        self.best_psnr = 0
    
    def _build_model(self):
        """构建模型"""
        model_name = self.config.get('model', 'edsr')
        
        if model_name == 'edsr':
            return EDSR(
                num_features=self.config.get('num_features', 64),
                num_blocks=self.config.get('num_blocks', 16),
                scale_factor=self.config['scale_factor']
            )
        elif model_name == 'rcan':
            return RCAN(
                num_features=self.config.get('num_features', 64),
                num_groups=self.config.get('num_groups', 10),
                scale_factor=self.config['scale_factor']
            )
        else:
            raise ValueError(f"Unknown model: {model_name}")
    
    def _build_loss(self):
        """构建损失函数"""
        loss_type = self.config.get('loss', 'l1')
        
        if loss_type == 'l1':
            return nn.L1Loss()
        elif loss_type == 'l2':
            return nn.MSELoss()
        elif loss_type == 'charbonnier':
            return CharbonnierLoss()
        else:
            raise ValueError(f"Unknown loss: {loss_type}")
    
    def train_epoch(self, train_loader, epoch):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        for batch_idx, (lr, hr) in enumerate(pbar):
            lr = lr.to(self.device)
            hr = hr.to(self.device)
            
            # 前向传播
            sr = self.model(lr)
            loss = self.criterion(sr, hr)
            
            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # 更新进度条
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # 记录到TensorBoard
            global_step = epoch * len(train_loader) + batch_idx
            self.writer.add_scalar('train/loss', loss.item(), global_step)
        
        return total_loss / len(train_loader)
    
    @torch.no_grad()
    def validate(self, val_loader):
        """验证"""
        self.model.eval()
        
        total_psnr = 0
        total_ssim = 0
        count = 0
        
        for lr, hr, _ in val_loader:
            lr = lr.to(self.device)
            hr = hr.to(self.device)
            
            sr = self.model(lr)
            
            # 计算指标
            for i in range(sr.size(0)):
                sr_np = sr[i].cpu().numpy().transpose(1, 2, 0) * 255
                hr_np = hr[i].cpu().numpy().transpose(1, 2, 0) * 255
                
                total_psnr += calculate_psnr(sr_np, hr_np)
                total_ssim += calculate_ssim(sr_np[..., 0], hr_np[..., 0])
                count += 1
        
        avg_psnr = total_psnr / count
        avg_ssim = total_ssim / count
        
        return avg_psnr, avg_ssim
    
    def train(self, train_loader, val_loader, num_epochs):
        """完整训练流程"""
        for epoch in range(1, num_epochs + 1):
            # 训练
            train_loss = self.train_epoch(train_loader, epoch)
            print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}")
            
            # 更新学习率
            self.scheduler.step()
            
            # 验证
            if epoch % self.config['val_interval'] == 0:
                psnr, ssim = self.validate(val_loader)
                print(f"Validation: PSNR = {psnr:.2f}, SSIM = {ssim:.4f}")
                
                self.writer.add_scalar('val/psnr', psnr, epoch)
                self.writer.add_scalar('val/ssim', ssim, epoch)
                
                # 保存最佳模型
                if psnr > self.best_psnr:
                    self.best_psnr = psnr
                    self.save_checkpoint('best.pth', epoch)
                    print(f"New best model! PSNR = {psnr:.2f}")
            
            # 定期保存
            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(f'epoch_{epoch}.pth', epoch)
    
    def save_checkpoint(self, filename, epoch):
        """保存检查点"""
        os.makedirs(self.config['checkpoint_dir'], exist_ok=True)
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_psnr': self.best_psnr
        }, os.path.join(self.config['checkpoint_dir'], filename))
    
    def load_checkpoint(self, path):
        """加载检查点"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_psnr = checkpoint.get('best_psnr', 0)
        return checkpoint.get('epoch', 0)

8.3 训练脚本

python 复制代码
def main():
    """主函数"""
    
    config = {
        # 数据
        'train_hr_dir': './data/DIV2K/train_HR',
        'val_lr_dir': './data/Set5/LR_bicubic/X4',
        'val_hr_dir': './data/Set5/HR',
        
        # 模型
        'model': 'edsr',
        'num_features': 64,
        'num_blocks': 16,
        'scale_factor': 4,
        
        # 训练
        'batch_size': 16,
        'patch_size': 96,
        'num_epochs': 300,
        'lr': 1e-4,
        'lr_milestones': [100, 200],
        'loss': 'l1',
        
        # 其他
        'val_interval': 10,
        'save_interval': 50,
        'log_dir': './logs',
        'checkpoint_dir': './checkpoints',
        'num_workers': 4
    }
    
    # 创建数据集
    train_dataset = SRDataset(
        config['train_hr_dir'],
        scale_factor=config['scale_factor'],
        patch_size=config['patch_size']
    )
    
    val_dataset = SRTestDataset(
        config['val_lr_dir'],
        config['val_hr_dir']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False
    )
    
    # 创建训练器
    trainer = SRTrainer(config)
    
    # 开始训练
    trainer.train(train_loader, val_loader, config['num_epochs'])
    
    print("Training completed!")


if __name__ == '__main__':
    main()

九、前沿研究方向

9.1 真实世界超分辨率

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                 真实世界超分辨率挑战                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  问题:                                                         │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 传统方法假设:LR = Bicubic(HR)                           │   │
│  │ 真实退化包含:模糊、噪声、压缩、传感器噪声...            │   │
│  │ 在Bicubic数据上训练的模型在真实图像上效果差              │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  解决方案:                                                     │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 1. 更复杂的退化模型                                      │   │
│  │    模糊 + 下采样 + 噪声 + JPEG压缩                       │   │
│  │                                                          │   │
│  │ 2. 盲超分辨率                                            │   │
│  │    不假设已知退化核,自动估计退化                        │   │
│  │                                                          │   │
│  │ 3. 真实数据集                                            │   │
│  │    RealSR:用不同焦距拍摄的真实图像对                    │   │
│  │                                                          │   │
│  │ 4. 无监督学习                                            │   │
│  │    不需要成对数据的训练方法                              │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

9.2 Transformer超分辨率

复制代码
近年来Transformer在超分辨率领域的应用:

IPT (2021): 
- 预训练图像处理Transformer
- 大规模ImageNet预训练
- 参数量巨大(115M)

SwinIR (2021):
- 基于Swin Transformer
- 局部窗口注意力 + 移位窗口
- 参数量适中(11.8M)
- 在多个benchmark上达到SOTA

ESRT (2021):
- 轻量级Transformer
- 仅751K参数
- 效率和性能的良好平衡

9.3 领域特定应用

复制代码
人脸超分辨率:
- 利用人脸先验(关键点、解析图)
- 身份保持损失
- 面部属性保持

医学图像超分辨率:
- CT/MRI图像增强
- 3D体数据超分辨率
- 保持医学诊断信息

遥感图像超分辨率:
- 处理大尺寸图像
- 多光谱/高光谱数据
- 时序数据融合

视频超分辨率:
- 利用时间冗余
- 光流对齐
- 实时处理需求

十、总结与展望

10.1 核心要点回顾

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    图像超分辨率核心知识                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  问题本质:                                                     │
│  • 从低分辨率图像恢复高分辨率图像                               │
│  • 病态问题,需要学习图像先验                                   │
│                                                                 │
│  关键技术:                                                     │
│  • 上采样:PixelShuffle最常用                                  │
│  • 网络设计:残差学习、注意力机制                               │
│  • 损失函数:像素损失+感知损失+对抗损失                        │
│                                                                 │
│  评估指标:                                                     │
│  • PSNR/SSIM:重建精度                                         │
│  • LPIPS/NIQE:感知质量                                        │
│                                                                 │
│  发展趋势:                                                     │
│  • 真实世界退化                                                 │
│  • 轻量级网络                                                   │
│  • Transformer架构                                              │
│  • 领域特定应用                                                 │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

10.2 未来方向

复制代码
1. 轻量高效
   - 边缘设备部署
   - 实时处理需求
   - 模型压缩量化

2. 真实场景
   - 复杂退化建模
   - 盲超分辨率
   - 少样本/无监督学习

3. 新架构探索
   - Transformer + CNN混合
   - 扩散模型
   - 神经隐式表示

4. 联合任务
   - 超分辨率 + 去噪
   - 超分辨率 + 去模糊
   - 超分辨率 + 目标检测

希望这篇文章帮助你全面理解了图像超分辨率技术!如有问题,欢迎评论区交流。


相关推荐
FL16238631292 小时前
MMA综合格斗动作检测数据集VOC+YOLO格式1780张16类别
人工智能·yolo·机器学习
格林威2 小时前
Baumer相机铸件气孔与缩松识别:提升铸造良品率的 6 个核心算法,附 OpenCV+Halcon 实战代码!
人工智能·opencv·算法·安全·计算机视觉·堡盟相机·baumer相机
光羽隹衡2 小时前
计算机视觉——Opencv(图像金字塔)
人工智能·opencv·计算机视觉
zhengfei6112 小时前
人工智能驱动的暗网开源情报工具
人工智能·开源
余俊晖2 小时前
多模态视觉语言模型:Molmo2训练数据、训练配方
人工智能·语言模型·自然语言处理
葫三生2 小时前
存在之思:三生原理与现象学对话可能?
数据库·人工智能·神经网络·算法·区块链
UI设计兰亭妙微2 小时前
UI 设计新范式:从国际案例看体验与商业的融合之道
人工智能·ui·b端设计
老蒋每日coding2 小时前
AIGC领域多模态大模型的知识图谱构建:技术框架与实践路径
人工智能·aigc·知识图谱
布兰妮甜2 小时前
Photoshop中通过图层混合模式实现图像元素透明度渐变过渡的完整指南
人工智能·ui·生活·photoshop·文化