【图像超分】论文复现:轻量化超分 | RLFN的Pytorch源码复现,跑通源码,整合到EDSR-PyTorch中进行训练、测试
前言
论文题目:Residual Local Feature Network for Efficient Super-Resolution
论文地址:https://arxiv.org/abs/2205.07514
论文源码:https://github.com/bytedance/RLFN
NTIRE 2022高效超分辨率挑战赛运行赛道第一名
摘要 :基于深度学习的方法在单幅图像超分辨率(SISR)中取得了很好的效果。然而,高效超分辨率的最新进展主要集中在减少参数数量和FLOPs(每秒所执行的浮点运算次数,用来衡量计算机的计算能力以及模型的复杂度),并通过复杂的层连接策略提高特征利用率来聚合更强大的特征。这些结构可能不是实现更高运行速度所必需的,这使得它们难以部署到资源受限的设备上。本文提出了一种新的残差局部特征网络(RLFN)。主要思想是使用三层卷积进行残差局部特征学习,简化特征聚合,在模型性能和推理时间之间实现了很好的权衡。此外,我们回顾了流行的对比损失,并观察到其特征提取器的中间特征的选择对性能有很大影响。此外,我们还提出了一种新的多阶段暖启动训练策略 。在每个阶段,利用前几个阶段的预训练权值来提高模型的性能。结合改进的对比损失和训练策略,所提出的RLFN在运行时间方面优于所有最先进的高效图像SR模型,同时保持SR的PSNR和SSIM。此外,我们还获得了NTIRE 2022 高效超分辨率挑战赛运行赛道第一名。
网络结构
RLFN主要由三部分组成:第一部分特征提取卷积、多个堆叠残差局部特征块(rlfb)和重构模块。

