即插即用关系感知全局注意力模块RGA,涨点起飞

题目:Relation-Aware Global Attention for Person Re-identification

论文地址:https://arxiv.org/pdf/1904.02998

创新点

  • 全局关系感知:RGA模块通过捕获全局结构信息(如特征节点之间的相似度或亲和力)来挖掘全局范围的关系,这种结构化的全球信息有助于更好地进行语义推断和注意力分配。

  • 特征节点的紧凑表示:通过将每个特征节点与所有其他节点的两两关系和特征本身进行堆叠表示,RGA模块能够在更大范围内评估特征的重要性,从而提高辨别力。

  • 双重应用(空间和通道注意力):RGA模块在空间(RGA-S)和通道(RGA-C)维度均可应用,通过序列或并行结合使用进一步提升模型性能。

  • 轻量级设计:RGA通过浅层卷积来计算全局注意力值,相比传统非局部块,具有更高的适应性和计算效率。

方法

整体结构

论文提出的模型结构基于ResNet-50骨干网络,在其四个残差块后引入了Relation-Aware Global Attention(RGA)模块,通过空间和通道维度的关系感知增强特征表征。RGA模块结合了全局结构信息,使模型能够关注行人图像中的关键区域,同时使用分类和三元组损失优化模型性能,实现更高的识别精度。

  • 基线网络(ResNet-50):使用ResNet-50作为基础骨干网络,同时去掉了最后一层的空间下采样,以更好地保留空间信息。

  • RGA模块的插入:在ResNet-50的四个残差块(即conv2_x、conv3_x、conv4_x和conv5_x)之后分别添加RGA模块。RGA模块的插入分为空间维度注意力(RGA-S)和通道维度注意力(RGA-C),它们可以单独使用或联合使用,以增强特征提取的全局结构感知。

  • 训练损失函数:模型训练使用了识别损失(classification loss)和三元组损失(triplet loss),前者通过标签平滑技术强化分类效果,后者通过hard mining技术提升样本区分能力。

  • 数据增强和优化器:采用随机裁剪、水平翻转、随机擦除等常规的数据增强策略,并使用Adam优化器进行训练,以更好地适应行人重识别任务。

即插即用模块作用

RGA_model 作为一个即插即用模块,主要适用于:

  • 行人重识别:在行人重识别任务中,RGA可帮助模型关注行人关键区域,忽略背景杂讯。

  • 目标检测和图像分类:在目标检测和图像分类任务中,RGA模块能够提升模型对目标区域的识别精度。

  • 复杂背景和遮挡处理:在含有复杂背景、部分遮挡或姿态多变的图像场景中,RGA模块可以增强特征的全局感知,改善模型的稳健性和精确度。

消融实验结果

  • 展示了不同配置下RGA模块对模型性能的影响,包括仅使用空间维度注意力(RGA-S)、仅使用通道维度注意力(RGA-C)以及二者的组合(如RGA-SC),并分别对比了基线模型和各模块的不同组合形式。

  • 无论单独使用RGA-S或RGA-C,都能显著提高模型性能,证明了RGA模块对特征提取的增强作用。

  • 组合使用RGA-S和RGA-C(尤其是RGA-SC序列组合)时,模型性能达到最佳,进一步验证了空间和通道关系感知的联合作用对提升识别效果的重要性

即插即用模块

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F


