即插即用关系感知全局注意力模块RGA,涨点起飞

题目:Relation-Aware Global Attention for Person Re-identification

论文地址:https://arxiv.org/pdf/1904.02998

创新点

  • 全局关系感知:RGA模块通过捕获全局结构信息(如特征节点之间的相似度或亲和力)来挖掘全局范围的关系,这种结构化的全球信息有助于更好地进行语义推断和注意力分配。

  • 特征节点的紧凑表示:通过将每个特征节点与所有其他节点的两两关系和特征本身进行堆叠表示,RGA模块能够在更大范围内评估特征的重要性,从而提高辨别力。

  • 双重应用(空间和通道注意力):RGA模块在空间(RGA-S)和通道(RGA-C)维度均可应用,通过序列或并行结合使用进一步提升模型性能。

  • 轻量级设计:RGA通过浅层卷积来计算全局注意力值,相比传统非局部块,具有更高的适应性和计算效率。

方法

整体结构

论文提出的模型结构基于ResNet-50骨干网络,在其四个残差块后引入了Relation-Aware Global Attention(RGA)模块,通过空间和通道维度的关系感知增强特征表征。RGA模块结合了全局结构信息,使模型能够关注行人图像中的关键区域,同时使用分类和三元组损失优化模型性能,实现更高的识别精度。

  • 基线网络(ResNet-50):使用ResNet-50作为基础骨干网络,同时去掉了最后一层的空间下采样,以更好地保留空间信息。

  • RGA模块的插入:在ResNet-50的四个残差块(即conv2_x、conv3_x、conv4_x和conv5_x)之后分别添加RGA模块。RGA模块的插入分为空间维度注意力(RGA-S)和通道维度注意力(RGA-C),它们可以单独使用或联合使用,以增强特征提取的全局结构感知。

  • 训练损失函数:模型训练使用了识别损失(classification loss)和三元组损失(triplet loss),前者通过标签平滑技术强化分类效果,后者通过hard mining技术提升样本区分能力。

  • 数据增强和优化器:采用随机裁剪、水平翻转、随机擦除等常规的数据增强策略,并使用Adam优化器进行训练,以更好地适应行人重识别任务。

即插即用模块作用

RGA_model 作为一个即插即用模块,主要适用于:

  • 行人重识别:在行人重识别任务中,RGA可帮助模型关注行人关键区域,忽略背景杂讯。

  • 目标检测和图像分类:在目标检测和图像分类任务中,RGA模块能够提升模型对目标区域的识别精度。

  • 复杂背景和遮挡处理:在含有复杂背景、部分遮挡或姿态多变的图像场景中,RGA模块可以增强特征的全局感知,改善模型的稳健性和精确度。

消融实验结果

  • 展示了不同配置下RGA模块对模型性能的影响,包括仅使用空间维度注意力(RGA-S)、仅使用通道维度注意力(RGA-C)以及二者的组合(如RGA-SC),并分别对比了基线模型和各模块的不同组合形式。

  • 无论单独使用RGA-S或RGA-C,都能显著提高模型性能,证明了RGA模块对特征提取的增强作用。

  • 组合使用RGA-S和RGA-C(尤其是RGA-SC序列组合)时,模型性能达到最佳,进一步验证了空间和通道关系感知的联合作用对提升识别效果的重要性

即插即用模块

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F


