训练自己的多目标跟踪特征提取网络——DeepSort代码篇

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

beginning

上期给小伙伴讲解了多目标追踪DeepSort算法的基本原理与工作流程,还没看过的小伙伴赶紧来瞅一瞅叭➡ 一文详解DeepSort多目标追踪算法------原理篇。但读万卷书也要走万里路,只懂书本上的原理是万万不够滴,所以本期就带来了DeepSort的代码讲解,包括如何训练自己的特征提取网络用于目标追踪,从代码实战上深入了解。如果你也对此感兴趣,想动手实现一个自己的目标追踪模型,废话不多说,让我们一起愉快的学习叭🎈🎈🎈


1.DeepSort代码讲解

网上有很多版本的DeepSort代码,本文使用的是一位up主基于yolov5的DeepSort 代码版本➡源码+数据集(提取码flng)🌈🌈🌈

代码的目录大体如下:

可以看到代码分为三个部分:

  1. 追踪相关代码和权重:放的是追踪的代码,里面最主要的是deep_sort.py文件(后面会重点讲解)
  2. 检测相关代码和权重:这里使用的是基于yolov5的检测算法,这里可能就有小伙伴问了,追踪为啥要用到目标检测算法腻?因为首先要把目标定位检测到一个框框里,然后才能对其进行追踪呀
  3. 调用追踪和检测的相关py文件:包括检测器和追踪器

下面来看一下代码里的各个参数设置:

<math xmlns="http://www.w3.org/1998/Math/MathML"> d e e p _ s o r t / c o n f i g s / d e e p _ s o r t . y a m l 文件目录下🎈🎈🎈 \color{blue}{deep\_sort/configs/deep\_sort.yaml文件目录下🎈🎈🎈} </math>deep_sort/configs/deep_sort.yaml文件目录下🎈🎈🎈:

  • 特征提取网络权重的路径,我们下面训练特征提取网络之后,就会生成一个这样的权重文件
  • 用于级联匹配的最大余弦距离,如果大于这个阈值则忽略
  • 置信度的阈值------首先要对物体进行一个检测,如果大于这个阈值,检测框就保留下来
  • 非极大抑制阈值------检测到的框有好多个,我们取其中最好的框,若设置为1代表不进行抑制
  • 最大IOU阈值
  • 最大寿命------假设物体发生了遮挡,检测时失帧,经过MAX_AGE帧没有追踪到该物体,那么轨迹就会变成不确定态,就删除该轨迹
  • 最高击中次数------若预测框和检测框连续击中该次数,就从不确定态转为确定态
  • 最大保存特征帧数------每次进行检测时要进行特征提取,我们会将提取的特征进行保存,每个轨迹在删除之前最高保存100帧(可以更改)

<math xmlns="http://www.w3.org/1998/Math/MathML"> d e e p _ s o r t / d e e p _ s o r t / s o r t / 文件目录下🎈🎈🎈 \color{blue}{deep\_sort/deep\_sort/sort/文件目录下🎈🎈🎈} </math>deep_sort/deep_sort/sort/文件目录下🎈🎈🎈:

  • detection.py:用来保存经过目标检测后检测到的框框,包括框的置信度以及获取的特征
  • iou_matching.py:计算预测框和检测框的交并比
  • kalman_filter.py:卡尔曼滤波的相关代码,用来预测预测框的轨迹
  • linear_assignment.py:匈牙利算法的相关代码,用来匹配预测的轨迹框和检测框的最佳匹配效果
  • nn_matching.py:计算欧氏距离、余弦距离等来计算最近邻距离
  • preprocessing.py:非极大值抑制代码
  • track.py:存储轨迹信息,包括轨迹框的位置、速度、ID、状态等
  • tracker.py:保存所有的轨迹信息,负责初始化第一帧、卡尔曼滤波的预测和更新、级联匹配、IOU匹配等,这个模块几乎负责整个的过程,是很重要滴

理解了各个参数之后,咱们重点看一下deep_sort.py文件中的代码