数据
| model | Runtime[ms] | Params[M] | Flops[G] | Acts[M] | ~GPU Mem[M]~ |
|---|---|---|---|---|---|
| RLFN_ntire | 27.11 | 0.317 | 19.70 | 80.05 | 377.91 |
模型代码
模型现在有三个版本,分别为rlfn.py,rlfn_ntire.py,rlfn_s.py
rlfn.py
py
# -*- coding: utf-8 -*-
# Copyright 2022 ByteDance
import torch.nn as nn
from model import block
class RLFN(nn.Module):
"""
Residual Local Feature Network (RLFN)
Model definition of RLFN in `Residual Local Feature Network for
Efficient Super-Resolution`
"""
def __init__(self,
in_channels=3,
out_channels=3,
feature_channels=52,
upscale=4):
super(RLFN, self).__init__()
self.conv_1 = block.conv_layer(in_channels,
feature_channels,
kernel_size=3)
self.block_1 = block.RLFB(feature_channels)
self.block_2 = block.RLFB(feature_channels)
self.block_3 = block.RLFB(feature_channels)
self.block_4 = block.RLFB(feature_channels)
self.block_5 = block.RLFB(feature_channels)
self.block_6 = block.RLFB(feature_channels)
self.conv_2 = block.conv_layer(feature_channels,
feature_channels,
kernel_size=3)
self.upsampler = block.pixelshuffle_block(feature_channels,
out_channels,
upscale_factor=upscale)
def forward(self, x):
out_feature = self.conv_1(x)
out_b1 = self.block_1(out_feature)
out_b2 = self.block_2(out_b1)
out_b3 = self.block_3(out_b2)
out_b4 = self.block_4(out_b3)
out_b5 = self.block_5(out_b4)
out_b6 = self.block_6(out_b5)
out_low_resolution = self.conv_2(out_b6) + out_feature
output = self.upsampler(out_low_resolution)
return output
rlfn_ntire.py
py
# -*- coding: utf-8 -*-
# Copyright 2022 ByteDance
import torch.nn as nn
from model import block
class RLFN_Prune(nn.Module):
"""
Residual Local Feature Network (RLFN)
Model definition of RLFN in NTIRE 2022 Efficient SR Challenge
"""
def __init__(self,
in_channels=3,
out_channels=3,
feature_channels=46,
mid_channels=48,
upscale=4):
super(RLFN_Prune, self).__init__()
self.conv_1 = block.conv_layer(in_channels,
feature_channels,
kernel_size=3)
self.block_1 = block.RLFB(feature_channels, mid_channels)
self.block_2 = block.RLFB(feature_channels, mid_channels)
self.block_3 = block.RLFB(feature_channels, mid_channels)
self.block_4 = block.RLFB(feature_channels, mid_channels)
self.conv_2 = block.conv_layer(feature_channels,
feature_channels,
kernel_size=3)
self.upsampler = block.pixelshuffle_block(feature_channels,
out_channels,
upscale_factor=upscale)
def forward(self, x):
out_feature = self.conv_1(x)
out_b1 = self.block_1(out_feature)
out_b2 = self.block_2(out_b1)
out_b3 = self.block_3(out_b2)
out_b4 = self.block_4(out_b3)
out_low_resolution = self.conv_2(out_b4) + out_feature
output = self.upsampler(out_low_resolution)
return output
rlfn_s.py
py
# -*- coding: utf-8 -*-
# Copyright 2022 ByteDance
import torch.nn as nn
from model import block
class RLFN_S(nn.Module):
"""
Residual Local Feature Network (RLFN)
Model definition of RLFN_S in `Residual Local Feature Network for
Efficient Super-Resolution`
"""
def __init__(self,
in_channels=3,
out_channels=3,
feature_channels=48,
upscale=4):
super(RLFN_S, self).__init__()
self.conv_1 = block.conv_layer(in_channels,
feature_channels,
kernel_size=3)
self.block_1 = block.RLFB(feature_channels)
self.block_2 = block.RLFB(feature_channels)
self.block_3 = block.RLFB(feature_channels)
self.block_4 = block.RLFB(feature_channels)
self.block_5 = block.RLFB(feature_channels)
self.block_6 = block.RLFB(feature_channels)
self.conv_2 = block.conv_layer(feature_channels,
feature_channels,
kernel_size=3)
self.upsampler = block.pixelshuffle_block(feature_channels,
out_channels,
upscale_factor=upscale)
def forward(self, x):
out_feature = self.conv_1(x)
out_b1 = self.block_1(out_feature)
out_b2 = self.block_2(out_b1)
out_b3 = self.block_3(out_b2)
out_b4 = self.block_4(out_b3)
out_b5 = self.block_5(out_b4)
out_b6 = self.block_6(out_b5)
out_low_resolution = self.conv_2(out_b6) + out_feature
output = self.upsampler(out_low_resolution)
return output
block.py
py
# -*- coding: utf-8 -*-
# 编码声明,确保文件支持中文等Unicode字符
# Copyright 2022 ByteDance
# 版权声明,归属ByteDance公司
from collections import OrderedDict
# 导入OrderedDict,用于创建有序字典
import torch.nn as nn
# 导入PyTorch的神经网络模块
import torch.nn.functional as F
# 导入PyTorch的函数式接口,包含各种激活函数、池化等操作
def _make_pair(value):
# 将输入值转换为长度为2的元组(如果输入是整数)
if isinstance(value, int):
# 如果输入是整数,将其转换为两个相同元素的元组
value = (value,) * 2
return value
def conv_layer(in_channels,
out_channels,
kernel_size,
bias=True):
"""
重写卷积层,实现自适应填充(padding)
"""
kernel_size = _make_pair(kernel_size)
# 计算填充大小,使卷积前后特征图尺寸不变(当步长为1时)
padding = (int((kernel_size[0] - 1) / 2),
int((kernel_size[1] - 1) / 2))
# 返回创建的卷积层
return nn.Conv2d(in_channels,
out_channels,
kernel_size,
padding=padding,
bias=bias)
def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
"""
激活函数层工厂函数,支持['relu', 'lrelu', 'prelu']三种类型
参数
----------
act_type: str
激活函数类型,必须是['relu', 'lrelu', 'prelu']中的一种
inplace: bool
是否使用inplace操作(节省内存)
neg_slope: float
'lrelu'或'prelu'在负区间的斜率
n_prelu: int
'prelu'的参数数量
----------
"""
act_type = act_type.lower()
# 转换为小写,确保输入不区分大小写
if act_type == 'relu':
# 创建ReLU激活层
layer = nn.ReLU(inplace)
elif act_type == 'lrelu':
# 创建LeakyReLU激活层
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
# 创建PReLU激活层
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
# 不支持的激活函数类型则抛出异常
raise NotImplementedError(
'激活层 [{:s}] 未实现'.format(act_type))
return layer
def sequential(*args):
"""
将传入的模块按顺序添加到Sequential容器中
参数
----------
args: 按顺序传入的模块定义
-------
"""
if len(args) == 1:
# 如果只有一个参数
if isinstance(args[0], OrderedDict):
# 不支持OrderedDict作为输入
raise NotImplementedError(
'sequential不支持OrderedDict输入')
return args[0]
modules = []
# 遍历所有传入的模块
for module in args:
if isinstance(module, nn.Sequential):
# 如果是Sequential容器,将其内部模块展开添加
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
# 如果是单个模块,直接添加
modules.append(module)
# 创建并返回新的Sequential容器
return nn.Sequential(*modules)
def pixelshuffle_block(in_channels,
out_channels,
upscale_factor=2,
kernel_size=3):
"""
根据 upscale_factor 对特征进行上采样(像素重排)
"""
# 创建卷积层,输出通道数为目标通道数乘以 upscale_factor 的平方
conv = conv_layer(in_channels,
out_channels * (upscale_factor ** 2),
kernel_size)
# 创建像素重排层,实现上采样
pixel_shuffle = nn.PixelShuffle(upscale_factor)
# 将卷积层和像素重排层按顺序组合
return sequential(conv, pixel_shuffle)
class ESA(nn.Module):
"""
增强空间注意力机制(ESA)的修改版,源自论文
`Residual Feature Aggregation Network for Image Super-Resolution`
注:此处删除了原实现中未使用的`conv_max`和`conv3_`相关代码
"""
def __init__(self, esa_channels, n_feats, conv):#esa_channels:16
super(ESA, self).__init__()
# 初始化ESA通道数
f = esa_channels
# 1x1卷积压缩通道数
self.conv1 = conv(n_feats, f, kernel_size=1)
# 1x1卷积处理跳跃连接的特征
self.conv_f = conv(f, f, kernel_size=1)
# 3x3卷积,步长为2,无填充(用于降采样)
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
# 3x3卷积,带填充(用于特征提取)
self.conv3 = conv(f, f, kernel_size=3, padding=1)
# 1x1卷积恢复通道数
self.conv4 = conv(f, n_feats, kernel_size=1)
# Sigmoid激活函数,生成注意力权重
self.sigmoid = nn.Sigmoid()
# ReLU激活函数
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 对输入特征进行通道压缩
c1_ = (self.conv1(x))
# 降采样
c1 = self.conv2(c1_)
# 最大池化进一步降采样
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
# 特征提取
c3 = self.conv3(v_max)
# 上采样恢复到原始特征图尺寸
c3 = F.interpolate(c3, (x.size(2), x.size(3)),
mode='bilinear', align_corners=False)
# 处理跳跃连接的特征
cf = self.conv_f(c1_)
# 特征融合并恢复通道数
c4 = self.conv4(c3 + cf)
# 生成注意力权重
m = self.sigmoid(c4)
# 注意力加权
return x * m
class RLFB(nn.Module):
"""
残差局部特征块(RLFB)
"""
def __init__(self,
in_channels,
mid_channels=None,
out_channels=None,
esa_channels=16):
super(RLFB, self).__init__()
# 如果未指定中间通道数,默认与输入通道数相同
if mid_channels is None:
mid_channels = in_channels
# 如果未指定输出通道数,默认与输入通道数相同
if out_channels is None:
out_channels = in_channels
# 第一个卷积层(3x3)
self.c1_r = conv_layer(in_channels, mid_channels, 3)
# 第二个卷积层(3x3)
self.c2_r = conv_layer(mid_channels, mid_channels, 3)
# 第三个卷积层(3x3),恢复到输入通道数
self.c3_r = conv_layer(mid_channels, in_channels, 3)
# 1x1卷积调整通道数到输出通道数
self.c5 = conv_layer(in_channels, out_channels, 1)
# ESA注意力模块
self.esa = ESA(esa_channels, out_channels, nn.Conv2d)
# LeakyReLU激活函数
self.act = activation('lrelu', neg_slope=0.05)
def forward(self, x):
# 第一层卷积
out = (self.c1_r(x))
# 激活函数
out = self.act(out)
# 第二层卷积
out = (self.c2_r(out))
# 激活函数
out = self.act(out)
# 第三层卷积
out = (self.c3_r(out))
# 激活函数
out = self.act(out)
# 残差连接(跳跃连接)
out = out + x
# 通过1x1卷积和ESA注意力模块
out = self.esa(self.c5(out))
return out
复现过程
准备工作
首先配置一下EDSR的环境
下载DIV2K的数据集,数据集地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/

