即插即用系列(代码实践)| TGRS 2025 GST-Net: 残差注意力增强+空间与通道的双重过滤结合的红外小目标检测

论文题目:GST-Net: Global Spatio-Temporal Detection Network for Infrared Small Objects in Complex Ground Scenarios

中文题目 :GST-Net:复杂地面场景下红外小目标的全局时空检测框架
应用任务:红外小目标检测 (IRSTD)、视频目标检测、特征增强
论文原文 (Paper)https://ieeexplore.ieee.org/abstract/document/11098927
官方代码 (Code)https://github.com/elvintanhust/GST-Det

摘要

本文结合 红外小目标检测 (IRSTD) 领域的经典论文《GST-Net》中的设计思想,针对复杂地面背景下目标微弱、易被噪声淹没 的痛点,提供了一个通用的即插即用模块------Res_CBAM_block 。该模块将经典的 CBAM (Convolutional Block Attention Module) 嵌入到残差结构中,通过**通道注意力(关注"什么")空间注意力(关注"哪里")**的串联,有效抑制背景杂波,增强小目标的特征响应,是构建高性能红外检测 Backbone 的基础组件。


目录

第一部分:模块原理与实战分析

1. 论文背景与解决的痛点

在红外小目标检测(尤其是涉及视频序列的 GST-Net 任务)中,我们面临着极其恶劣的成像环境:

  • 低信噪比 (Low SCR):目标通常只有几个像素大,且亮度可能比背景还低。
  • 复杂背景干扰:地面场景中包含树木、道路、建筑物等高频纹理,这些纹理在卷积神经网络眼中很容易被误判为目标。
  • 特征淹没:随着网络层数加深,微小的目标特征很容易在下采样过程中丢失。

痛点总结 :我们需要一种机制,能够在特征提取的每一个阶段,都显式地告诉网络"哪里是目标,哪里是背景",防止目标信息流失。

2. 核心模块原理揭秘

虽然 GST-Net 论文中提出了复杂的 RMPE 和 GSTDEM 模块,但其底层特征提取往往依赖于强大的注意力机制。这里提供的 Res_CBAM_block 是实现特征增强的"万金油"模块,其核心逻辑如下:

  • 双重注意力机制 (Dual Attention)

  • 通道注意力 (Channel Attention) :利用全局平均池化和最大池化,压缩空间维度,学习每个通道的权重。它负责判断哪些特征通道包含目标信息(例如,抑制包含大面积背景噪声的通道)。

  • 空间注意力 (Spatial Attention) :在通道维度进行压缩,学习空间上的权重图。它负责定位图像的哪个位置是目标(高亮小目标区域)。

  • 残差连接 (Residual Connection)

  • 直接将注意力增强后的特征与原始输入相加。这保证了梯度能够顺畅传播,防止因为多层注意力导致的网络退化,同时实现了"特征细化"的效果。

  • fea_add_module (特征融合)

  • 一个简单但有效的逐元素加法模块,通常用于融合不同层级或不同分支(如时空双流)的特征。

3. 架构图解

4. 适用场景与魔改建议

这套代码非常适合用于以下场景的改进:

  • 红外/遥感小目标检测:替换 ResNet 中的 BasicBlock,显著降低虚警率。
  • U-Net 编码器增强:在 U-Net 的下采样路径中加入 Res_CBAM,保护小目标特征不被丢失。
  • 特征融合阶段:在 FPN(特征金字塔)的横向连接处使用,增强多尺度特征的表达能力。

第二部分:核心完整代码

python 复制代码
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    """Channel Attention Module from CBAM"""
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    """Spatial Attention Module from CBAM"""
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class Res_CBAM_block(nn.Module):
    """Residual Block with CBAM (Convolutional Block Attention Module)"""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        if stride != 1 or out_channels != in_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels))
        else:
            self.shortcut = None

        self.ca = ChannelAttention(out_channels)
        self.sa = SpatialAttention()

    def forward(self, x):
        residual = x
        if self.shortcut is not None:
            residual = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.ca(out) * out
        out = self.sa(out) * out
        out += residual
        out = self.relu(out)
        return out