python 复制代码
class DeepSort(object):
    def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True):
        self.min_confidence = min_confidence # 检测结果置信度阈值 
        self.nms_max_overlap = nms_max_overlap # 非极大抑制阈值,设置为1代表不进行抑制

        self.extractor = Extractor(model_path, use_cuda=use_cuda) # 用于提取一个batch图片对应的特征

        max_cosine_distance = max_dist # 最大余弦距离,用于级联匹配,如果大于该阈值,则忽略
        nn_budget = 100 # 每个类别gallery最多的外观描述子的个数,如果超过,删除旧的
        # NearestNeighborDistanceMetric 最近邻距离度量
        # 对于每个目标,返回到目前为止已观察到的任何样本的最近距离(欧式或余弦)。
        # 由距离度量方法构造一个 Tracker。
        # 第一个参数可选'cosine' or 'euclidean'
        metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
        self.tracker = Tracker(metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)

    def update(self, bbox_xywh, confidences, ori_img):
        self.height, self.width = ori_img.shape[:2]
        # generate detections
        # 从原图中抠取bbox对应图片并计算得到相应的特征
        features = self._get_features(bbox_xywh, ori_img)
        bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)        
        # 筛选掉小于min_confidence的目标,并构造一个Detection对象构成的列表
        detections = [Detection(bbox_tlwh[i], conf, features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence]

        # run on non-maximum supression
        boxes = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])
        indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
        detections = [detections[i] for i in indices]

        # update tracker
        self.tracker.predict() # 将跟踪状态分布向前传播一步
        self.tracker.update(detections) # 执行测量更新和跟踪管理

        # output bbox identities
        outputs = []
        for track in self.tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            box = track.to_tlwh()
            x1,y1,x2,y2 = self._tlwh_to_xyxy(box)
            track_id = track.track_id
            outputs.append(np.array([x1,y1,x2,y2,track_id], dtype=np.int))
        if len(outputs) > 0:
            outputs = np.stack(outputs,axis=0)
        return outputs

其中,在初始化函数init()中,model_path代表外观特征网络的路径,剩下几个参数都是阈值。初始化之后来到了update()函数 ,我们每一个时间步每一帧都要执行这个函数,核心部分是需要把检测器输出的边界框给恢复到实际大小,恢复之后进行一个置信度的筛选,把置信度低的框扔掉,然后再对筛选后的框提取外观特征,实例化每一个检测,把它整合成一个Detection的类。接下来我们的任务就是把检测的框与轨迹进行匹配,最关键的就是这两行------self.tracker.predict()和self.tracker.update(detections),之后就是一些后处理🎶🎶🎶

我们到deep_sort/deep_sort/sort/tracker.py/文件目录下看predict()做了一件什么事情腻?predict()就是将现在我们已经有的轨迹里面的每一个轨迹都进行卡尔曼的更新,预测下一帧的位置和协方差矩阵🌻🌻🌻

下一步进行deep_sort/deep_sort/sort/tracker.py/文件目录下的update更新,检测框一进来首先进行match()的匹配,match()做了一件什么事情腻 ?如下面代码所示,里面定义了gated_metric()函数,这个函数需要好好的品一品 🧐🧐🧐函数有四个参数,分别是所有的轨迹、所有的检测、轨迹的索引和检测的索引,相当于每一次只对tracks里面的这些索引和dets里面的这些索引进行匹配,其他的索引它就不管了,这是一种写法上的差别。我们把每一个检测的特征、轨迹的ID提取出来,就可以算一下检测和轨迹的这个外观特征的代价矩阵......这时候可能你会有疑问了,为什么不直接传以轨迹为列表的那种形式,而是传索引腻?是因为它把度量专门写了一个类,这个类里面专门存了每一个轨迹的外观历史特征信息,所以只需要传索引就可以辽,这种设计有一种空间上的冗余,但是好在挺简洁滴✨✨✨

python 复制代码
    def _match(self, detections):

        def gated_metric(tracks, dets, track_indices, detection_indices):
            features = np.array([dets[i].feature for i in detection_indices])
            targets = np.array([tracks[i].track_id for i in track_indices])
            
            # 通过最近邻(余弦距离)计算出成本矩阵(代价矩阵)
            cost_matrix = self.metric.distance(features, targets)
            # 计算门控后的成本矩阵(代价矩阵)
            cost_matrix = linear_assignment.gate_cost_matrix(
                self.kf, cost_matrix, tracks, dets, track_indices,
                detection_indices)

            return cost_matrix

接着咱们就来运行一下demo.py文件 看看效果叭,为了更好理解代码,设置一下断点 。当运行到if cv2.getWindowProperty(name, cv2.WND_PROP_AUTOSIZE) < 1:时,代码停住了,并且出现了一个如下所示的框框,说明这是我们的第一帧⛳⛳⛳

但是我们发现画面中并没有出现检测框和追踪框,这是怎么个事儿腻 ?因为我们第一帧是对检测的框进行初始化,并不会出现追踪的结果;然后我们再继续第二帧,第二帧结束的时候框也没有画出来,得一直等到第三帧结束后(也就是上期原理篇里讲的,轨迹得和检测框击中3次后才能检测到框),物体的追踪框立马显示出来,如下所示。不断快速的运行,视频就会有追踪的效果啦🌴🌴🌴

2.训练特征提取网络

