即插即用关系感知全局注意力模块RGA,涨点起飞

题目:Relation-Aware Global Attention for Person Re-identification

论文地址:https://arxiv.org/pdf/1904.02998

创新点

  • 全局关系感知:RGA模块通过捕获全局结构信息(如特征节点之间的相似度或亲和力)来挖掘全局范围的关系,这种结构化的全球信息有助于更好地进行语义推断和注意力分配。

  • 特征节点的紧凑表示:通过将每个特征节点与所有其他节点的两两关系和特征本身进行堆叠表示,RGA模块能够在更大范围内评估特征的重要性,从而提高辨别力。

  • 双重应用(空间和通道注意力):RGA模块在空间(RGA-S)和通道(RGA-C)维度均可应用,通过序列或并行结合使用进一步提升模型性能。

  • 轻量级设计:RGA通过浅层卷积来计算全局注意力值,相比传统非局部块,具有更高的适应性和计算效率。

方法

整体结构

论文提出的模型结构基于ResNet-50骨干网络,在其四个残差块后引入了Relation-Aware Global Attention(RGA)模块,通过空间和通道维度的关系感知增强特征表征。RGA模块结合了全局结构信息,使模型能够关注行人图像中的关键区域,同时使用分类和三元组损失优化模型性能,实现更高的识别精度。

  • 基线网络(ResNet-50):使用ResNet-50作为基础骨干网络,同时去掉了最后一层的空间下采样,以更好地保留空间信息。

  • RGA模块的插入:在ResNet-50的四个残差块(即conv2_x、conv3_x、conv4_x和conv5_x)之后分别添加RGA模块。RGA模块的插入分为空间维度注意力(RGA-S)和通道维度注意力(RGA-C),它们可以单独使用或联合使用,以增强特征提取的全局结构感知。

  • 训练损失函数:模型训练使用了识别损失(classification loss)和三元组损失(triplet loss),前者通过标签平滑技术强化分类效果,后者通过hard mining技术提升样本区分能力。

  • 数据增强和优化器:采用随机裁剪、水平翻转、随机擦除等常规的数据增强策略,并使用Adam优化器进行训练,以更好地适应行人重识别任务。

即插即用模块作用

RGA_model 作为一个即插即用模块,主要适用于:

  • 行人重识别:在行人重识别任务中,RGA可帮助模型关注行人关键区域,忽略背景杂讯。

  • 目标检测和图像分类:在目标检测和图像分类任务中,RGA模块能够提升模型对目标区域的识别精度。

  • 复杂背景和遮挡处理:在含有复杂背景、部分遮挡或姿态多变的图像场景中,RGA模块可以增强特征的全局感知,改善模型的稳健性和精确度。

消融实验结果

  • 展示了不同配置下RGA模块对模型性能的影响,包括仅使用空间维度注意力(RGA-S)、仅使用通道维度注意力(RGA-C)以及二者的组合(如RGA-SC),并分别对比了基线模型和各模块的不同组合形式。

  • 无论单独使用RGA-S或RGA-C,都能显著提高模型性能,证明了RGA模块对特征提取的增强作用。

  • 组合使用RGA-S和RGA-C(尤其是RGA-SC序列组合)时,模型性能达到最佳,进一步验证了空间和通道关系感知的联合作用对提升识别效果的重要性

即插即用模块

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F


class RGA_Module(nn.Module):
    def __init__(self, in_channel, in_spatial, use_spatial=True, use_channel=True,
                 cha_ratio=8, spa_ratio=8, down_ratio=8):
        super(RGA_Module, self).__init__()

        self.in_channel = in_channel
        self.in_spatial = in_spatial

        self.use_spatial = use_spatial
        self.use_channel = use_channel

        print('Use_Spatial_Att: {};\tUse_Channel_Att: {}.'.format(self.use_spatial, self.use_channel))

        # self.inter_channel = in_channel // cha_ratio
        # self.inter_spatial = in_spatial // spa_ratio
        self.inter_channel = max(in_channel // cha_ratio, 1)
        self.inter_spatial = max(in_spatial // spa_ratio, 1)

        # Embedding functions for original features
        if self.use_spatial:
            self.gx_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
        if self.use_channel:
            self.gx_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )

        # Embedding functions for relation features
        if self.use_spatial:
            self.gg_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial * 2, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )
        if self.use_channel:
            self.gg_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )

        # Networks for learning attention weights
        if self.use_spatial:
            num_channel_s = 1 + self.inter_spatial
            self.W_spatial = nn.Sequential(
                nn.Conv2d(in_channels=num_channel_s, out_channels=num_channel_s // down_ratio,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(num_channel_s // down_ratio),
                nn.ReLU(),
                nn.Conv2d(in_channels=num_channel_s // down_ratio, out_channels=1,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1)
            )
        if self.use_channel:
            num_channel_c = max(1 + self.inter_channel, 1) # 确保至少为1
            self.W_channel = nn.Sequential(
                nn.Conv2d(in_channels=num_channel_c, out_channels=max(num_channel_c // down_ratio, 1), # 防止为0
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(max(num_channel_c // down_ratio, 1)), # 同样防止为0
                nn.ReLU(),
                nn.Conv2d(in_channels=max(num_channel_c // down_ratio, 1), out_channels=1,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1)
            )

        # Embedding functions for modeling relations
        if self.use_spatial:
            self.theta_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
            self.phi_spatial = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_channel),
                nn.ReLU()
            )
        if self.use_channel:
            self.theta_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )
            self.phi_channel = nn.Sequential(
                nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(self.inter_spatial),
                nn.ReLU()
            )

    def forward(self, x):
        b, c, h, w = x.size()

        if self.use_spatial:
            # spatial attention
            theta_xs = self.theta_spatial(x)
            phi_xs = self.phi_spatial(x)
            theta_xs = theta_xs.view(b, self.inter_channel, -1)
            theta_xs = theta_xs.permute(0, 2, 1)
            phi_xs = phi_xs.view(b, self.inter_channel, -1)
            Gs = torch.matmul(theta_xs, phi_xs)
            Gs_in = Gs.permute(0, 2, 1).view(b, h * w, h, w)
            Gs_out = Gs.view(b, h * w, h, w)
            Gs_joint = torch.cat((Gs_in, Gs_out), 1)
            Gs_joint = self.gg_spatial(Gs_joint)

            g_xs = self.gx_spatial(x)
            g_xs = torch.mean(g_xs, dim=1, keepdim=True)
            ys = torch.cat((g_xs, Gs_joint), 1)

            W_ys = self.W_spatial(ys)
            if not self.use_channel:
                out = F.sigmoid(W_ys.expand_as(x)) * x
                return out
            else:
                x = F.sigmoid(W_ys.expand_as(x)) * x

        if self.use_channel:
            # channel attention
            xc = x.view(b, c, -1).permute(0, 2, 1).unsqueeze(-1)
            theta_xc = self.theta_channel(xc).squeeze(-1).permute(0, 2, 1)
            phi_xc = self.phi_channel(xc).squeeze(-1)
            Gc = torch.matmul(theta_xc, phi_xc)
            Gc_in = Gc.permute(0, 2, 1).unsqueeze(-1)
            Gc_out = Gc.unsqueeze(-1)
            Gc_joint = torch.cat((Gc_in, Gc_out), 1)
            Gc_joint = self.gg_channel(Gc_joint)

            g_xc = self.gx_channel(xc)
            g_xc = torch.mean(g_xc, dim=1, keepdim=True)
            yc = torch.cat((g_xc, Gc_joint), 1)

            W_yc = self.W_channel(yc).transpose(1, 2)
            out = F.sigmoid(W_yc) * x

            return out

if __name__ == '__main__':
    block = RGA_Module(in_channel=3, in_spatial=32*32) # in_spatial应是h*w
    input = torch.rand(32, 3, 32, 32)
    output = block(input)
    print(input.size())
    print(output.size())
相关推荐
paixiaoxin5 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
AI视觉网奇6 小时前
人脸生成3d模型 Era3D
人工智能·计算机视觉
编码小哥6 小时前
opencv中的色彩空间
opencv·计算机视觉
吃个糖糖7 小时前
34 Opencv 自定义角点检测
人工智能·opencv·计算机视觉
葡萄爱9 小时前
OpenCV图像分割
人工智能·opencv·计算机视觉
深度学习lover12 小时前
<项目代码>YOLO Visdrone航拍目标识别<目标检测>
python·yolo·目标检测·计算机视觉·visdrone航拍目标识别
编码小哥14 小时前
深入解析Mat对象:计算机视觉中的核心数据结构
opencv·计算机视觉
liuming199214 小时前
Halcon中histo_2dim(Operator)算子原理及应用详解
图像处理·人工智能·深度学习·算法·机器学习·计算机视觉·视觉检测
Asiram_15 小时前
大数据机器学习与计算机视觉应用08:反向传播
大数据·机器学习·计算机视觉
深度学习lover1 天前
[项目代码] YOLOv8 遥感航拍飞机和船舶识别 [目标检测]
python·yolo·目标检测·计算机视觉·遥感航拍飞机和船舶识别