超分辨率重建 | 2025 FIWHN:轻量级超分辨率 SOTA!基于“宽残差”与 Transformer 混合架构的高效网络(代码实践)

论文名称:Efficient Image Super-Resolution with Feature Interaction Weighted Hybrid Network

论文原文 (Paper)https://arxiv.org/abs/2212.14181
官方代码 (Code)https://github.com/IVIPLab/FIWHN

超分辨率重建 | 2025 FIWHN:轻量级超分辨率 SOTA!基于"宽残差"与 Transformer 混合架构的高效网络(论文精读)

摘要

本文提取自 IEEE TMM 顶刊论文《Efficient Image Super-Resolution with Feature Interaction Weighted Hybrid Network (FIWHN)》。针对轻量级网络中激活函数导致特征丢失 以及CNN缺乏全局建模能力 的痛点,通过复现论文核心代码,提供了两个超强的即插即用模块:WDIB(宽残差蒸馏交互块)TransBlock(高效Transformer块)。代码已封装好,复制即可无缝嵌入YOLO、UNet或ResNet等网络中进行魔改。


目录

第一部分:模块原理与实战分析

1. 论文背景与解决的痛点

在计算机视觉任务(尤其是超分辨率、目标检测的小目标层)中,我们经常面临两个尴尬的处境:

  1. 特征"死"在了激活函数上:传统CNN中广泛使用的ReLU等激活函数会导致中间层特征信息的丢失,尤其是在网络层数加深时,很多细节纹理就没了 。

  2. 局部与全局的割裂:CNN擅长提取局部纹理,Transformer擅长抓全局关联。现有的混合网络要么是简单的串联,要么是并行后硬拼凑,两者缺乏深度的特征交互,导致伪影产生 。

2. 核心模块原理揭秘

为了解决上述问题,FIWHN提出了两个核心组件,我已将其提取为独立的PyTorch模块:

  • WDIB (Wide-residual Distillation Interaction Block):

  • 对应代码类名MY

  • 原理:利用"宽残差"机制,在激活函数前扩展通道数,防止特征丢失。同时引入了"特征交互"和"蒸馏"机制,通过学习系数将不同层级的特征进行加权融合 。

  • 作用:极大增强了网络对局部细节的保留能力。

  • TransBlock (Efficient Transformer):

  • 对应代码类名TransBlock (包含 EffAttention)

  • 原理:为了解决Transformer计算量大的问题,采用了高效注意力机制(Efficient Attention),通过分组和分割操作降低显存占用,同时捕捉长距离依赖 。

  • 作用:弥补CNN全局感受野不足的缺陷。

3. 架构图解

建议参考论文中的 Figure 3 ,它详细展示了 WDIB 的内部构造(如何进行宽残差连接和蒸馏)以及 Efficient Transformer 的设计。

4. 适用场景与魔改建议

这套代码非常适合用于以下场景的改进:

  • YOLO系列的主干或Neck部分 :用 WDIB 替换原有的 C2f 或 Bottleneck,增强特征提取能力。
  • 图像复原/超分任务:直接作为深层特征提取器。
  • 小目标检测 :利用 TransBlock 增强全局上下文信息,防止小目标漏检。

第二部分:核心完整代码

博主提示 :以下代码包含完整的辅助函数、核心模块(WDIB/MY 和 TransBlock)以及测试主函数。代码基于 PyTorch 实现,复制粘贴保存为 .py 文件即可运行。

python 复制代码
"""
FIWHN 核心即插即用模块
提取自 FIWHN-基于特征交互加权混合网络的高效图像超分辨率

🔥 核心创新模块:
1. MY (WDIB): Wide-residual Distillation Interaction Block
   - 论文主要创新:特征交互和蒸馏机制
   - 实现了"特征交互加权"的核心思想
   
2. TransBlock: Efficient Transformer
   - 论文第二大创新:高效Transformer设计
   - CNN与Transformer的混合网络

📦 支撑模块:
- CoffConv: 系数卷积("加权"机制的关键)
- SRBW1, SRBW2: WDIB的构建块
- sa_layer: 空间-通道注意力
- EffAttention, Mlp: Transformer组件
- Scale: 可学习缩放因子

测试环境: anaconda torchv5
"""

import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter


# ============================
# 辅助函数
# ============================

def std(x):
    """
    计算特征图的标准差
    
    参数:
        x: 输入特征图 (B, C, H, W)
    
    返回:
        标准差特征图 (B, C, 1, 1)
    """
    return torch.std(x, dim=[2, 3], keepdim=True)


def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1):
    """
    激活函数工厂
    
    参数:
        act_type: 激活函数类型 ('relu', 'lrelu', 'prelu')
        inplace: 是否in-place操作
        neg_slope: LeakyReLU的负斜率
        n_prelu: PReLU的参数数量
    
    返回:
        激活函数层
    """
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU()
    elif act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer


def same_padding(images, ksizes, strides, rates):
    """
    计算same padding
    
    参数:
        images: 输入图像 (B, C, H, W)
        ksizes: 卷积核大小 [kh, kw]
        strides: 步长 [sh, sw]
        rates: 膨胀率 [rh, rw]
    
    返回:
        padding后的图像
    """
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images


def extract_image_patches(images, ksizes, strides, rates, padding='same'):
    """
    提取图像patches
    
    参数:
        images: 输入图像 (B, C, H, W)
        ksizes: patch大小 [kh, kw]
        strides: 步长 [sh, sw]
        rates: 膨胀率 [rh, rw]
        padding: padding类型 ('same' or 'valid')
    
    返回:
        patches (B, C*kh*kw, L), L是patch数量
    """
    assert len(images.size()) == 4
    assert padding in ['same', 'valid']
    
    if padding == 'same':
        images = same_padding(images, ksizes, strides, rates)
    elif padding == 'valid':
        pass
    else:
        raise NotImplementedError('Unsupported padding type: {}. Only "same" or "valid" are supported.'.format(padding))

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                             dilation=rates,
                             padding=0,
                             stride=strides)
    patches = unfold(images)
    return patches


def reverse_patches(images, out_size, ksizes, strides, padding):
    """
    将patches重组为图像
    
    参数:
        images: patches (B, C*kh*kw, L)
        out_size: 输出图像尺寸 (H, W)
        ksizes: patch大小 [kh, kw]
        strides: 步长
        padding: padding大小
    
    返回:
        重组后的图像 (B, C, H, W)
    """
    unfold = torch.nn.Fold(output_size=out_size, 
                          kernel_size=ksizes, 
                          dilation=1, 
                          padding=padding, 
                          stride=strides)
    patches = unfold(images)
    return patches


# ============================
# 基础模块
# ============================

class Scale(nn.Module):
    """
    可学习的缩放因子
    
    参数:
        init_value: 初始缩放值
    """
    def __init__(self, init_value=1e-3):
        super().__init__()
        self.scale = nn.Parameter(torch.FloatTensor([init_value]))

    def forward(self, input):
        return input * self.scale





# ============================
# 注意力模块
# ============================

