即插即用系列(代码实践)| TGRS 2025 GST-Net: 残差注意力增强+空间与通道的双重过滤结合的红外小目标检测

论文题目:GST-Net: Global Spatio-Temporal Detection Network for Infrared Small Objects in Complex Ground Scenarios

中文题目 :GST-Net:复杂地面场景下红外小目标的全局时空检测框架
应用任务:红外小目标检测 (IRSTD)、视频目标检测、特征增强
论文原文 (Paper)https://ieeexplore.ieee.org/abstract/document/11098927
官方代码 (Code)https://github.com/elvintanhust/GST-Det

摘要

本文结合 红外小目标检测 (IRSTD) 领域的经典论文《GST-Net》中的设计思想,针对复杂地面背景下目标微弱、易被噪声淹没 的痛点,提供了一个通用的即插即用模块------Res_CBAM_block 。该模块将经典的 CBAM (Convolutional Block Attention Module) 嵌入到残差结构中,通过**通道注意力(关注"什么")空间注意力(关注"哪里")**的串联,有效抑制背景杂波,增强小目标的特征响应,是构建高性能红外检测 Backbone 的基础组件。


目录

第一部分:模块原理与实战分析

1. 论文背景与解决的痛点

在红外小目标检测(尤其是涉及视频序列的 GST-Net 任务)中,我们面临着极其恶劣的成像环境:

  • 低信噪比 (Low SCR):目标通常只有几个像素大,且亮度可能比背景还低。
  • 复杂背景干扰:地面场景中包含树木、道路、建筑物等高频纹理,这些纹理在卷积神经网络眼中很容易被误判为目标。
  • 特征淹没:随着网络层数加深,微小的目标特征很容易在下采样过程中丢失。

痛点总结 :我们需要一种机制,能够在特征提取的每一个阶段,都显式地告诉网络"哪里是目标,哪里是背景",防止目标信息流失。

2. 核心模块原理揭秘

虽然 GST-Net 论文中提出了复杂的 RMPE 和 GSTDEM 模块,但其底层特征提取往往依赖于强大的注意力机制。这里提供的 Res_CBAM_block 是实现特征增强的"万金油"模块,其核心逻辑如下:

  • 双重注意力机制 (Dual Attention)

  • 通道注意力 (Channel Attention) :利用全局平均池化和最大池化,压缩空间维度,学习每个通道的权重。它负责判断哪些特征通道包含目标信息(例如,抑制包含大面积背景噪声的通道)。

  • 空间注意力 (Spatial Attention) :在通道维度进行压缩,学习空间上的权重图。它负责定位图像的哪个位置是目标(高亮小目标区域)。

  • 残差连接 (Residual Connection)

  • 直接将注意力增强后的特征与原始输入相加。这保证了梯度能够顺畅传播,防止因为多层注意力导致的网络退化,同时实现了"特征细化"的效果。

  • fea_add_module (特征融合)

  • 一个简单但有效的逐元素加法模块,通常用于融合不同层级或不同分支(如时空双流)的特征。

3. 架构图解

4. 适用场景与魔改建议

这套代码非常适合用于以下场景的改进:

  • 红外/遥感小目标检测:替换 ResNet 中的 BasicBlock,显著降低虚警率。
  • U-Net 编码器增强:在 U-Net 的下采样路径中加入 Res_CBAM,保护小目标特征不被丢失。
  • 特征融合阶段:在 FPN(特征金字塔)的横向连接处使用,增强多尺度特征的表达能力。

第二部分:核心完整代码

python 复制代码
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    """Channel Attention Module from CBAM"""
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    """Spatial Attention Module from CBAM"""
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class Res_CBAM_block(nn.Module):
    """Residual Block with CBAM (Convolutional Block Attention Module)"""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        if stride != 1 or out_channels != in_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels))
        else:
            self.shortcut = None

        self.ca = ChannelAttention(out_channels)
        self.sa = SpatialAttention()

    def forward(self, x):
        residual = x
        if self.shortcut is not None:
            residual = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.ca(out) * out
        out = self.sa(out) * out
        out += residual
        out = self.relu(out)
        return out


