【深度学习|目标跟踪】StrongSORT 详解(以及StrongSORT++)

StrongSort详解

1、论文及源码

论文地址:https://arxiv.org/pdf/2202.13514

源码地址:https://github.com/dyhBUPT/StrongSORT?tab=readme-ov-file

2、DeepSORT回顾

参考此篇博客

3、StrongSORT的EMA

EMA即指数加权移动平均,在StrongSort中,self.track在update时,会将当前提取到的目标特征向量进行如下两步操作:

  • 将特征向量归一化
  • 指数加权移动平均

在代码中的体现,在track 脚本文件中,Track 类:

相比较于DeepSort中将每次的匹配上的目标的特征向量直接存储到对应track的gallery中,StrongSort的做法可以有效的平滑视频中间由于遮挡,噪声或其他不利因素导致的目标特征的衰减,使得在特征在匹配时的准确率能够有效提升。

Trackerupdate 的最后,我们会更新对应track 的特征库:

nn_matching 脚本中的NearestNeighborDistanceMetric 类中:

这里的self.budget可以在config.yaml中指定,意思是存储当前帧的前多少帧的特征进特征库。本质上还是和DeepSort一样,给每个track弄了个特征库,只不过特征库中的特征向量是从第一帧开始就进行指数加权移动平均并归一化的特征向量。

4、StrongSORT的NSA Kalman

在DeepSort中使用的是一个普通的kalman filter,即通过状态量直接估计下一时刻的状态,只有位置信息,而StrongSort中融合了目标的置信度信息,在计算噪声的均值和方差时,加入了track对应的检测目标的置信度信息,自适应的调整噪声。自适应计算噪声协方差的公式:

其中,Rk是预设的常数测量噪声协方差,Ck 是状态 k 下的检测置信度分数。即当置信度高时,意味着这次检测的结果有较小的噪声,对下次的状态预测的影响较小。

现在来看一下源码中关于NSA Kalman的流程:

  • strong_sort.py中StrongSort类的self.tracker.predict()
  • 进入tracker.py中的Tracker类的predict()方法,遍历self.tracks然后进入track的predict()预测,然后进入track.py中的Track类的predict()预测,然后进入kalman_filter.py中的KalmanFilter类中的predict()方法来计算均值和方差。
  • 然后回到strong_sort.py中的self.tracker.update(),进入tracker.py的Tracker类的update()方法,经过self._match()匹配完成之后,会得到匹配列表,未匹配的跟踪,未匹配的检测。然后我们遍历匹配列表来,跟新我们已经匹配上的track,进入track.py中的Track类的update()方法,在这里我们会看到加入了检测框的confidence来更新均值和方差:
  • 进入kalman_filter.py中的KalmanFilter类中的update()方法,进入self.project()方法计算出融合了confidence后的均值和方差。

这便是代码中NSA Kalman的一个实现过程。

5、StrongSORT的MC

在DeepSort中,虽然说结合了外观特征和运动特征来进行跟踪,但是DeepSort的lambda权重是设置成0或1的,因此并没有真正的结合外观特征追踪和运动特征追踪,只是将他们分成两个阶段分别匹配。而在StrongSort中,在外观特征匹配阶段,引入了一个mc_lambda权重作用在运动特征的门控矩阵上,结合外观代价矩阵来计算得到最后的cost_matrix,公式如下:

代码如下:

这是在_match()方法中的gated_metric()方法中调用。

6、StrongSORT的BOT特征提取器

这是一个关于行人重识别的特征提取网络,具体的还没有深入了解过,我个人觉得这里的特征提取器用什么不是StrongSort的亮点,毕竟这个模块是个即插即用的模块,用不同的分类网络或者孪生网络训练出来的特征提取网络都可以用在这个地方。

AFLink提出将两段30帧(可以自己调整,修改AFLink的分类网络重新训练即可)的时空序列作为输入,这两段tracklets分别是1 * 30 * 3维度的一个特征向量,其中30表示帧数,3分别表示帧数(时间),x,y(空间),即结合了时空的一段序列。网络将会输出这两段tracklets的相似置信度。

  • AFLink网络的定义:
python 复制代码
"""
@Author: Du Yunhao
@Filename: model.py
@Contact: [email protected]
@Time: 2021/12/28 14:13
@Discription: model
"""
import torch
from torch import nn