class fea_add_module(nn.Module):
    """Feature Addition Module with Dual-stream Attention Fusion"""
    def __init__(self, channels):
        super().__init__()
        self.ca1 = ChannelAttention(channels * 2)
        self.ca2 = ChannelAttention(channels)
        self.sa = SpatialAttention()
        self.relu = nn.ReLU(inplace=True)

        self.shortcut1 = nn.Sequential(
            nn.Conv2d(channels * 2, channels * 2, kernel_size=1, stride=1),
            nn.BatchNorm2d(channels * 2))

        self.shortcut2 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(channels))

        self.center_layer = nn.Sequential(
            nn.Conv2d(2 * channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, S, T):
        ST = torch.cat((S, T), dim=1)
        out1 = self.ca1(ST) * self.sa(ST) * ST
        res1 = self.shortcut1(ST)
        out1 += res1
        out2 = self.center_layer(out1)
        res2 = self.shortcut2(out2)
        out = self.ca2(out2) * self.sa(out2) * out2
        out += res2
        out = self.relu(out)
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print("=" * 60)
    print("Testing SPSA Modules")
    print("=" * 60)
    
    # Test ChannelAttention
    print("\n1. Testing ChannelAttention")
    x = torch.randn(1, 32, 256, 256).to(device)
    ca = ChannelAttention(in_planes=32).to(device)
    print(f"   Module: {ca.__class__.__name__}")
    output = ca(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 32, 1, 1), "ChannelAttention output shape mismatch!"
    print("   ✓ ChannelAttention test passed!")
    
    # Test SpatialAttention
    print("\n2. Testing SpatialAttention")
    x = torch.randn(1, 32, 256, 256).to(device)
    sa = SpatialAttention(kernel_size=7).to(device)
    print(f"   Module: {sa.__class__.__name__}")
    output = sa(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 1, 256, 256), "SpatialAttention output shape mismatch!"
    print("   ✓ SpatialAttention test passed!")
    
    # Test Res_CBAM_block
    print("\n3. Testing Res_CBAM_block")
    x = torch.randn(1, 32, 256, 256).to(device)
    res_cbam = Res_CBAM_block(in_channels=32, out_channels=64, stride=2).to(device)
    print(f"   Module: {res_cbam.__class__.__name__}")
    output = res_cbam(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 64, 128, 128), "Res_CBAM_block output shape mismatch!"
    print("   ✓ Res_CBAM_block test passed!")
    
    # Test Res_CBAM_block with same channels
    print("\n4. Testing Res_CBAM_block (same channels)")
    x = torch.randn(1, 32, 256, 256).to(device)
    res_cbam = Res_CBAM_block(in_channels=32, out_channels=32, stride=1).to(device)
    print(f"   Module: {res_cbam.__class__.__name__}")
    output = res_cbam(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 32, 256, 256), "Res_CBAM_block output shape mismatch!"
    print("   ✓ Res_CBAM_block test passed!")
    
    # Test fea_add_module
    print("\n5. Testing fea_add_module")
    s = torch.randn(1, 32, 256, 256).to(device)
    t = torch.randn(1, 32, 256, 256).to(device)
    fea_add = fea_add_module(channels=32).to(device)
    print(f"   Module: {fea_add.__class__.__name__}")
    output = fea_add(s, t)
    print(f"   输入张量S形状: {s.shape}")
    print(f"   输入张量T形状: {t.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 32, 256, 256), "fea_add_module output shape mismatch!"
    print("   ✓ fea_add_module test passed!")
    
    print("\n" + "=" * 60)
    print("All tests passed successfully! ✓")
    print("=" * 60)

第三部分:结果验证与总结

总结

在 GST-Net 等高性能红外检测框架中,注意力机制 是提升性能的基石。Res_CBAM_block 虽然结构简单,但它通过模拟人类视觉的"聚焦"过程,有效地解决了小目标特征微弱的难题。无论你是做视频检测还是单帧检测,加上这个模块,大概率能看到 Loss 下降和 Recall 提升!


喜欢这篇硬核复现的话,欢迎点赞收藏,订阅专栏获取更多 CV/红外目标检测 顶会论文的即插即用代码!

相关推荐
KaneLogger6 小时前
【Agent】openclaw + opencode 打造助手 安装篇
人工智能·google·程序员
知识浅谈7 小时前
一步步带你把 OpenClaw 玩宕机(附云服务器避坑部署教程)
人工智能
冬奇Lab7 小时前
OpenClaw 深度解析(四):插件 SDK 与扩展开发机制
人工智能·开源·源码阅读
IT_陈寒8 小时前
SpringBoot实战:5个让你的API性能翻倍的隐藏技巧
前端·人工智能·后端
机器之心9 小时前
让AI自我进化?斯坦福华人博士答辩视频火了,庞若鸣参与评审
人工智能·openai
iceiceiceice9 小时前
iOS PDF阅读器段评实现:如何从 PDFSelection 精准还原一个自然段
前端·人工智能·ios
AI攻城狮10 小时前
RAG Chunking 为什么这么难?5 大挑战 + 最佳实践指南
人工智能·云原生·aigc
yiyu071610 小时前
3分钟搞懂深度学习AI:梯度下降:迷雾中的下山路
人工智能·深度学习
掘金安东尼10 小时前
玩转龙虾🦞,openclaw 核心命令行收藏(持续更新)v2026.3.2
人工智能