【深度学习|目标跟踪】StrongSORT 详解(以及StrongSORT++)

StrongSort详解

1、论文及源码

论文地址:https://arxiv.org/pdf/2202.13514

源码地址:https://github.com/dyhBUPT/StrongSORT?tab=readme-ov-file

2、DeepSORT回顾

参考此篇博客

3、StrongSORT的EMA

EMA即指数加权移动平均,在StrongSort中,self.track在update时,会将当前提取到的目标特征向量进行如下两步操作:

  • 将特征向量归一化
  • 指数加权移动平均

在代码中的体现,在track 脚本文件中,Track 类:

相比较于DeepSort中将每次的匹配上的目标的特征向量直接存储到对应track的gallery中,StrongSort的做法可以有效的平滑视频中间由于遮挡,噪声或其他不利因素导致的目标特征的衰减,使得在特征在匹配时的准确率能够有效提升。

Trackerupdate 的最后,我们会更新对应track 的特征库:

nn_matching 脚本中的NearestNeighborDistanceMetric 类中:

这里的self.budget可以在config.yaml中指定,意思是存储当前帧的前多少帧的特征进特征库。本质上还是和DeepSort一样,给每个track弄了个特征库,只不过特征库中的特征向量是从第一帧开始就进行指数加权移动平均并归一化的特征向量。

4、StrongSORT的NSA Kalman

在DeepSort中使用的是一个普通的kalman filter,即通过状态量直接估计下一时刻的状态,只有位置信息,而StrongSort中融合了目标的置信度信息,在计算噪声的均值和方差时,加入了track对应的检测目标的置信度信息,自适应的调整噪声。自适应计算噪声协方差的公式:

其中,Rk是预设的常数测量噪声协方差,Ck 是状态 k 下的检测置信度分数。即当置信度高时,意味着这次检测的结果有较小的噪声,对下次的状态预测的影响较小。

现在来看一下源码中关于NSA Kalman的流程:

  • strong_sort.py中StrongSort类的self.tracker.predict()
  • 进入tracker.py中的Tracker类的predict()方法,遍历self.tracks然后进入track的predict()预测,然后进入track.py中的Track类的predict()预测,然后进入kalman_filter.py中的KalmanFilter类中的predict()方法来计算均值和方差。
  • 然后回到strong_sort.py中的self.tracker.update(),进入tracker.py的Tracker类的update()方法,经过self._match()匹配完成之后,会得到匹配列表,未匹配的跟踪,未匹配的检测。然后我们遍历匹配列表来,跟新我们已经匹配上的track,进入track.py中的Track类的update()方法,在这里我们会看到加入了检测框的confidence来更新均值和方差:
  • 进入kalman_filter.py中的KalmanFilter类中的update()方法,进入self.project()方法计算出融合了confidence后的均值和方差。

这便是代码中NSA Kalman的一个实现过程。

5、StrongSORT的MC

在DeepSort中,虽然说结合了外观特征和运动特征来进行跟踪,但是DeepSort的lambda权重是设置成0或1的,因此并没有真正的结合外观特征追踪和运动特征追踪,只是将他们分成两个阶段分别匹配。而在StrongSort中,在外观特征匹配阶段,引入了一个mc_lambda权重作用在运动特征的门控矩阵上,结合外观代价矩阵来计算得到最后的cost_matrix,公式如下:

代码如下:

这是在_match()方法中的gated_metric()方法中调用。

6、StrongSORT的BOT特征提取器

这是一个关于行人重识别的特征提取网络,具体的还没有深入了解过,我个人觉得这里的特征提取器用什么不是StrongSort的亮点,毕竟这个模块是个即插即用的模块,用不同的分类网络或者孪生网络训练出来的特征提取网络都可以用在这个地方。

AFLink提出将两段30帧(可以自己调整,修改AFLink的分类网络重新训练即可)的时空序列作为输入,这两段tracklets分别是1 * 30 * 3维度的一个特征向量,其中30表示帧数,3分别表示帧数(时间),x,y(空间),即结合了时空的一段序列。网络将会输出这两段tracklets的相似置信度。

  • AFLink网络的定义:
python 复制代码
"""
@Author: Du Yunhao
@Filename: model.py
@Contact: dyh_bupt@163.com
@Time: 2021/12/28 14:13
@Discription: model
"""
import torch
from torch import nn

