【Block总结】HWD,小波下采样,适用分类、分割、目标检测等任务|即插即用

论文信息

Haar wavelet downsampling (HWD) 是一项针对语义分割的创新模块,旨在通过减少特征图的空间分辨率来提高深度卷积神经网络(DCNNs)的性能。该论文的主要贡献在于提出了一种新的下采样方法,能够在下采样阶段有效地减少信息损失。论文的详细信息如下:

  • 标题: Haar Wavelet Downsampling: A Simple but Effective Downsampling Module for Semantic Segmentation
  • 作者: Guoping Xu, Wentao Liao, Xuan Zhang, Chang Li, Xinwei He, Xinglong Wu
  • 发表年份: 2023
  • 期刊: Pattern Recognition
  • DOI : 10.1016/j.patcog.2023.109819

创新点

HWD 模块的核心创新在于:

  • 信息保留: 通过使用 Haar 小波变换,HWD 能够在下采样过程中最大限度地保留信息,避免传统下采样方法中常见的信息损失。

  • 特征熵指数: 论文中提出了一种新的度量标准,称为特征熵指数(Feature Entropy Index, FEI),用于评估下采样特征图与预测结果之间的信息不确定性。

  • 易于集成: HWD 模块可以直接替代现有的池化层或带步幅的卷积层,而不会显著增加计算开销。

方法

HWD 模块的实现方法包括以下几个步骤:

  1. Haar 小波变换: 该模块利用 Haar 小波变换对特征图进行下采样,降低空间分辨率的同时保留重要信息。

  2. 特征图编码: 在下采样过程中,部分空间信息被编码到通道维度,以便后续的卷积层能够提取判别性特征。

  3. 集成到 CNN 中: HWD 模块可以无缝集成到现有的卷积神经网络架构中,增强其语义分割能力。

HWD模块与传统下采样方法相比有哪些优势?

Haar Wavelet Downsampling (HWD) 模块相较于传统下采样方法(如最大池化和步幅卷积)具有多项显著优势:

优势

  1. 信息保留能力:

    • HWD 模块通过 Haar 小波变换进行下采样,能够在降低特征图的空间分辨率时最大限度地保留重要信息。这种方法有效减少了传统下采样过程中常见的信息损失,尤其是在语义分割任务中,保持空间信息对于像素级预测至关重要[2][5]。
  2. 特征熵指数(FEI):

    • HWD 引入了一种新的度量标准,称为特征熵指数(Feature Entropy Index, FEI),用于评估下采样后特征图与预测结果之间的信息不确定性。FEI 可以帮助量化下采样方法在保留关键信息方面的能力,从而为模型的性能提供更深入的理解[2][5]。
  3. 计算开销低:

    • HWD 模块可以直接替代现有的池化层或带步幅的卷积层,而不会显著增加计算开销。这使得 HWD 易于集成到现有的卷积神经网络架构中,提升了模型的灵活性和适应性[2][3][5]。
  4. 广泛的适用性:

    • 实验表明,HWD 模块在不同模态的图像数据集和多种 CNN 架构中均能有效提高分割性能。这种广泛的适用性使得 HWD 成为一种通用的下采样解决方案,适合多种应用场景[4][6]。
  5. 减少信息不确定性:

    • HWD 模块在下采样过程中有效减少了信息不确定性,相比传统方法,能够更好地保留特征的判别性,从而提升模型的整体表现[5][6]。

效果与实验结果

论文通过一系列综合实验验证了 HWD 模块的有效性,结果显示:

  • HWD 模块在与七种最先进的分割方法进行比较时,表现出更优的性能。

  • 实验结果表明,HWD 在保持高精度的同时,显著减少了信息损失,提升了模型的整体表现。

总结

Haar wavelet downsampling 模块为语义分割任务提供了一种简单而有效的下采样解决方案。通过引入 Haar 小波变换,该模块不仅提高了信息保留能力,还通过特征熵指数的引入,为特征重要性评估提供了新的视角。综合实验结果表明,HWD 模块在多种语义分割任务中均表现出色,具有广泛的应用潜力。

代码

python 复制代码
from pytorch_wavelets import DWTForward
import torch
from torch import nn
class HWD(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(HWD, self).__init__()
        self.wt = DWTForward(J=1, mode='zero', wave='haar')
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_ch * 4, out_ch, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        yL, yH = self.wt(x)
        y_HL = yH[0][:, :, 0, ::]
        y_LH = yH[0][:, :, 1, ::]
        y_HH = yH[0][:, :, 2, ::]
        x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
        x = self.conv_bn_relu(x)
        return x




if __name__ == "__main__":
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, height, width,channels)
    x = torch.randn(1,32,40,40).to(device)
    # 初始化 HWD 模块
    dim=32
    block = HWD(dim,dim)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)
相关推荐
人工不智能57721 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
盟接之桥21 小时前
盟接之桥说制造:引流品 × 利润品,全球电商平台高效产品组合策略(供讨论)
大数据·linux·服务器·网络·人工智能·制造
kfyty72521 小时前
集成 spring-ai 2.x 实践中遇到的一些问题及解决方案
java·人工智能·spring-ai
h64648564h21 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
数据与后端架构提升之路21 小时前
论系统安全架构设计及其应用(基于AI大模型项目)
人工智能·安全·系统安全
忆~遂愿21 小时前
ops-cv 算子库深度解析:面向视觉任务的硬件优化与数据布局(NCHW/NHWC)策略
java·大数据·linux·人工智能
Liue6123123121 小时前
YOLO11-C3k2-MBRConv3改进提升金属表面缺陷检测与分类性能_焊接裂纹气孔飞溅物焊接线识别
人工智能·分类·数据挖掘
一切尽在,你来21 小时前
第二章 预告内容
人工智能·langchain·ai编程
23遇见1 天前
基于 CANN 框架的 AI 加速:ops-nn 仓库的关键技术解读
人工智能
Codebee1 天前
OoderAgent 企业版 2.0 发布的意义:一次生态战略的全面升级
人工智能