class TemporalBlock(nn.Module):
    def __init__(self, cin, cout):
        super(TemporalBlock, self).__init__()
        self.conv = nn.Conv2d(cin, cout, (7, 1), bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.bnf = nn.BatchNorm1d(cout)
        self.bnx = nn.BatchNorm1d(cout)
        self.bny = nn.BatchNorm1d(cout)

    def bn(self, x):
        x[:, :, :, 0] = self.bnf(x[:, :, :, 0])
        x[:, :, :, 1] = self.bnx(x[:, :, :, 1])
        x[:, :, :, 2] = self.bny(x[:, :, :, 2])
        return x

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class FusionBlock(nn.Module):
    def __init__(self, cin, cout):
        super(FusionBlock, self).__init__()
        self.conv = nn.Conv2d(cin, cout, (1, 3), bias=False)
        self.bn = nn.BatchNorm2d(cout)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Classifier(nn.Module):
    def __init__(self, cin):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(cin*2, cin//2)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(cin//2, 2)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class PostLinker(nn.Module):
    def __init__(self):
        super(PostLinker, self).__init__()
        self.TemporalModule_1 = nn.Sequential(
            TemporalBlock(1, 32),
            TemporalBlock(32, 64),
            TemporalBlock(64, 128),
            TemporalBlock(128, 256)
        )
        self.TemporalModule_2 = nn.Sequential(
            TemporalBlock(1, 32),
            TemporalBlock(32, 64),
            TemporalBlock(64, 128),
            TemporalBlock(128, 256)
        )
        self.FusionBlock_1 = FusionBlock(256, 256)
        self.FusionBlock_2 = FusionBlock(256, 256)
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = Classifier(256)

    def forward(self, x1, x2):
        x1 = x1[:, :, :, :3]
        x2 = x2[:, :, :, :3]
        x1 = self.TemporalModule_1(x1)  # [B,1,30,3] -> [B,256,6,3]
        x2 = self.TemporalModule_2(x2)
        x1 = self.FusionBlock_1(x1)
        x2 = self.FusionBlock_2(x2)
        x1 = self.pooling(x1).squeeze(-1).squeeze(-1)
        x2 = self.pooling(x2).squeeze(-1).squeeze(-1)
        y = self.classifier(x1, x2)
        if not self.training:
            y = torch.softmax(y, dim=1)
        return y


if __name__ == '__main__':
    x1 = torch.ones((3, 1, 30, 3))
    x2 = torch.ones((3, 1, 30, 3))
    m = PostLinker()
    m.eval()
    # 提取第一个维度的第二个元素作为置信度(0表示第一个维度,1表示该维度的索引)
    y1 = m(x1, x2)[0, 1].detach().cpu().numpy()
    print(y1)

输入两个tracklets,维度分别是[1, 1, 30, 3];

||

v
Temporal module: 特征层维度[1, 256, 6, 3];

||

v
Fusion module: 特征层维度[1, 256, 6, 1];

||

v
Pooling + Squeeze:特征层维度[1, 256];

||

v
Classifier:特征层维度[1, 2];

  • AFLink的推理:
    AFLink是一个离线模块,在目标检测+跟踪推理完成后,将每一帧的track信息(帧数,位置)保存到txt中,然后离线的使用AFLink读取这个txt来得到最终的推理结果,也将序列信息保存到txt中。我们可以根据txt中的帧数和坐标信息来将检测结果可视化在视频中。

8、StrongSORT++的GSI模块

GSI模块即高斯平滑插值,为了弥补跟踪之后的检测结果在帧上不够连续的情况。这一模块个人感觉对于跟踪的结果上没有影响,在完全跟对的情况下,能够改善跟踪的连续性的视觉效果,但是如果跟踪错误的前提下,GSI之后,结果会变得很奇怪(框开始飘来飘去)。

在源码中,大致上分为两步骤:

  • 设定一个插帧得阈值,小于这个阈值得两个不连续的帧之间我们会一帧帧插值,而大鱼这个阈值的,我们只插这个阈值数量的帧在中间。
  • 高斯平滑,来使得插的帧的跟踪信息显得更加的连续平滑。
相关推荐
gogoMark5 小时前
口播视频怎么剪!利用AI提高口播视频剪辑效率并增强”网感”
人工智能·音视频
2201_754918416 小时前
OpenCV 特征检测全面解析与实战应用
人工智能·opencv·计算机视觉
love530love7 小时前
Windows避坑部署CosyVoice多语言大语言模型
人工智能·windows·python·语言模型·自然语言处理·pycharm
985小水博一枚呀8 小时前
【AI大模型学习路线】第二阶段之RAG基础与架构——第七章(【项目实战】基于RAG的PDF文档助手)技术方案与架构设计?
人工智能·学习·语言模型·架构·大模型
白熊1888 小时前
【图像生成大模型】Wan2.1:下一代开源大规模视频生成模型
人工智能·计算机视觉·开源·文生图·音视频
weixin_514548898 小时前
一种开源的高斯泼溅实现库——gsplat: An Open-Source Library for Gaussian Splatting
人工智能·计算机视觉·3d
四口鲸鱼爱吃盐8 小时前
BMVC2023 | 多样化高层特征以提升对抗迁移性
人工智能·深度学习·cnn·vit·对抗攻击·迁移攻击
Echo``9 小时前
3:OpenCV—视频播放
图像处理·人工智能·opencv·算法·机器学习·视觉检测·音视频
Douglassssssss9 小时前
【深度学习】使用块的网络(VGG)
网络·人工智能·深度学习
okok__TXF9 小时前
SpringBoot3+AI
java·人工智能·spring