class sa_layer(nn.Module):
    """
    空间-通道混洗注意力层
    
    参数:
        n_feats: 特征通道数
        groups: 分组数量(默认4)
    """
    def __init__(self, n_feats, groups=4):
        super(sa_layer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1, n_feats // (2 * groups), 1, 1))
        self.cbias = Parameter(torch.ones(1, n_feats // (2 * groups), 1, 1))
        self.sweight = Parameter(torch.zeros(1, n_feats // (2 * groups), 1, 1))
        self.sbias = Parameter(torch.ones(1, n_feats // (2 * groups), 1, 1))

        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(n_feats // (2 * groups), n_feats // (2 * groups))

    @staticmethod
    def channel_shuffle(x, groups):
        """通道混洗"""
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)
        x = x.reshape(b, -1, h, w)
        return x

    def forward(self, x):
        b, c, h, w = x.shape

        x = x.reshape(b * self.groups, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)

        # channel attention
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0 * self.sigmoid(xn)

        # spatial attention
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1 * self.sigmoid(xs)

        # concatenate along channel axis
        out = torch.cat([xn, xs], dim=1)
        out = out.reshape(b, -1, h, w)

        out = self.channel_shuffle(out, 2)
        return out


class CoffConv(nn.Module):
    """
    系数卷积 - 结合均值池化和标准差分支
    
    参数:
        n_feats: 特征通道数
    """
    def __init__(self, n_feats):
        super(CoffConv, self).__init__()
        # 上分支: 均值池化
        self.upper_branch = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(n_feats, n_feats // 8, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats // 8, n_feats, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
        )

        self.std = std
        # 下分支: 标准差
        self.lower_branch = nn.Sequential(
            nn.Conv2d(n_feats, n_feats // 8, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats // 8, n_feats, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
        )

    def forward(self, fea):
        upper = self.upper_branch(fea)
        lower = self.std(fea)
        lower = self.lower_branch(lower)
        out = torch.add(upper, lower) / 2
        return out


# ============================
# 残差块模块
# ============================

class SRBW1(nn.Module):
    """
    简单残差块带权重1 (Simple Residual Block with Weight 1)
    
    参数:
        n_feats: 特征通道数
        wn: 权重归一化函数
        act: 激活函数
    """
    def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x), act=nn.ReLU(True)):
        super(SRBW1, self).__init__()
        self.res_scale = Scale(1)
        self.x_scale = Scale(1)
        body = []
        body.append(nn.Conv2d(n_feats, n_feats*2, kernel_size=1, padding=0))
        body.append(act)
        body.append(nn.Conv2d(n_feats*2, n_feats//2, kernel_size=1, padding=0))
        body.append(nn.Conv2d(n_feats//2, n_feats, kernel_size=3, padding=1))

        self.body = nn.Sequential(*body)
        self.SAlayer = sa_layer(n_feats)

    def forward(self, x):
        y = self.res_scale(self.body(x))
        return y


class SRBW2(nn.Module):
    """
    简单残差块带权重2 (Simple Residual Block with Weight 2)
    
    参数:
        n_feats: 特征通道数
        wn: 权重归一化函数
        act: 激活函数
    """
    def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x), act=nn.ReLU(True)):
        super(SRBW2, self).__init__()
        self.res_scale = Scale(1)
        self.x_scale = Scale(1)
        body = []
        body.append(nn.Conv2d(n_feats, n_feats*2, kernel_size=1, padding=0))
        body.append(act)
        body.append(nn.Conv2d(n_feats*2, n_feats//2, kernel_size=1, padding=0))
        body.append(nn.Conv2d(n_feats//2, n_feats//2, kernel_size=3, padding=1))

        self.body = nn.Sequential(*body)
        self.SAlayer = sa_layer(n_feats//2)
        self.conv = nn.Conv2d(n_feats, n_feats//2, kernel_size=3, padding=1)

    def forward(self, x):
        y = self.res_scale(self.body(x)) + self.x_scale(self.conv(x))
        return y


# ============================
# 交互蒸馏块
# ============================

class MY(nn.Module):
    """
    主要交互蒸馏块 (Main interactYon block)
    这是WDIB的核心实现,包含多个SRBW块和CoffConv系数卷积
    
    参数:
        n_feats: 特征通道数
        act: 激活函数
    """
    def __init__(self, n_feats, act=nn.ReLU(True)):
        super(MY, self).__init__()

        self.act = activation('lrelu', neg_slope=0.05)
        wn = lambda x: torch.nn.utils.weight_norm(x)
        self.srb1 = SRBW1(n_feats)
        self.srb2 = SRBW1(n_feats)
        self.rb1 = SRBW1(n_feats)
        self.rb2 = SRBW1(n_feats)
        self.A1_coffconv = CoffConv(n_feats)
        self.B1_coffconv = CoffConv(n_feats)
        self.A2_coffconv = CoffConv(n_feats)
        self.B2_coffconv = CoffConv(n_feats)
        self.conv_distilled1 = nn.Conv2d(n_feats, n_feats, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv_distilled2 = nn.Conv2d(n_feats, n_feats, kernel_size=1, stride=1, padding=0, bias=False)
        self.sigmoid1 = nn.Sigmoid()
        self.sigmoid2 = nn.Sigmoid()
        self.sigmoid3 = nn.Sigmoid()
        self.scale_x1 = Scale(1)
        self.scale_x2 = Scale(1)
        self.srb3 = SRBW1(n_feats)
        self.srb4 = SRBW1(n_feats)
        self.fuse1 = SRBW2(n_feats*2)
        self.fuse2 = nn.Conv2d(2*n_feats, n_feats, kernel_size=1, stride=1, padding=0, bias=False, dilation=1)

    def forward(self, x):
        out_a = self.act(self.srb1(x))
        distilled_a1 = remaining_a1 = out_a
        out_a = self.rb1(remaining_a1)
        A1 = self.A1_coffconv(out_a)
        out_b_1 = A1 * out_a + x
        B1 = self.B1_coffconv(x)
        out_a_1 = B1 * x + out_a

        out_b = self.act(self.srb2(out_b_1))
        distilled_b1 = remaining_b1 = out_b
        out_b = self.rb2(remaining_b1)
        A2 = self.A2_coffconv(out_a_1)
        out_b_2 = A2 * out_a_1 + out_b
        out_b_2 = out_b_2 * self.sigmoid1(self.conv_distilled1(distilled_b1))
        B2 = self.B2_coffconv(out_b)
        out_a_2 = out_b * B2 + out_a_1
        out_a_2 = out_a_2 * self.sigmoid2(self.conv_distilled2(distilled_a1))

        out_a_out = self.srb3(out_a_2)
        out_b_out = self.srb4(out_b_2)

        out1 = self.fuse1(torch.cat([self.scale_x1(out_a_out), self.scale_x2(out_b_out)], dim=1))
        out2 = self.sigmoid3(self.fuse2(torch.cat([self.scale_x1(out_a_out), self.scale_x2(out_b_out)], dim=1)))
        
        out = out2 * out_b_out
        y1 = out1 + out

        return y1


# ============================
# Transformer模块
# ============================

class Mlp(nn.Module):
    """
    多层感知器 (MLP)
    
    参数:
        in_features: 输入特征维度
        hidden_features: 隐藏层特征维度
        out_features: 输出特征维度
        act_layer: 激活函数层
        drop: dropout率
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features//4
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class EffAttention(nn.Module):
    """
    高效注意力机制 (Efficient Attention)
    使用分组注意力机制降低计算复杂度
    
    参数:
        dim: 特征维度
        num_heads: 注意力头数
        qkv_bias: 是否使用QKV的bias
        qk_scale: QK缩放因子
        attn_drop: 注意力dropout率
        proj_drop: 投影dropout率
    """
    def __init__(self, dim, num_heads=9, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias)
        self.qkv = nn.Linear(dim//2, dim//2 * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim//2, dim)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self, x):
        x = self.reduce(x)
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        q_all = torch.split(q, math.ceil(N//4), dim=-2)
        k_all = torch.split(k, math.ceil(N//4), dim=-2)
        v_all = torch.split(v, math.ceil(N//4), dim=-2)

        output = []
        for q, k, v in zip(q_all, k_all, v_all):
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            trans_x = (attn @ v).transpose(1, 2)
            output.append(trans_x)
        x = torch.cat(output, dim=1)
        x = x.reshape(B, N, C)
        x = self.proj(x)
        return x


class TransBlock(nn.Module):
    """
    Transformer块 (Transformer Block)
    结合高效注意力和MLP
    
    参数:
        n_feat: 特征通道数
        dim: Transformer维度
        num_heads: 注意力头数
        mlp_ratio: MLP扩展比例
        qkv_bias: 是否使用QKV的bias
        qk_scale: QK缩放因子
        drop: dropout率
        attn_drop: 注意力dropout率
        drop_path: drop path率
        act_layer: 激活函数层
        norm_layer: 归一化层
    """
    def __init__(
        self, n_feat=64, dim=768, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
        drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
        super(TransBlock, self).__init__()
        self.dim = dim
        self.atten = EffAttention(self.dim, num_heads=9, qkv_bias=False, qk_scale=None,
                             attn_drop=0., proj_drop=0.)
        self.norm1 = nn.LayerNorm(self.dim)
        self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop)
        self.norm2 = nn.LayerNorm(self.dim)

    def forward(self, x):
        b, c, h, w = x.shape
        x = extract_image_patches(x, ksizes=[3, 3],
                                  strides=[1, 1],
                                  rates=[1, 1],
                                  padding='same')
        x = x.permute(0, 2, 1)
        x = x + self.atten(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        x = x.permute(0, 2, 1)
        x = reverse_patches(x, (h, w), (3, 3), 1, 1)
        
        return x


# ============================
# 测试代码
# ============================

if __name__ == "__main__":
    # 输入 B C H W, 输出 B C H W
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.randn(1, 32, 64, 64).to(device)
    
    print("=" * 60)
    print("FIWHN 核心即插即用模块测试")
    print(f"运行设备: {device}")
    print("=" * 60)

    
    # 测试MY (WDIB核心模块) - 论文主要创新
    print("[核心1] MY - Wide-residual Distillation Interaction Block:")
    my_block = MY(32).to(device)
    y_my = my_block(x)
    print(f"  输入shape: {x.shape}")
    print(f"  输出shape: {y_my.shape}")
    print(f"  WDIB是本文的主要创新,实现特征交互和蒸馏机制")

    # 测试TransBlock - 论文第二大创新
    print("[核心2] TransBlock - Efficient Transformer:")
    trans_block = TransBlock(n_feat=32, dim=32*9).to(device)
    y_trans = trans_block(x)
    print(f"  输入shape: {x.shape}")
    print(f"  输出shape: {y_trans.shape}")
    print(f"  高效Transformer设计,CNN与Transformer混合")
    print()
    

    

第三部分:结果验证与总结

为了确保代码的可用性,我已经在本地环境(PyTorch + CUDA)进行了测试。运行上述代码的 __main__ 部分,你可以看到清晰的输入输出维度打印,证明模块可以跑通且不改变特征图尺寸(Padding处理得当),真正做到了"即插即用"。

总结

FIWHN 通过 WDIB 解决了特征在深层网络中丢失的问题,又通过 TransBlock 以极低的计算成本引入了全局注意力。如果你正在做 YOLO 改进、语义分割或者图像复原,这两个模块绝对值得一试!


相关推荐
NAGNIP1 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab2 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab2 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP6 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年6 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼6 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS7 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区8 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈8 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang8 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx