题目:Relation-Aware Global Attention for Person Re-identification
论文地址:https://arxiv.org/pdf/1904.02998
创新点
-
全局关系感知:RGA模块通过捕获全局结构信息(如特征节点之间的相似度或亲和力)来挖掘全局范围的关系,这种结构化的全球信息有助于更好地进行语义推断和注意力分配。
-
特征节点的紧凑表示:通过将每个特征节点与所有其他节点的两两关系和特征本身进行堆叠表示,RGA模块能够在更大范围内评估特征的重要性,从而提高辨别力。
-
双重应用(空间和通道注意力):RGA模块在空间(RGA-S)和通道(RGA-C)维度均可应用,通过序列或并行结合使用进一步提升模型性能。
-
轻量级设计:RGA通过浅层卷积来计算全局注意力值,相比传统非局部块,具有更高的适应性和计算效率。
方法
整体结构
论文提出的模型结构基于ResNet-50骨干网络,在其四个残差块后引入了Relation-Aware Global Attention(RGA)模块,通过空间和通道维度的关系感知增强特征表征。RGA模块结合了全局结构信息,使模型能够关注行人图像中的关键区域,同时使用分类和三元组损失优化模型性能,实现更高的识别精度。
-
基线网络(ResNet-50):使用ResNet-50作为基础骨干网络,同时去掉了最后一层的空间下采样,以更好地保留空间信息。
-
RGA模块的插入:在ResNet-50的四个残差块(即conv2_x、conv3_x、conv4_x和conv5_x)之后分别添加RGA模块。RGA模块的插入分为空间维度注意力(RGA-S)和通道维度注意力(RGA-C),它们可以单独使用或联合使用,以增强特征提取的全局结构感知。
-
训练损失函数:模型训练使用了识别损失(classification loss)和三元组损失(triplet loss),前者通过标签平滑技术强化分类效果,后者通过hard mining技术提升样本区分能力。
-
数据增强和优化器:采用随机裁剪、水平翻转、随机擦除等常规的数据增强策略,并使用Adam优化器进行训练,以更好地适应行人重识别任务。
即插即用模块作用
RGA_model 作为一个即插即用模块,主要适用于:
-
行人重识别:在行人重识别任务中,RGA可帮助模型关注行人关键区域,忽略背景杂讯。
-
目标检测和图像分类:在目标检测和图像分类任务中,RGA模块能够提升模型对目标区域的识别精度。
-
复杂背景和遮挡处理:在含有复杂背景、部分遮挡或姿态多变的图像场景中,RGA模块可以增强特征的全局感知,改善模型的稳健性和精确度。
消融实验结果
-
展示了不同配置下RGA模块对模型性能的影响,包括仅使用空间维度注意力(RGA-S)、仅使用通道维度注意力(RGA-C)以及二者的组合(如RGA-SC),并分别对比了基线模型和各模块的不同组合形式。
-
无论单独使用RGA-S或RGA-C,都能显著提高模型性能,证明了RGA模块对特征提取的增强作用。
-
组合使用RGA-S和RGA-C(尤其是RGA-SC序列组合)时,模型性能达到最佳,进一步验证了空间和通道关系感知的联合作用对提升识别效果的重要性
即插即用模块
python
import torch
from torch import nn
from torch.nn import functional as F
class RGA_Module(nn.Module):
def __init__(self, in_channel, in_spatial, use_spatial=True, use_channel=True,
cha_ratio=8, spa_ratio=8, down_ratio=8):
super(RGA_Module, self).__init__()
self.in_channel = in_channel
self.in_spatial = in_spatial
self.use_spatial = use_spatial
self.use_channel = use_channel
print('Use_Spatial_Att: {};\tUse_Channel_Att: {}.'.format(self.use_spatial, self.use_channel))
# self.inter_channel = in_channel // cha_ratio
# self.inter_spatial = in_spatial // spa_ratio
self.inter_channel = max(in_channel // cha_ratio, 1)
self.inter_spatial = max(in_spatial // spa_ratio, 1)
# Embedding functions for original features
if self.use_spatial:
self.gx_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
if self.use_channel:
self.gx_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
# Embedding functions for relation features
if self.use_spatial:
self.gg_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial * 2, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
if self.use_channel:
self.gg_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
# Networks for learning attention weights
if self.use_spatial:
num_channel_s = 1 + self.inter_spatial
self.W_spatial = nn.Sequential(
nn.Conv2d(in_channels=num_channel_s, out_channels=num_channel_s // down_ratio,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(num_channel_s // down_ratio),
nn.ReLU(),
nn.Conv2d(in_channels=num_channel_s // down_ratio, out_channels=1,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1)
)
if self.use_channel:
num_channel_c = max(1 + self.inter_channel, 1) # 确保至少为1
self.W_channel = nn.Sequential(
nn.Conv2d(in_channels=num_channel_c, out_channels=max(num_channel_c // down_ratio, 1), # 防止为0
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(max(num_channel_c // down_ratio, 1)), # 同样防止为0
nn.ReLU(),
nn.Conv2d(in_channels=max(num_channel_c // down_ratio, 1), out_channels=1,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1)
)
# Embedding functions for modeling relations
if self.use_spatial:
self.theta_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
self.phi_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
if self.use_channel:
self.theta_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
self.phi_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
def forward(self, x):
b, c, h, w = x.size()
if self.use_spatial:
# spatial attention
theta_xs = self.theta_spatial(x)
phi_xs = self.phi_spatial(x)
theta_xs = theta_xs.view(b, self.inter_channel, -1)
theta_xs = theta_xs.permute(0, 2, 1)
phi_xs = phi_xs.view(b, self.inter_channel, -1)
Gs = torch.matmul(theta_xs, phi_xs)
Gs_in = Gs.permute(0, 2, 1).view(b, h * w, h, w)
Gs_out = Gs.view(b, h * w, h, w)
Gs_joint = torch.cat((Gs_in, Gs_out), 1)
Gs_joint = self.gg_spatial(Gs_joint)
g_xs = self.gx_spatial(x)
g_xs = torch.mean(g_xs, dim=1, keepdim=True)
ys = torch.cat((g_xs, Gs_joint), 1)
W_ys = self.W_spatial(ys)
if not self.use_channel:
out = F.sigmoid(W_ys.expand_as(x)) * x
return out
else:
x = F.sigmoid(W_ys.expand_as(x)) * x
if self.use_channel:
# channel attention
xc = x.view(b, c, -1).permute(0, 2, 1).unsqueeze(-1)
theta_xc = self.theta_channel(xc).squeeze(-1).permute(0, 2, 1)
phi_xc = self.phi_channel(xc).squeeze(-1)
Gc = torch.matmul(theta_xc, phi_xc)
Gc_in = Gc.permute(0, 2, 1).unsqueeze(-1)
Gc_out = Gc.unsqueeze(-1)
Gc_joint = torch.cat((Gc_in, Gc_out), 1)
Gc_joint = self.gg_channel(Gc_joint)
g_xc = self.gx_channel(xc)
g_xc = torch.mean(g_xc, dim=1, keepdim=True)
yc = torch.cat((g_xc, Gc_joint), 1)
W_yc = self.W_channel(yc).transpose(1, 2)
out = F.sigmoid(W_yc) * x
return out
if __name__ == '__main__':
block = RGA_Module(in_channel=3, in_spatial=32*32) # in_spatial应是h*w
input = torch.rand(32, 3, 32, 32)
output = block(input)
print(input.size())
print(output.size())