class RGA_Module(nn.Module):
    def __init__(self, in_channel, in_spatial, use_spatial=True, use_channel=True,
                 cha_ratio=8, spa_ratio=8, down_ratio=8):
        super(RGA_Module, self).__init__()

        self.in_channel = in_channel
        self.in_spatial = in_spatial

        self.use_spatial = use_spatial
        self.use_channel = use_channel

        print('Use_Spatial_Att: {};\tUse_Channel_Att: {}.'.format(self.use_spatial, self.use_channel))

        # self.inter_channel = in_channel // cha_ratio
        # self.inter_spatial = in_spatial // spa_ratio
        self.inter_channel = max(in_channel // cha_ratio, 1)
        self.inter_spatial = max(in_spatial // spa_ratio, 1)

        # Embedding functions for original features
        if self.use_spatial:
            self.gx_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
        if self.use_channel:
            self.gx_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )

        # Embedding functions for relation features
        if self.use_spatial:
            self.gg_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial * 2, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )
        if self.use_channel:
            self.gg_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )

        # Networks for learning attention weights
        if self.use_spatial:
            num_channel_s = 1 + self.inter_spatial
            self.W_spatial = nn.Sequential(
                nn.Conv2d(in_channels=num_channel_s, out_channels=num_channel_s // down_ratio,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(num_channel_s // down_ratio),
                nn.ReLU(),
                nn.Conv2d(in_channels=num_channel_s // down_ratio, out_channels=1,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1)
            )
        if self.use_channel:
            num_channel_c = max(1 + self.inter_channel, 1) # 确保至少为1
            self.W_channel = nn.Sequential(
                nn.Conv2d(in_channels=num_channel_c, out_channels=max(num_channel_c // down_ratio, 1), # 防止为0
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(max(num_channel_c // down_ratio, 1)), # 同样防止为0
                nn.ReLU(),
                nn.Conv2d(in_channels=max(num_channel_c // down_ratio, 1), out_channels=1,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1)
            )

        # Embedding functions for modeling relations
        if self.use_spatial:
            self.theta_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
            self.phi_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
        if self.use_channel:
            self.theta_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )
            self.phi_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )

    def forward(self, x):
        b, c, h, w = x.size()

        if self.use_spatial:
            # spatial attention
            theta_xs = self.theta_spatial(x)
            phi_xs = self.phi_spatial(x)
            theta_xs = theta_xs.view(b, self.inter_channel, -1)
            theta_xs = theta_xs.permute(0, 2, 1)
            phi_xs = phi_xs.view(b, self.inter_channel, -1)
            Gs = torch.matmul(theta_xs, phi_xs)
            Gs_in = Gs.permute(0, 2, 1).view(b, h * w, h, w)
            Gs_out = Gs.view(b, h * w, h, w)
            Gs_joint = torch.cat((Gs_in, Gs_out), 1)
            Gs_joint = self.gg_spatial(Gs_joint)

            g_xs = self.gx_spatial(x)
            g_xs = torch.mean(g_xs, dim=1, keepdim=True)
            ys = torch.cat((g_xs, Gs_joint), 1)

            W_ys = self.W_spatial(ys)
            if not self.use_channel:
                out = F.sigmoid(W_ys.expand_as(x)) * x
                return out
            else:
                x = F.sigmoid(W_ys.expand_as(x)) * x

        if self.use_channel:
            # channel attention
            xc = x.view(b, c, -1).permute(0, 2, 1).unsqueeze(-1)
            theta_xc = self.theta_channel(xc).squeeze(-1).permute(0, 2, 1)
            phi_xc = self.phi_channel(xc).squeeze(-1)
            Gc = torch.matmul(theta_xc, phi_xc)
            Gc_in = Gc.permute(0, 2, 1).unsqueeze(-1)
            Gc_out = Gc.unsqueeze(-1)
            Gc_joint = torch.cat((Gc_in, Gc_out), 1)
            Gc_joint = self.gg_channel(Gc_joint)

            g_xc = self.gx_channel(xc)
            g_xc = torch.mean(g_xc, dim=1, keepdim=True)
            yc = torch.cat((g_xc, Gc_joint), 1)

            W_yc = self.W_channel(yc).transpose(1, 2)
            out = F.sigmoid(W_yc) * x

            return out

if __name__ == '__main__':
    block = RGA_Module(in_channel=3, in_spatial=32*32) # in_spatial应是h*w
    input = torch.rand(32, 3, 32, 32)
    output = block(input)
    print(input.size())
    print(output.size())
相关推荐
新手小白勇闯新世界20 分钟前
论文阅读-用于图像识别的深度残差学习
论文阅读·人工智能·深度学习·学习·计算机视觉
新手小白勇闯新世界28 分钟前
论文阅读- --DeepI2P:通过深度分类进行图像到点云配准
论文阅读·深度学习·算法·计算机视觉
Jurio.6 小时前
【SPIE单独出版审核,见刊检索稳定!】2024年遥感技术与图像处理国际学术会议(RSTIP 2024,11月29-12月1日)
大数据·图像处理·人工智能·深度学习·机器学习·计算机视觉·学术会议
真的是我26 小时前
基于MATLAB课程设计-图像处理完整版
图像处理·人工智能·计算机视觉·matlab
广州视觉芯软件有限公司8 小时前
MFC,DLL界面库设计注意
c++·人工智能·计算机视觉·mfc
pen-ai1 天前
【机器学习】19. CNN 卷积神经网络 Convolutional neural network
人工智能·深度学习·机器学习·计算机视觉·cnn
孤单网愈云1 天前
11.4OpenCV_图像预处理02
人工智能·opencv·计算机视觉
我就想睡到自然醒1 天前
【计算机视觉基础】卷积
人工智能·计算机视觉
Angelina_Jolie1 天前
即插即用显著位置注意力spab,涨点起飞
计算机视觉
bigshark_software1 天前
2024-11-04 问AI: [AI面试题] 解释计算机视觉的概念
人工智能·计算机视觉