class fea_add_module(nn.Module):
    """Feature Addition Module with Dual-stream Attention Fusion"""
    def __init__(self, channels):
        super().__init__()
        self.ca1 = ChannelAttention(channels * 2)
        self.ca2 = ChannelAttention(channels)
        self.sa = SpatialAttention()
        self.relu = nn.ReLU(inplace=True)

        self.shortcut1 = nn.Sequential(
            nn.Conv2d(channels * 2, channels * 2, kernel_size=1, stride=1),
            nn.BatchNorm2d(channels * 2))

        self.shortcut2 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(channels))

        self.center_layer = nn.Sequential(
            nn.Conv2d(2 * channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, S, T):
        ST = torch.cat((S, T), dim=1)
        out1 = self.ca1(ST) * self.sa(ST) * ST
        res1 = self.shortcut1(ST)
        out1 += res1
        out2 = self.center_layer(out1)
        res2 = self.shortcut2(out2)
        out = self.ca2(out2) * self.sa(out2) * out2
        out += res2
        out = self.relu(out)
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print("=" * 60)
    print("Testing SPSA Modules")
    print("=" * 60)
    
    # Test ChannelAttention
    print("\n1. Testing ChannelAttention")
    x = torch.randn(1, 32, 256, 256).to(device)
    ca = ChannelAttention(in_planes=32).to(device)
    print(f"   Module: {ca.__class__.__name__}")
    output = ca(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 32, 1, 1), "ChannelAttention output shape mismatch!"
    print("   ✓ ChannelAttention test passed!")
    
    # Test SpatialAttention
    print("\n2. Testing SpatialAttention")
    x = torch.randn(1, 32, 256, 256).to(device)
    sa = SpatialAttention(kernel_size=7).to(device)
    print(f"   Module: {sa.__class__.__name__}")
    output = sa(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 1, 256, 256), "SpatialAttention output shape mismatch!"
    print("   ✓ SpatialAttention test passed!")
    
    # Test Res_CBAM_block
    print("\n3. Testing Res_CBAM_block")
    x = torch.randn(1, 32, 256, 256).to(device)
    res_cbam = Res_CBAM_block(in_channels=32, out_channels=64, stride=2).to(device)
    print(f"   Module: {res_cbam.__class__.__name__}")
    output = res_cbam(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 64, 128, 128), "Res_CBAM_block output shape mismatch!"
    print("   ✓ Res_CBAM_block test passed!")
    
    # Test Res_CBAM_block with same channels
    print("\n4. Testing Res_CBAM_block (same channels)")
    x = torch.randn(1, 32, 256, 256).to(device)
    res_cbam = Res_CBAM_block(in_channels=32, out_channels=32, stride=1).to(device)
    print(f"   Module: {res_cbam.__class__.__name__}")
    output = res_cbam(x)
    print(f"   输入张量形状: {x.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 32, 256, 256), "Res_CBAM_block output shape mismatch!"
    print("   ✓ Res_CBAM_block test passed!")
    
    # Test fea_add_module
    print("\n5. Testing fea_add_module")
    s = torch.randn(1, 32, 256, 256).to(device)
    t = torch.randn(1, 32, 256, 256).to(device)
    fea_add = fea_add_module(channels=32).to(device)
    print(f"   Module: {fea_add.__class__.__name__}")
    output = fea_add(s, t)
    print(f"   输入张量S形状: {s.shape}")
    print(f"   输入张量T形状: {t.shape}")
    print(f"   输出张量形状: {output.shape}")
    assert output.shape == (1, 32, 256, 256), "fea_add_module output shape mismatch!"
    print("   ✓ fea_add_module test passed!")
    
    print("\n" + "=" * 60)
    print("All tests passed successfully! ✓")
    print("=" * 60)

第三部分:结果验证与总结

总结

在 GST-Net 等高性能红外检测框架中,注意力机制 是提升性能的基石。Res_CBAM_block 虽然结构简单,但它通过模拟人类视觉的"聚焦"过程,有效地解决了小目标特征微弱的难题。无论你是做视频检测还是单帧检测,加上这个模块,大概率能看到 Loss 下降和 Recall 提升!


喜欢这篇硬核复现的话,欢迎点赞收藏,订阅专栏获取更多 CV/红外目标检测 顶会论文的即插即用代码!

相关推荐
九.九7 小时前
ops-transformer:AI 处理器上的高性能 Transformer 算子库
人工智能·深度学习·transformer
春日见7 小时前
拉取与合并:如何让个人分支既包含你昨天的修改,也包含 develop 最新更新
大数据·人工智能·深度学习·elasticsearch·搜索引擎
恋猫de小郭7 小时前
AI 在提高你工作效率的同时,也一直在增加你的疲惫和焦虑
前端·人工智能·ai编程
deephub8 小时前
Agent Lightning:微软开源的框架无关 Agent 训练方案,LangChain/AutoGen 都能用
人工智能·microsoft·langchain·大语言模型·agent·强化学习
偷吃的耗子8 小时前
【CNN算法理解】:三、AlexNet 训练模块(附代码)
深度学习·算法·cnn
大模型RAG和Agent技术实践8 小时前
从零构建本地AI合同审查系统:架构设计与流式交互实战(完整源代码)
人工智能·交互·智能合同审核
老邋遢8 小时前
第三章-AI知识扫盲看这一篇就够了
人工智能
互联网江湖8 小时前
Seedance2.0炸场:长短视频们“修坝”十年,不如AI放水一天?
人工智能
PythonPioneer8 小时前
在AI技术迅猛发展的今天,传统职业该如何“踏浪前行”?
人工智能
冬奇Lab9 小时前
一天一个开源项目(第20篇):NanoBot - 轻量级AI Agent框架,极简高效的智能体构建工具
人工智能·开源·agent