下载RLFN的项目,网址:https://github.com/bytedance/RLFN

下载EDSR的项目,网址:https://github.com/sanghyun-son/EDSR-PyTorch

训练

在这个文件中有一个dir_data

改为自己下载的数据集的位置
然后在FMEN中复制rlfn.py 和block.py到EDSR中的src中的model

在src中打开终端
bash
python main.py --model FMEN --scale 4 --patch_size 48 --epochs 300 --save RLFN_baseline_x4 --reset
测试
测试就在RLFN项目中有一个test_demo.py中测试
test_demo.py
py
# 版权声明:原作者为Yawei Li等人,协议为MIT,可能经字节跳动修改
# 导入操作系统路径处理库,用于处理文件/文件夹路径
import os.path
# 导入日志记录库,用于输出运行过程中的关键信息(如模型参数、处理进度)
import logging
# 从collections导入有序字典类,用于有序存储测试结果(如运行时间)
from collections import OrderedDict
# 导入PyTorch库,深度学习框架,用于加载模型、处理张量和GPU计算
import torch
# 从utils工具包导入日志配置函数,用于初始化日志格式和存储路径
from utils import utils_logger
# 从utils工具包导入图像处理函数(如读取图片、张量转换、保存图片),命名为util简化调用
from utils import utils_image as util
# 从utils工具包导入模型统计函数,用于计算模型的计算量(FLOPs)和激活次数
from utils.model_summary import get_model_flops, get_model_activation
# 从model模型包导入RLFN_Prune类,这是超分辨率任务使用的核心模型
from model.rlfn_ntire import RLFN_Prune
from model.rlfn import RLFN
def main():
# 1. 初始化日志系统:日志名称为"NTIRE2022-EfficientSR",日志保存到"NTIRE2022-EfficientSR.log"文件
utils_logger.logger_info('NTIRE2022-EfficientSR', log_path='NTIRE2022-EfficientSR.log')
# 获取日志实例,后续用logger.info()输出关键信息到控制台和日志文件
logger = logging.getLogger('NTIRE2022-EfficientSR')
# --------------------------------
# 2. 基础配置:设置数据路径、GPU/CPU环境
# --------------------------------
# 注释:原测试集为DIV2K的901-1000张图,当前改为自定义数据路径
# 测试数据根目录:拼接当前工作目录(os.getcwd())和"data"文件夹,即"./data"
testsets = os.path.join(os.getcwd(), 'data')
# 低分辨率(LR)图片文件夹名:自定义为"shangbo_Low_images",对应路径为"./data/shangbo_Low_images"
testset_L = 'Urban100/image_SRF_4/LR'
# 初始化当前GPU设备(若有多个GPU,默认使用第0个)
torch.cuda.current_device()
# 清空GPU缓存,释放未使用的显存,避免显存不足问题
torch.cuda.empty_cache()
# 禁用cudnn的benchmark模式:避免首次运行时花时间优化,适合输入图片尺寸不固定的场景
torch.backends.cudnn.benchmark = False
# 选择计算设备:优先使用GPU(cuda),若无GPU则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --------------------------------
# 3. 加载超分辨率模型
# --------------------------------
# 模型权重文件路径:拼接"model_zoo"文件夹和"rlfn_ntire_x4.pth",即"./model_zoo/rlfn_ntire_x4.pth"
model_path = os.path.join('model_zoo', 'rlfn_ntire_x4.pth')
# 初始化RLFN_Prune模型:输入通道数3(RGB彩色图),输出通道数3(同样为RGB图)
model = RLFN_Prune(in_channels=3, out_channels=3)
# 加载预训练权重到模型:strict=True表示权重文件的键必须与模型参数完全匹配,避免加载错误
model.load_state_dict(torch.load(model_path), strict=True)
# 设置模型为评估模式(eval()):关闭训练时的 dropout、批量归一化(BN)更新,确保推理结果稳定
model.eval()
# 冻结模型所有参数:禁用梯度计算,减少显存占用,加速推理
for k, v in model.named_parameters():
v.requires_grad = False
# 将模型移动到选定的设备(GPU/CPU),确保计算在目标设备上进行
model = model.to(device)
# 计算并记录模型参数总数:sum()累加所有参数的元素数量(numel())
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
# 将参数数量输出到日志
logger.info('Params number: {}'.format(number_parameters))
# --------------------------------
# 4. 配置图片读取路径和结果保存路径
# --------------------------------
# 低分辨率图片文件夹的完整路径:拼接数据根目录和低分辨率文件夹名
L_folder = os.path.join(testsets, testset_L)
# 超分辨率结果保存文件夹的完整路径:在数据根目录下新建"xxx_results"文件夹
E_folder = os.path.join(testsets, testset_L + '_results')
# 调用工具函数创建结果文件夹:若文件夹已存在则不重复创建,避免报错
util.mkdir(E_folder)
# 初始化有序字典,用于记录测试结果:这里先只记录每张图的运行时间
test_results = OrderedDict()
test_results['runtime'] = []
# 将低分辨率图片路径和结果保存路径输出到日志,确认路径正确
logger.info(L_folder)
logger.info(E_folder)
# 初始化图片计数变量,用于记录当前处理的是第几张图
idx = 0
# 初始化GPU计时事件:用于精确测量模型推理时间(比CPU计时更准确,避免GPU异步影响)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# 5. 循环读取低分辨率图片,逐张进行超分辨率处理
# util.get_image_paths(L_folder):获取文件夹下所有支持格式的图片路径列表
for img in util.get_image_paths(L_folder):
# --------------------------------
# (1) 读取并预处理低分辨率图片(img_L)
# --------------------------------
# 图片计数+1,更新当前处理的图片序号
idx += 1
# 提取图片文件名和后缀:如路径"a/b/c.png",img_name是"c",ext是".png"
img_name, ext = os.path.splitext(os.path.basename(img))
# 将当前处理的图片序号和文件名输出到日志,方便追踪进度
logger.info('{:->4d}--> {:>10s}'.format(idx, img_name + ext))
# 读取低分辨率图片:uint格式(像素值0-255),3通道(RGB)
img_L = util.imread_uint(img, n_channels=3)
# 将uint格式图片转为4维张量(batch, channel, height, width):batch=1(单张图),方便模型输入
img_L = util.uint2tensor4(img_L)
# 像素值缩放:原张量是0-1范围(uint2tensor4默认转换),乘以255恢复为0-255范围,匹配模型训练时的输入格式
img_L = img_L * 255.
# 将预处理后的张量移动到目标设备(GPU/CPU),确保与模型在同一设备上
img_L = img_L.to(device)
# 记录推理开始时间(GPU事件计时)
start.record()
# 模型推理:输入低分辨率张量,输出超分辨率张量(img_E)
img_E = model(img_L)
# 记录推理结束时间(GPU事件计时)
end.record()
# 等待GPU完成所有计算(同步操作),确保计时准确,避免异步导致的时间误差
torch.cuda.synchronize()
# 将当前图片的推理时间(毫秒)存入测试结果字典
test_results['runtime'].append(start.elapsed_time(end)) # milliseconds
# 注释:以下是CPU计时的备用代码,当前未启用;原理与GPU计时类似,但精度较低
# torch.cuda.synchronize()
# start = time.time()
# img_E = model(img_L)
# torch.cuda.synchronize()
# end = time.time()
# test_results['runtime'].append(end-start) # seconds
# --------------------------------
# (2) 后处理并保存超分辨率图片(img_E)
# --------------------------------
# 像素值反向缩放:将模型输出的0-255范围张量,除以255恢复为0-1范围,方便后续转换为uint格式
img_E = img_E / 255.
# 将4维张量转为uint格式图片(0-255):自动处理张量维度,去除batch维度
img_E = util.tensor2uint(img_E)
# 保存超分辨率图片:路径为结果文件夹+原文件名前4位+后缀(如"1234.png"),避免文件名过长
util.imsave(img_E, os.path.join(E_folder, img_name + '_SR' + ext)) # 加_SR区分超分图
# 6. 统计并输出模型性能指标(计算量、激活次数、参数数量)
# 设置模型输入维度:(通道数, 高度, 宽度),即3通道256x256图片,用于计算FLOPs和激活次数
input_dim = (3, 256, 256) # set the input dimension
# 计算模型在指定输入维度下的激活次数和卷积层数量
activations, num_conv = get_model_activation(model, input_dim)
# 激活次数单位转换:除以1e6,转为"百万次(M)",方便阅读
activations = activations / 10 ** 6
# 输出激活次数到日志,保留4位小数
logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
# 输出卷积层数量到日志
logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
# 计算模型在指定输入维度下的计算量(FLOPs):False表示不打印详细层信息
flops = get_model_flops(model, input_dim, False)
# 计算量单位转换:除以1e9,转为"十亿次(G)",方便阅读
flops = flops / 10 ** 9
# 输出计算量到日志,保留4位小数
logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
# 重新计算模型参数总数(与前面一致,此处为重复验证或统一格式)
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
# 参数数量单位转换:除以1e6,转为"百万个(M)",方便阅读
num_parameters = num_parameters / 10 ** 6
# 输出参数数量到日志,保留4位小数
logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
# 7. 计算并输出平均推理时间(单位:毫秒)
# 平均时间计算:所有图片的运行时间总和 ÷ 图片数量(直接保留毫秒单位)
ave_runtime_ms = sum(test_results['runtime']) / len(test_results['runtime'])
# 输出平均推理时间到日志,保留6位小数,显示测试文件夹路径
logger.info('------> Average runtime of ({}) is : {:.6f} milliseconds'.format(L_folder, ave_runtime_ms))
# 8. 程序入口:若当前脚本是直接运行(而非被导入),则执行main()函数
if __name__ == '__main__':
main()