class TemporalBlock(nn.Module):
    def __init__(self, cin, cout):
        super(TemporalBlock, self).__init__()
        self.conv = nn.Conv2d(cin, cout, (7, 1), bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.bnf = nn.BatchNorm1d(cout)
        self.bnx = nn.BatchNorm1d(cout)
        self.bny = nn.BatchNorm1d(cout)

    def bn(self, x):
        x[:, :, :, 0] = self.bnf(x[:, :, :, 0])
        x[:, :, :, 1] = self.bnx(x[:, :, :, 1])
        x[:, :, :, 2] = self.bny(x[:, :, :, 2])
        return x

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class FusionBlock(nn.Module):
    def __init__(self, cin, cout):
        super(FusionBlock, self).__init__()
        self.conv = nn.Conv2d(cin, cout, (1, 3), bias=False)
        self.bn = nn.BatchNorm2d(cout)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Classifier(nn.Module):
    def __init__(self, cin):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(cin*2, cin//2)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(cin//2, 2)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class PostLinker(nn.Module):
    def __init__(self):
        super(PostLinker, self).__init__()
        self.TemporalModule_1 = nn.Sequential(
            TemporalBlock(1, 32),
            TemporalBlock(32, 64),
            TemporalBlock(64, 128),
            TemporalBlock(128, 256)
        )
        self.TemporalModule_2 = nn.Sequential(
            TemporalBlock(1, 32),
            TemporalBlock(32, 64),
            TemporalBlock(64, 128),
            TemporalBlock(128, 256)
        )
        self.FusionBlock_1 = FusionBlock(256, 256)
        self.FusionBlock_2 = FusionBlock(256, 256)
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = Classifier(256)

    def forward(self, x1, x2):
        x1 = x1[:, :, :, :3]
        x2 = x2[:, :, :, :3]
        x1 = self.TemporalModule_1(x1)  # [B,1,30,3] -> [B,256,6,3]
        x2 = self.TemporalModule_2(x2)
        x1 = self.FusionBlock_1(x1)
        x2 = self.FusionBlock_2(x2)
        x1 = self.pooling(x1).squeeze(-1).squeeze(-1)
        x2 = self.pooling(x2).squeeze(-1).squeeze(-1)
        y = self.classifier(x1, x2)
        if not self.training:
            y = torch.softmax(y, dim=1)
        return y


if __name__ == '__main__':
    x1 = torch.ones((3, 1, 30, 3))
    x2 = torch.ones((3, 1, 30, 3))
    m = PostLinker()
    m.eval()
    # 提取第一个维度的第二个元素作为置信度(0表示第一个维度,1表示该维度的索引)
    y1 = m(x1, x2)[0, 1].detach().cpu().numpy()
    print(y1)

输入两个tracklets,维度分别是[1, 1, 30, 3];

||

v
Temporal module: 特征层维度[1, 256, 6, 3];

||

v
Fusion module: 特征层维度[1, 256, 6, 1];

||

v
Pooling + Squeeze:特征层维度[1, 256];

||

v
Classifier:特征层维度[1, 2];

  • AFLink的推理:
    AFLink是一个离线模块,在目标检测+跟踪推理完成后,将每一帧的track信息(帧数,位置)保存到txt中,然后离线的使用AFLink读取这个txt来得到最终的推理结果,也将序列信息保存到txt中。我们可以根据txt中的帧数和坐标信息来将检测结果可视化在视频中。

8、StrongSORT++的GSI模块

GSI模块即高斯平滑插值,为了弥补跟踪之后的检测结果在帧上不够连续的情况。这一模块个人感觉对于跟踪的结果上没有影响,在完全跟对的情况下,能够改善跟踪的连续性的视觉效果,但是如果跟踪错误的前提下,GSI之后,结果会变得很奇怪(框开始飘来飘去)。

在源码中,大致上分为两步骤:

  • 设定一个插帧得阈值,小于这个阈值得两个不连续的帧之间我们会一帧帧插值,而大鱼这个阈值的,我们只插这个阈值数量的帧在中间。
  • 高斯平滑,来使得插的帧的跟踪信息显得更加的连续平滑。
相关推荐
www_3dyz_com12 分钟前
人工智能在VR展览中扮演什么角色?
人工智能·vr
刘不二18 分钟前
大模型应用—HivisionIDPhotos 证件照在线制作!支持离线、换装、美颜等
人工智能·开源
feilieren28 分钟前
AI 视频:初识 Pika 2.0,基本使用攻略
人工智能·ai视频
开放知识图谱1 小时前
论文浅尝 | HippoRAG:神经生物学启发的大语言模型的长期记忆(Neurips2024)
人工智能·语言模型·自然语言处理
威化饼的一隅1 小时前
【多模态】swift-3框架使用
人工智能·深度学习·大模型·swift·多模态
人类群星闪耀时2 小时前
大模型技术优化负载均衡:AI驱动的智能化运维
运维·人工智能·负载均衡
编码小哥2 小时前
通过opencv加载、保存视频
人工智能·opencv
机器学习之心2 小时前
BiTCN-BiGRU基于双向时间卷积网络结合双向门控循环单元的数据多特征分类预测(多输入单输出)
深度学习·分类·gru
发呆小天才O.oᯅ2 小时前
YOLOv8目标检测——详细记录使用OpenCV的DNN模块进行推理部署C++实现
c++·图像处理·人工智能·opencv·yolo·目标检测·dnn
lovelin+v175030409662 小时前
智能电商:API接口如何驱动自动化与智能化转型
大数据·人工智能·爬虫·python