DeepSort算法效果比较好的其中一个原因就是它有一个特征提取网络,这个特征提取网络可以对比前后的视频帧,保留特征,避免目标丢失的追踪失败,所以咱们就来学习一下如何训练自己的特征提取网络模型叭🌈🌈🌈

首先上一个特征提取网络的代码,就是目录deep_sort/deep_sort/deep/model.py文件。其实特征提取的模型有很多,这里可以选择替换别的模型滴

模型选好之后,准备数据集。这里使用的是Market-1501 数据集 (链接上文中有,大家也可用自己的数据集),由6个摄像头拍摄到的1501个行人和32668个行人矩形框。如下所示:

接下来把下载好的数据集放到deep_sort/deep_sort/deep/路径下,运行prepare_person代码将数据集划分为训练集和测试集 ,运行之后会生成一个名叫pytorch的文件,其中文件下生成train文件和test文件 。最后我们在deep_sort/deep_sort/deep目录下新建一个Market-1501文件夹,将训练集和测试集放到该目录下就ok啦🎈🎈🎈

数据集准备好之后,利用deep_sort/deep_sort/deep/路径下的train.py文件来训练模型 。首先将训练集和测试集的路径填入到第14行default中,注意一定要是绝对路径 ❗❗❗然后也可以通过代码的第182行修改训练的轮数,默认训练轮数为40。最后运行train.py就大功告成辽。训练完之后就会在deep_sort/deep_sort/deep/checkpoint目录下生成一个新的权重文件ckpt.t7,然后你就可以愉快的食用它啦🍭🍭🍭

这里还有两个划线计数的趣味应用------count_car.py和count_person.py,也就是对出入的车辆行人进行计数。感兴趣的一起来玩一玩叭🌟🌟🌟

3.多目标追踪的应用

原理和实战之后,咱们来大开脑洞想一想多目标追踪在日常生活中都有哪些用处呢 ,这里可以有一些有趣的想法喔:

  • 购物体验升级:想象一下,你步入一家智能商店。多目标追踪技术可以识别你,知道你的购物偏好,并在你身边展示你可能感兴趣的商品。它还可以跟踪你拿起的物品,自动添加到你的购物车中,无需排队结账,真正实现"拿了就走"。这对我们女孩子来说简直完美✨✨✨
  • 智能家居:家可以通过多目标追踪变得更加智能。当你走进房间时,灯光和温度可以根据你的位置和偏好进行自动调整。家居设备也可以根据你的活动自动启动或关闭🌞🌞🌞
  • 宠物生活质量提升:你的宠物可以穿戴一个小型追踪器,用于室内外的定位。这将有助于你更好地理解它们的行为和需要,确保它们的快乐和安全,铲屎官听闻大喜😁😁😁
  • 智能运动训练:在锻炼时,多目标追踪可以监测你的动作和姿势,提供实时反馈,帮助你改进体能训练。你还可以与虚拟训练伙伴互动,仿佛在与真人一起运动🧸🧸🧸

这些应用场景只是一些有趣的点子,多目标追踪的未来潜力我相信是无穷无尽滴。


ending

看到这里相信盆友们都对多目标追踪算法DeepSort的实战应用有了更全面深入的了解啦!很开心能把学到的知识以文章的形式分享给大家🌴🌴🌴如果你也觉得我的分享对你有所帮助,please一键三连嗷!!!下期见

相关推荐
MUTA️2 分钟前
RT-DETR学习笔记(2)
人工智能·笔记·深度学习·学习·机器学习·计算机视觉
开发者每周简报36 分钟前
求职市场变化
人工智能·面试·职场和发展
AI前沿技术追踪1 小时前
OpenAI 12天发布会:AI革命的里程碑@附35页PDF文件下载
人工智能
余~~185381628001 小时前
稳定的碰一碰发视频、碰一碰矩阵源码技术开发,支持OEM
开发语言·人工智能·python·音视频
galileo20161 小时前
LLM与金融
人工智能
DREAM依旧2 小时前
隐马尔科夫模型|前向算法|Viterbi 算法
人工智能
GocNeverGiveUp2 小时前
机器学习2-NumPy
人工智能·机器学习·numpy
B站计算机毕业设计超人3 小时前
计算机毕业设计PySpark+Hadoop中国城市交通分析与预测 Python交通预测 Python交通可视化 客流量预测 交通大数据 机器学习 深度学习
大数据·人工智能·爬虫·python·机器学习·课程设计·数据可视化
学术头条3 小时前
清华、智谱团队:探索 RLHF 的 scaling laws
人工智能·深度学习·算法·机器学习·语言模型·计算语言学
18号房客3 小时前
一个简单的机器学习实战例程,使用Scikit-Learn库来完成一个常见的分类任务——**鸢尾花数据集(Iris Dataset)**的分类
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·sklearn