class RGA_Module(nn.Module):
    def __init__(self, in_channel, in_spatial, use_spatial=True, use_channel=True,
                 cha_ratio=8, spa_ratio=8, down_ratio=8):
        super(RGA_Module, self).__init__()

        self.in_channel = in_channel
        self.in_spatial = in_spatial

        self.use_spatial = use_spatial
        self.use_channel = use_channel

        print('Use_Spatial_Att: {};\tUse_Channel_Att: {}.'.format(self.use_spatial, self.use_channel))

        # self.inter_channel = in_channel // cha_ratio
        # self.inter_spatial = in_spatial // spa_ratio
        self.inter_channel = max(in_channel // cha_ratio, 1)
        self.inter_spatial = max(in_spatial // spa_ratio, 1)

        # Embedding functions for original features
        if self.use_spatial:
            self.gx_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
        if self.use_channel:
            self.gx_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )

        # Embedding functions for relation features
        if self.use_spatial:
            self.gg_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial * 2, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )
        if self.use_channel:
            self.gg_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )

        # Networks for learning attention weights
        if self.use_spatial:
            num_channel_s = 1 + self.inter_spatial
            self.W_spatial = nn.Sequential(
                nn.Conv2d(in_channels=num_channel_s, out_channels=num_channel_s // down_ratio,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(num_channel_s // down_ratio),
                nn.ReLU(),
                nn.Conv2d(in_channels=num_channel_s // down_ratio, out_channels=1,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1)
            )
        if self.use_channel:
            num_channel_c = max(1 + self.inter_channel, 1) # 确保至少为1
            self.W_channel = nn.Sequential(
                nn.Conv2d(in_channels=num_channel_c, out_channels=max(num_channel_c // down_ratio, 1), # 防止为0
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(max(num_channel_c // down_ratio, 1)), # 同样防止为0
                nn.ReLU(),
                nn.Conv2d(in_channels=max(num_channel_c // down_ratio, 1), out_channels=1,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1)
            )

        # Embedding functions for modeling relations
        if self.use_spatial:
            self.theta_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
            self.phi_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
        if self.use_channel:
            self.theta_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )
            self.phi_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )

    def forward(self, x):
        b, c, h, w = x.size()

        if self.use_spatial:
            # spatial attention
            theta_xs = self.theta_spatial(x)
            phi_xs = self.phi_spatial(x)
            theta_xs = theta_xs.view(b, self.inter_channel, -1)
            theta_xs = theta_xs.permute(0, 2, 1)
            phi_xs = phi_xs.view(b, self.inter_channel, -1)
            Gs = torch.matmul(theta_xs, phi_xs)
            Gs_in = Gs.permute(0, 2, 1).view(b, h * w, h, w)
            Gs_out = Gs.view(b, h * w, h, w)
            Gs_joint = torch.cat((Gs_in, Gs_out), 1)
            Gs_joint = self.gg_spatial(Gs_joint)

            g_xs = self.gx_spatial(x)
            g_xs = torch.mean(g_xs, dim=1, keepdim=True)
            ys = torch.cat((g_xs, Gs_joint), 1)

            W_ys = self.W_spatial(ys)
            if not self.use_channel:
                out = F.sigmoid(W_ys.expand_as(x)) * x
                return out
            else:
                x = F.sigmoid(W_ys.expand_as(x)) * x

        if self.use_channel:
            # channel attention
            xc = x.view(b, c, -1).permute(0, 2, 1).unsqueeze(-1)
            theta_xc = self.theta_channel(xc).squeeze(-1).permute(0, 2, 1)
            phi_xc = self.phi_channel(xc).squeeze(-1)
            Gc = torch.matmul(theta_xc, phi_xc)
            Gc_in = Gc.permute(0, 2, 1).unsqueeze(-1)
            Gc_out = Gc.unsqueeze(-1)
            Gc_joint = torch.cat((Gc_in, Gc_out), 1)
            Gc_joint = self.gg_channel(Gc_joint)

            g_xc = self.gx_channel(xc)
            g_xc = torch.mean(g_xc, dim=1, keepdim=True)
            yc = torch.cat((g_xc, Gc_joint), 1)

            W_yc = self.W_channel(yc).transpose(1, 2)
            out = F.sigmoid(W_yc) * x

            return out

if __name__ == '__main__':
    block = RGA_Module(in_channel=3, in_spatial=32*32) # in_spatial应是h*w
    input = torch.rand(32, 3, 32, 32)
    output = block(input)
    print(input.size())
    print(output.size())
相关推荐
如若1234 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
加密新世界6 小时前
优化 Solana 程序
人工智能·算法·计算机视觉
WeeJot嵌入式10 小时前
OpenCV:计算机视觉的瑞士军刀
计算机视觉
思通数科多模态大模型10 小时前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
学不会lostfound10 小时前
三、计算机视觉_05MTCNN人脸检测
pytorch·深度学习·计算机视觉·mtcnn·p-net·r-net·o-net
Mr.谢尔比11 小时前
李宏毅机器学习课程知识点摘要(1-5集)
人工智能·pytorch·深度学习·神经网络·算法·机器学习·计算机视觉
思通数科AI全行业智能NLP系统11 小时前
六大核心应用场景,解锁AI检测系统的智能安全之道
图像处理·人工智能·深度学习·安全·目标检测·计算机视觉·知识图谱
李歘歘15 小时前
Stable Diffusion经典应用场景
人工智能·深度学习·计算机视觉
饭碗、碗碗香15 小时前
OpenCV笔记:图像去噪对比
人工智能·笔记·opencv·计算机视觉
蚂蚁没问题s17 小时前
图像处理 - 色彩空间转换
图像处理·人工智能·算法·机器学习·计算机视觉