视频目标追踪完全指南:从原理到实战部署

本文系统讲解视频目标追踪的核心原理、主流算法(SORT/DeepSORT/ByteTrack/BoT-SORT)及完整PyTorch实战代码,助你快速掌握这一计算机视觉核心技术。


一、什么是视频目标追踪

1.1 任务定义

视频目标追踪(Video Object Tracking, VOT)是指在视频序列中持续定位并识别特定目标的技术。其核心任务是为每个目标分配唯一的ID,并在整个视频中保持ID的一致性。

复制代码
追踪任务示意:

帧1: 检测到3个人 → 分配ID: Person_1, Person_2, Person_3
帧2: 检测到3个人 → 关联: Person_1, Person_2, Person_3 (保持ID)
帧3: Person_2被遮挡 → 保持: Person_1, Person_3
帧4: Person_2重新出现 → 恢复: Person_1, Person_2, Person_3

1.2 追踪 vs 检测 vs 分割

任务 输出 时序关系 应用场景
目标检测 单帧边界框+类别 图像理解
语义分割 像素级类别 场景解析
目标追踪 跨帧边界框+唯一ID 行为分析、计数
实例分割+追踪 跨帧像素级mask+ID 视频编辑、VFX

1.3 应用场景

智能交通:车辆计数、轨迹分析、违章检测、交通流量统计。

安防监控:行人追踪、异常行为检测、跨摄像头追踪(ReID)。

自动驾驶:周围车辆和行人的持续追踪,预测运动轨迹。

体育分析:运动员追踪、战术分析、运动数据统计。

工业检测:生产线产品追踪、缺陷品溯源。

水下机器人:海洋生物追踪、水下目标持续监控。


二、目标追踪的两大范式

2.1 单目标追踪(SOT)

给定第一帧中目标的位置,在后续帧中持续追踪该目标。

复制代码
SOT工作流程:

第1帧: 用户/算法指定目标框 [x, y, w, h]
       ↓
第2帧: 在目标周围搜索,找到最相似区域
       ↓
第3帧: 继续搜索...
       ↓
...持续追踪直到视频结束或目标丢失

代表算法:SiamFC、SiamRPN、SiamMask、TransT、OSTrack

优点:无需检测器,计算量小,可追踪任意类别目标

缺点:只能追踪单个目标,容易漂移

2.2 多目标追踪(MOT)

同时追踪视频中的多个目标,是目前工业界主流方案。

复制代码
MOT工作流程(Tracking by Detection):

每一帧:
  1. 检测器 → 获得所有目标的边界框
  2. 特征提取 → 提取每个检测框的外观特征(可选)
  3. 运动预测 → 用卡尔曼滤波预测已有轨迹的新位置
  4. 数据关联 → 匹配检测框与已有轨迹
  5. 轨迹管理 → 创建新轨迹、删除丢失轨迹

代表算法:SORT、DeepSORT、ByteTrack、BoT-SORT、OC-SORT

优点:可追踪多目标,鲁棒性强

缺点:依赖检测器质量,计算量较大


三、核心算法详解

3.1 SORT:简单在线实时追踪

SORT(Simple Online and Realtime Tracking)是多目标追踪的基石算法,核心思想极其简洁。

核心组件

  1. 卡尔曼滤波:预测目标下一帧位置

  2. 匈牙利算法:最优匹配检测框与轨迹

  3. IoU距离:衡量检测框与预测框的相似度

    SORT算法流程:

    复制代码
                     检测结果
                        ↓

    ┌─────────────────────────────────────────┐
    │ 卡尔曼滤波预测 │
    │ 已有轨迹 → 预测下一帧位置 │
    └─────────────────────────────────────────┘

    ┌─────────────────────────────────────────┐
    │ 计算IoU代价矩阵 │
    │ Cost[i,j] = 1 - IoU(detection_i, │
    │ prediction_j) │
    └─────────────────────────────────────────┘

    ┌─────────────────────────────────────────┐
    │ 匈牙利算法匹配 │
    │ 找到使总代价最小的匹配方案 │
    └─────────────────────────────────────────┘

    ┌─────────────────────────────────────────┐
    │ 轨迹管理 │
    │ - 匹配成功: 更新轨迹 │
    │ - 检测未匹配: 创建新轨迹 │
    │ - 轨迹未匹配: 标记丢失/删除 │
    └─────────────────────────────────────────┘

SORT的局限:纯粹依赖运动信息(IoU),当目标被遮挡或快速运动时容易ID切换。

3.2 DeepSORT:引入外观特征

DeepSORT在SORT基础上引入深度外观特征,显著提升遮挡场景下的追踪性能。

核心改进

  1. 外观特征:使用ReID网络提取目标外观描述子

  2. 级联匹配:优先匹配最近出现的轨迹

  3. 马氏距离:结合运动和外观信息

    DeepSORT匹配策略:

    代价矩阵 = λ × 运动距离(马氏距离) + (1-λ) × 外观距离(余弦距离)

    复制代码
               检测框
          D1   D2   D3   D4
     T1 [0.2, 0.8, 0.9, 0.7]    ← 轨迹1与各检测框的距离
     T2 [0.7, 0.1, 0.6, 0.8]
     T3 [0.8, 0.7, 0.2, 0.9]

    级联匹配:
    第1轮: 匹配最近1帧内出现的轨迹
    第2轮: 匹配最近2帧内出现的轨迹
    ...
    第N轮: 匹配丢失N帧的轨迹

3.3 ByteTrack:利用低分检测框

ByteTrack的核心洞察:低置信度检测框也包含有价值信息,不应直接丢弃。

复制代码
ByteTrack两阶段匹配:

第一阶段:高分检测框匹配
  - 筛选置信度 > τ_high 的检测框
  - 与所有轨迹进行IoU匹配
  - 匹配成功的轨迹更新状态

第二阶段:低分检测框匹配
  - 筛选 τ_low < 置信度 < τ_high 的检测框
  - 仅与第一阶段未匹配的轨迹进行匹配
  - 用于恢复被遮挡的目标

典型阈值:τ_high = 0.6, τ_low = 0.1

ByteTrack的优势

  1. 无需ReID网络,速度快
  2. 有效处理遮挡导致的低分检测
  3. 在MOT17、MOT20等榜单上取得SOTA

3.4 BoT-SORT:当前最强方案

BoT-SORT结合了多种技术优势,是目前综合性能最好的追踪器之一。

核心组件

  1. 改进的卡尔曼滤波:使用相机运动补偿(CMC)

  2. IoU-ReID融合:结合运动和外观信息

  3. 轨迹状态管理:更精细的轨迹生命周期管理

    BoT-SORT改进点:

    1. 相机运动补偿(CMC)

      • 检测相邻帧的全局运动(如相机平移、旋转)
      • 在预测前补偿相机运动
      • 提升动态相机场景的追踪精度
    2. IoU-ReID融合
      Cost = min(Cost_IoU, Cost_ReID) # 取较小值
      而非 DeepSORT的加权融合

    3. 轨迹置信度

      • 根据匹配历史动态调整轨迹置信度
      • 低置信度轨迹更容易被删除

3.5 算法对比

算法 速度 精度 遮挡处理 需要ReID 适用场景
SORT ⭐⭐⭐⭐⭐ ⭐⭐ 简单场景、实时性要求高
DeepSORT ⭐⭐⭐ ⭐⭐⭐⭐ 遮挡多、需要ReID
ByteTrack ⭐⭐⭐⭐ ⭐⭐⭐⭐ 较好 通用场景、平衡性能
BoT-SORT ⭐⭐⭐ ⭐⭐⭐⭐⭐ 很好 可选 高精度要求
OC-SORT ⭐⭐⭐⭐ ⭐⭐⭐⭐ 较好 非线性运动场景

四、卡尔曼滤波详解

卡尔曼滤波是目标追踪的核心组件,用于预测目标的下一帧位置。

4.1 状态向量定义

python 复制代码
# 状态向量 x = [x, y, a, h, vx, vy, va, vh]
# x, y: 边界框中心坐标
# a: 宽高比 (aspect ratio = w/h)
# h: 高度
# vx, vy, va, vh: 对应的速度

# 观测向量 z = [x, y, a, h]
# 只能观测到位置,无法直接观测速度

4.2 卡尔曼滤波实现

python 复制代码
import numpy as np
from scipy.linalg import block_diag


class KalmanFilter:
    """
    目标追踪专用卡尔曼滤波器
    状态向量: [x, y, a, h, vx, vy, va, vh]
    观测向量: [x, y, a, h]
    """
    
    def __init__(self):
        ndim = 4  # 观测维度
        dt = 1.0  # 时间步长
        
        # 状态转移矩阵 F
        # 假设匀速运动模型: x' = x + vx * dt
        self._motion_mat = np.eye(2 * ndim, 2 * ndim)
        for i in range(ndim):
            self._motion_mat[i, ndim + i] = dt
        
        # 观测矩阵 H
        # 只能观测位置,不能观测速度
        self._update_mat = np.eye(ndim, 2 * ndim)
        
        # 过程噪声权重
        self._std_weight_position = 1. / 20
        self._std_weight_velocity = 1. / 160
    
    def initiate(self, measurement):
        """
        从第一次观测初始化轨迹
        
        Args:
            measurement: [x, y, a, h] 边界框观测
        Returns:
            mean: 状态均值
            covariance: 状态协方差
        """
        mean_pos = measurement
        mean_vel = np.zeros_like(mean_pos)
        mean = np.r_[mean_pos, mean_vel]
        
        # 初始协方差:位置不确定性小,速度不确定性大
        std = [
            2 * self._std_weight_position * measurement[3],   # x
            2 * self._std_weight_position * measurement[3],   # y
            1e-2,                                              # a
            2 * self._std_weight_position * measurement[3],   # h
            10 * self._std_weight_velocity * measurement[3],  # vx
            10 * self._std_weight_velocity * measurement[3],  # vy
            1e-5,                                              # va
            10 * self._std_weight_velocity * measurement[3],  # vh
        ]
        covariance = np.diag(np.square(std))
        
        return mean, covariance
    
    def predict(self, mean, covariance):
        """
        预测步骤:根据运动模型预测下一状态
        
        x' = F * x
        P' = F * P * F^T + Q
        """
        # 过程噪声 Q
        std_pos = [
            self._std_weight_position * mean[3],
            self._std_weight_position * mean[3],
            1e-2,
            self._std_weight_position * mean[3],
        ]
        std_vel = [
            self._std_weight_velocity * mean[3],
            self._std_weight_velocity * mean[3],
            1e-5,
            self._std_weight_velocity * mean[3],
        ]
        motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
        
        # 预测
        mean = np.dot(self._motion_mat, mean)
        covariance = np.linalg.multi_dot([
            self._motion_mat, covariance, self._motion_mat.T
        ]) + motion_cov
        
        return mean, covariance
    
    def project(self, mean, covariance):
        """
        将状态投影到观测空间
        
        z = H * x
        S = H * P * H^T + R
        """
        # 观测噪声 R
        std = [
            self._std_weight_position * mean[3],
            self._std_weight_position * mean[3],
            1e-1,
            self._std_weight_position * mean[3],
        ]
        innovation_cov = np.diag(np.square(std))
        
        mean = np.dot(self._update_mat, mean)
        covariance = np.linalg.multi_dot([
            self._update_mat, covariance, self._update_mat.T
        ]) + innovation_cov
        
        return mean, covariance
    
    def update(self, mean, covariance, measurement):
        """
        更新步骤:融合预测和观测
        
        K = P * H^T * S^(-1)
        x = x + K * (z - H * x)
        P = (I - K * H) * P
        """
        # 投影到观测空间
        projected_mean, projected_cov = self.project(mean, covariance)
        
        # 卡尔曼增益
        chol_factor = np.linalg.cholesky(projected_cov)
        kalman_gain = np.linalg.solve(
            chol_factor,
            np.linalg.solve(chol_factor, 
                           np.dot(covariance, self._update_mat.T).T).T
        ).T
        
        # 更新
        innovation = measurement - projected_mean
        new_mean = mean + np.dot(innovation, kalman_gain.T)
        new_covariance = covariance - np.linalg.multi_dot([
            kalman_gain, projected_cov, kalman_gain.T
        ])
        
        return new_mean, new_covariance
    
    def gating_distance(self, mean, covariance, measurements, only_position=False):
        """
        计算马氏距离(用于门控)
        """
        projected_mean, projected_cov = self.project(mean, covariance)
        
        if only_position:
            projected_mean = projected_mean[:2]
            projected_cov = projected_cov[:2, :2]
            measurements = measurements[:, :2]
        
        chol_factor = np.linalg.cholesky(projected_cov)
        d = measurements - projected_mean
        z = np.linalg.solve(chol_factor, d.T).T
        
        return np.sum(z * z, axis=1)

五、数据关联算法

5.1 IoU距离计算

python 复制代码
import numpy as np


def bbox_iou(bbox1, bbox2):
    """
    计算两个边界框的IoU
    
    Args:
        bbox1: [x1, y1, x2, y2]
        bbox2: [x1, y1, x2, y2]
    Returns:
        iou: 交并比
    """
    # 计算交集
    x1 = max(bbox1[0], bbox2[0])
    y1 = max(bbox1[1], bbox2[1])
    x2 = min(bbox1[2], bbox2[2])
    y2 = min(bbox1[3], bbox2[3])
    
    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    
    # 计算并集
    area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
    area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
    union_area = area1 + area2 - inter_area
    
    return inter_area / (union_area + 1e-6)


def iou_distance(tracks, detections):
    """
    计算轨迹与检测框之间的IoU距离矩阵
    
    Args:
        tracks: 轨迹列表
        detections: 检测框列表
    Returns:
        cost_matrix: [num_tracks, num_detections]
    """
    cost_matrix = np.zeros((len(tracks), len(detections)))
    
    for t, track in enumerate(tracks):
        track_bbox = track.to_tlbr()  # 转换为 [x1, y1, x2, y2]
        for d, det in enumerate(detections):
            det_bbox = det.to_tlbr()
            iou = bbox_iou(track_bbox, det_bbox)
            cost_matrix[t, d] = 1 - iou  # IoU越大,代价越小
    
    return cost_matrix

5.2 匈牙利算法

python 复制代码
from scipy.optimize import linear_sum_assignment


def hungarian_matching(cost_matrix, threshold=0.7):
    """
    使用匈牙利算法进行最优匹配
    
    Args:
        cost_matrix: 代价矩阵 [num_tracks, num_detections]
        threshold: 匹配阈值,超过此值不匹配
    Returns:
        matches: 匹配对列表 [(track_idx, det_idx), ...]
        unmatched_tracks: 未匹配的轨迹索引
        unmatched_detections: 未匹配的检测框索引
    """
    if cost_matrix.size == 0:
        return [], list(range(cost_matrix.shape[0])), list(range(cost_matrix.shape[1]))
    
    # 匈牙利算法求解
    row_indices, col_indices = linear_sum_assignment(cost_matrix)
    
    matches = []
    unmatched_tracks = list(range(cost_matrix.shape[0]))
    unmatched_detections = list(range(cost_matrix.shape[1]))
    
    for row, col in zip(row_indices, col_indices):
        if cost_matrix[row, col] > threshold:
            continue
        matches.append((row, col))
        unmatched_tracks.remove(row)
        unmatched_detections.remove(col)
    
    return matches, unmatched_tracks, unmatched_detections

5.3 级联匹配(DeepSORT)

python 复制代码
def cascade_matching(tracks, detections, track_indices, detection_indices,
                     cost_function, threshold, max_age=30):
    """
    级联匹配:优先匹配最近出现的轨迹
    
    Args:
        tracks: 轨迹列表
        detections: 检测框列表
        track_indices: 待匹配的轨迹索引
        detection_indices: 待匹配的检测框索引
        cost_function: 代价函数
        threshold: 匹配阈值
        max_age: 最大丢失帧数
    """
    matches = []
    unmatched_detections = list(detection_indices)
    
    # 按丢失帧数分组,优先匹配最近的
    for age in range(max_age):
        if len(unmatched_detections) == 0:
            break
        
        # 筛选当前age的轨迹
        track_indices_age = [
            t for t in track_indices
            if tracks[t].time_since_update == age + 1
        ]
        
        if len(track_indices_age) == 0:
            continue
        
        # 计算代价矩阵
        cost_matrix = cost_function(
            [tracks[t] for t in track_indices_age],
            [detections[d] for d in unmatched_detections]
        )
        
        # 匈牙利匹配
        matched, _, unmatched_det = hungarian_matching(cost_matrix, threshold)
        
        # 更新结果
        for t_idx, d_idx in matched:
            matches.append((track_indices_age[t_idx], unmatched_detections[d_idx]))
        
        unmatched_detections = [unmatched_detections[d] for d in unmatched_det]
    
    unmatched_tracks = [t for t in track_indices if t not in [m[0] for m in matches]]
    
    return matches, unmatched_tracks, unmatched_detections

六、完整追踪器实现

6.1 轨迹类

python 复制代码
from enum import Enum


class TrackState(Enum):
    """轨迹状态"""
    Tentative = 1   # 待确认
    Confirmed = 2   # 已确认
    Deleted = 3     # 已删除


class Track:
    """单个目标轨迹"""
    
    _count = 0  # 全局ID计数器
    
    def __init__(self, detection, track_id=None, n_init=3, max_age=30):
        """
        Args:
            detection: 初始检测框 [x, y, w, h, score, class_id]
            track_id: 轨迹ID(可选)
            n_init: 确认所需的连续匹配次数
            max_age: 最大丢失帧数
        """
        # 分配ID
        if track_id is None:
            Track._count += 1
            self.track_id = Track._count
        else:
            self.track_id = track_id
        
        # 卡尔曼滤波器
        self.kf = KalmanFilter()
        
        # 初始化状态
        measurement = self._bbox_to_measurement(detection[:4])
        self.mean, self.covariance = self.kf.initiate(measurement)
        
        # 轨迹属性
        self.hits = 1                    # 匹配次数
        self.age = 1                     # 存在帧数
        self.time_since_update = 0       # 距上次更新的帧数
        self.state = TrackState.Tentative
        
        # 参数
        self.n_init = n_init
        self.max_age = max_age
        
        # 检测信息
        self.score = detection[4] if len(detection) > 4 else 1.0
        self.class_id = int(detection[5]) if len(detection) > 5 else 0
        
        # 外观特征历史(用于ReID)
        self.features = []
    
    def _bbox_to_measurement(self, bbox):
        """将[x, y, w, h]转换为[cx, cy, a, h]"""
        x, y, w, h = bbox
        return np.array([x + w/2, y + h/2, w/h, h])
    
    def _measurement_to_bbox(self, measurement):
        """将[cx, cy, a, h]转换为[x, y, w, h]"""
        cx, cy, a, h = measurement
        w = a * h
        return np.array([cx - w/2, cy - h/2, w, h])
    
    def predict(self):
        """预测下一帧状态"""
        self.mean, self.covariance = self.kf.predict(self.mean, self.covariance)
        self.age += 1
        self.time_since_update += 1
    
    def update(self, detection, feature=None):
        """
        用新检测更新轨迹
        
        Args:
            detection: [x, y, w, h, score, class_id]
            feature: 外观特征向量(可选)
        """
        measurement = self._bbox_to_measurement(detection[:4])
        self.mean, self.covariance = self.kf.update(
            self.mean, self.covariance, measurement
        )
        
        # 更新属性
        self.hits += 1
        self.time_since_update = 0
        self.score = detection[4] if len(detection) > 4 else self.score
        
        # 更新外观特征
        if feature is not None:
            self.features.append(feature)
            if len(self.features) > 100:  # 保留最近100个特征
                self.features = self.features[-100:]
        
        # 状态转换
        if self.state == TrackState.Tentative and self.hits >= self.n_init:
            self.state = TrackState.Confirmed
    
    def mark_missed(self):
        """标记为丢失"""
        if self.state == TrackState.Tentative:
            self.state = TrackState.Deleted
        elif self.time_since_update > self.max_age:
            self.state = TrackState.Deleted
    
    def is_tentative(self):
        return self.state == TrackState.Tentative
    
    def is_confirmed(self):
        return self.state == TrackState.Confirmed
    
    def is_deleted(self):
        return self.state == TrackState.Deleted
    
    def to_tlwh(self):
        """返回 [x, y, w, h] 格式"""
        return self._measurement_to_bbox(self.mean[:4])
    
    def to_tlbr(self):
        """返回 [x1, y1, x2, y2] 格式"""
        bbox = self.to_tlwh()
        return np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])

6.2 ByteTrack追踪器

python 复制代码
class ByteTracker:
    """
    ByteTrack多目标追踪器
    
    特点:利用低分检测框恢复被遮挡目标
    """
    
    def __init__(self, 
                 track_thresh=0.5,      # 高分阈值
                 track_buffer=30,       # 轨迹保持帧数
                 match_thresh=0.8,      # 匹配阈值
                 low_thresh=0.1,        # 低分阈值
                 new_track_thresh=0.6): # 新建轨迹阈值
        
        self.track_thresh = track_thresh
        self.track_buffer = track_buffer
        self.match_thresh = match_thresh
        self.low_thresh = low_thresh
        self.new_track_thresh = new_track_thresh
        
        self.tracks = []           # 已确认轨迹
        self.lost_tracks = []      # 丢失轨迹
        self.removed_tracks = []   # 已删除轨迹
        
        self.frame_id = 0
        Track._count = 0  # 重置ID计数
    
    def update(self, detections):
        """
        处理一帧检测结果
        
        Args:
            detections: numpy array, shape [N, 6]
                       每行: [x, y, w, h, score, class_id]
        Returns:
            outputs: 追踪结果 [[x1, y1, x2, y2, track_id, class_id, score], ...]
        """
        self.frame_id += 1
        
        # 分离高分和低分检测
        if len(detections) > 0:
            scores = detections[:, 4]
            
            # 高分检测
            high_mask = scores >= self.track_thresh
            high_dets = detections[high_mask]
            
            # 低分检测
            low_mask = (scores < self.track_thresh) & (scores >= self.low_thresh)
            low_dets = detections[low_mask]
        else:
            high_dets = np.empty((0, 6))
            low_dets = np.empty((0, 6))
        
        # 合并已确认和丢失的轨迹
        all_tracks = self.tracks + self.lost_tracks
        
        # 预测所有轨迹
        for track in all_tracks:
            track.predict()
        
        # ========== 第一阶段:高分检测匹配 ==========
        if len(high_dets) > 0 and len(all_tracks) > 0:
            # 计算IoU代价矩阵
            cost_matrix = iou_distance(all_tracks, self._dets_to_tracks(high_dets))
            
            # 匈牙利匹配
            matches, unmatched_tracks, unmatched_dets = hungarian_matching(
                cost_matrix, threshold=self.match_thresh
            )
            
            # 更新匹配的轨迹
            for track_idx, det_idx in matches:
                all_tracks[track_idx].update(high_dets[det_idx])
            
            # 获取未匹配的轨迹和检测
            remain_tracks = [all_tracks[i] for i in unmatched_tracks]
            remain_high_dets = high_dets[unmatched_dets]
        else:
            remain_tracks = all_tracks
            remain_high_dets = high_dets
            matches = []
        
        # ========== 第二阶段:低分检测匹配 ==========
        if len(low_dets) > 0 and len(remain_tracks) > 0:
            cost_matrix = iou_distance(remain_tracks, self._dets_to_tracks(low_dets))
            
            matches_low, unmatched_tracks_low, _ = hungarian_matching(
                cost_matrix, threshold=0.5  # 低分匹配使用更宽松的阈值
            )
            
            # 更新匹配的轨迹
            for track_idx, det_idx in matches_low:
                remain_tracks[track_idx].update(low_dets[det_idx])
            
            remain_tracks = [remain_tracks[i] for i in unmatched_tracks_low]
        
        # ========== 轨迹管理 ==========
        # 标记丢失的轨迹
        for track in remain_tracks:
            track.mark_missed()
        
        # 为未匹配的高分检测创建新轨迹
        for det in remain_high_dets:
            if det[4] >= self.new_track_thresh:
                new_track = Track(det, max_age=self.track_buffer)
                self.tracks.append(new_track)
        
        # 更新轨迹列表
        self.tracks = [t for t in self.tracks + self.lost_tracks if not t.is_deleted()]
        self.lost_tracks = [t for t in self.tracks if not t.is_confirmed()]
        self.tracks = [t for t in self.tracks if t.is_confirmed()]
        
        # ========== 输出结果 ==========
        outputs = []
        for track in self.tracks:
            if track.is_confirmed() and track.time_since_update == 0:
                bbox = track.to_tlbr()
                outputs.append([
                    bbox[0], bbox[1], bbox[2], bbox[3],
                    track.track_id, track.class_id, track.score
                ])
        
        return np.array(outputs) if len(outputs) > 0 else np.empty((0, 7))
    
    def _dets_to_tracks(self, detections):
        """将检测结果转换为临时Track对象(用于计算IoU)"""
        tracks = []
        for det in detections:
            track = Track(det)
            tracks.append(track)
        return tracks

七、完整推理系统

7.1 视频追踪推理器

python 复制代码
import cv2
import time
import numpy as np
from ultralytics import YOLO


class VideoTracker:
    """
    完整的视频目标追踪系统
    集成YOLO检测 + ByteTrack追踪
    """
    
    # 类别颜色映射
    COLORS = np.random.randint(0, 255, size=(100, 3), dtype=np.uint8)
    
    def __init__(self, 
                 detector_path='yolov8n.pt',
                 conf_thresh=0.25,
                 iou_thresh=0.45,
                 classes=None,
                 device='cuda'):
        """
        Args:
            detector_path: YOLO模型路径
            conf_thresh: 检测置信度阈值
            iou_thresh: NMS IoU阈值
            classes: 追踪的类别ID列表,None表示全部
            device: 推理设备
        """
        # 初始化检测器
        self.detector = YOLO(detector_path)
        self.conf_thresh = conf_thresh
        self.iou_thresh = iou_thresh
        self.classes = classes
        self.device = device
        
        # 初始化追踪器
        self.tracker = ByteTracker(
            track_thresh=0.5,
            track_buffer=30,
            match_thresh=0.8
        )
        
        # 统计信息
        self.frame_count = 0
        self.total_time = 0
    
    def detect(self, frame):
        """
        运行目标检测
        
        Returns:
            detections: [N, 6] - [x, y, w, h, score, class_id]
        """
        results = self.detector(
            frame,
            conf=self.conf_thresh,
            iou=self.iou_thresh,
            classes=self.classes,
            device=self.device,
            verbose=False
        )[0]
        
        # 解析检测结果
        boxes = results.boxes
        if len(boxes) == 0:
            return np.empty((0, 6))
        
        detections = []
        for box in boxes:
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            score = box.conf[0].cpu().numpy()
            class_id = int(box.cls[0].cpu().numpy())
            
            # 转换为 [x, y, w, h, score, class_id]
            detections.append([
                x1, y1, x2 - x1, y2 - y1, score, class_id
            ])
        
        return np.array(detections)
    
    def process_frame(self, frame):
        """
        处理单帧
        
        Returns:
            tracks: [M, 7] - [x1, y1, x2, y2, track_id, class_id, score]
            fps: 当前帧率
        """
        start_time = time.time()
        
        # 检测
        detections = self.detect(frame)
        
        # 追踪
        tracks = self.tracker.update(detections)
        
        # 统计
        self.frame_count += 1
        elapsed = time.time() - start_time
        self.total_time += elapsed
        fps = 1.0 / elapsed if elapsed > 0 else 0
        
        return tracks, fps
    
    def draw_tracks(self, frame, tracks, show_trajectory=True):
        """
        可视化追踪结果
        """
        frame = frame.copy()
        
        for track in tracks:
            x1, y1, x2, y2 = map(int, track[:4])
            track_id = int(track[4])
            class_id = int(track[5])
            score = track[6]
            
            # 获取颜色
            color = tuple(map(int, self.COLORS[track_id % 100]))
            
            # 绘制边界框
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            
            # 绘制标签
            label = f'ID:{track_id} {self.detector.names[class_id]} {score:.2f}'
            (label_w, label_h), baseline = cv2.getTextSize(
                label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
            )
            cv2.rectangle(
                frame, 
                (x1, y1 - label_h - baseline - 5),
                (x1 + label_w, y1),
                color, -1
            )
            cv2.putText(
                frame, label, (x1, y1 - baseline - 2),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
            )
        
        return frame
    
    def process_video(self, input_path, output_path=None, show=True):
        """
        处理视频文件
        
        Args:
            input_path: 输入视频路径
            output_path: 输出视频路径(可选)
            show: 是否实时显示
        """
        cap = cv2.VideoCapture(input_path)
        
        # 获取视频信息
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        print(f"视频信息: {width}x{height}, {fps}fps, {total_frames}帧")
        
        # 初始化视频写入器
        writer = None
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        # 重置追踪器
        self.tracker = ByteTracker()
        self.frame_count = 0
        self.total_time = 0
        
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                # 处理帧
                tracks, current_fps = self.process_frame(frame)
                
                # 可视化
                vis_frame = self.draw_tracks(frame, tracks)
                
                # 添加FPS信息
                cv2.putText(
                    vis_frame, f'FPS: {current_fps:.1f}',
                    (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
                )
                cv2.putText(
                    vis_frame, f'Tracks: {len(tracks)}',
                    (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
                )
                
                # 保存
                if writer:
                    writer.write(vis_frame)
                
                # 显示
                if show:
                    cv2.imshow('Tracking', vis_frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                
                # 进度
                if self.frame_count % 100 == 0:
                    print(f"处理进度: {self.frame_count}/{total_frames} "
                          f"({100*self.frame_count/total_frames:.1f}%)")
        
        finally:
            cap.release()
            if writer:
                writer.release()
            if show:
                cv2.destroyAllWindows()
        
        # 统计信息
        avg_fps = self.frame_count / self.total_time if self.total_time > 0 else 0
        print(f"\n处理完成!")
        print(f"总帧数: {self.frame_count}")
        print(f"平均FPS: {avg_fps:.1f}")
        print(f"总耗时: {self.total_time:.1f}s")
    
    def process_camera(self, camera_id=0, output_path=None):
        """
        处理摄像头实时流
        """
        cap = cv2.VideoCapture(camera_id)
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
        
        writer = None
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            writer = cv2.VideoWriter(output_path, fourcc, 30, (1280, 720))
        
        # 重置
        self.tracker = ByteTracker()
        
        print("按 'q' 退出")
        
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                tracks, fps = self.process_frame(frame)
                vis_frame = self.draw_tracks(frame, tracks)
                
                cv2.putText(
                    vis_frame, f'FPS: {fps:.1f} | Tracks: {len(tracks)}',
                    (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
                )
                
                if writer:
                    writer.write(vis_frame)
                
                cv2.imshow('Camera Tracking', vis_frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        
        finally:
            cap.release()
            if writer:
                writer.release()
            cv2.destroyAllWindows()

7.2 使用示例

python 复制代码
# ================== 基础使用 ==================

# 创建追踪器
tracker = VideoTracker(
    detector_path='yolov8n.pt',  # 检测模型
    conf_thresh=0.25,            # 置信度阈值
    classes=[0, 2, 5, 7],        # 只追踪:人、车、公交、卡车
    device='cuda'
)

# 处理视频文件
tracker.process_video(
    input_path='input_video.mp4',
    output_path='output_video.mp4',
    show=True
)

# 处理摄像头
tracker.process_camera(camera_id=0)


# ================== 使用YOLOv8内置追踪 ==================

from ultralytics import YOLO

# 最简单的方式:一行代码
model = YOLO('yolov8n.pt')
results = model.track(
    source='input_video.mp4',
    show=True,
    tracker='bytetrack.yaml',  # 或 'botsort.yaml'
    save=True
)

# 逐帧处理
model = YOLO('yolov8n.pt')
cap = cv2.VideoCapture('input_video.mp4')

while True:
    ret, frame = cap.read()
    if not ret:
        break
    
    # persist=True 保持跨帧追踪
    results = model.track(frame, persist=True)
    
    # 获取追踪结果
    if results[0].boxes.id is not None:
        boxes = results[0].boxes.xyxy.cpu().numpy()
        track_ids = results[0].boxes.id.cpu().numpy().astype(int)
        
        for box, track_id in zip(boxes, track_ids):
            x1, y1, x2, y2 = map(int, box)
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, f'ID:{track_id}', (x1, y1-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
    
    cv2.imshow('YOLOv8 Tracking', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

八、轨迹分析与应用

8.1 轨迹存储与分析

python 复制代码
import json
from collections import defaultdict


class TrajectoryAnalyzer:
    """轨迹分析器:存储、分析、可视化轨迹"""
    
    def __init__(self):
        self.trajectories = defaultdict(list)  # {track_id: [(frame_id, x, y, w, h), ...]}
        self.frame_id = 0
    
    def update(self, tracks):
        """
        更新轨迹数据
        
        Args:
            tracks: [M, 7] - [x1, y1, x2, y2, track_id, class_id, score]
        """
        self.frame_id += 1
        
        for track in tracks:
            x1, y1, x2, y2, track_id = track[:5]
            cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
            w, h = x2 - x1, y2 - y1
            
            self.trajectories[int(track_id)].append({
                'frame': self.frame_id,
                'cx': float(cx),
                'cy': float(cy),
                'w': float(w),
                'h': float(h)
            })
    
    def get_trajectory(self, track_id):
        """获取指定ID的轨迹"""
        return self.trajectories.get(track_id, [])
    
    def compute_speed(self, track_id, fps=30, pixels_per_meter=100):
        """
        计算目标平均速度
        
        Args:
            track_id: 轨迹ID
            fps: 视频帧率
            pixels_per_meter: 像素/米转换系数
        Returns:
            speed: 平均速度 (m/s)
        """
        traj = self.trajectories.get(track_id, [])
        if len(traj) < 2:
            return 0.0
        
        total_distance = 0
        for i in range(1, len(traj)):
            dx = traj[i]['cx'] - traj[i-1]['cx']
            dy = traj[i]['cy'] - traj[i-1]['cy']
            total_distance += np.sqrt(dx**2 + dy**2)
        
        # 转换单位
        distance_meters = total_distance / pixels_per_meter
        time_seconds = len(traj) / fps
        
        return distance_meters / time_seconds if time_seconds > 0 else 0
    
    def compute_direction(self, track_id):
        """
        计算目标运动方向
        
        Returns:
            direction: 角度 (0-360度, 0为正右方)
        """
        traj = self.trajectories.get(track_id, [])
        if len(traj) < 2:
            return None
        
        # 使用起点和终点计算总体方向
        dx = traj[-1]['cx'] - traj[0]['cx']
        dy = traj[-1]['cy'] - traj[0]['cy']
        
        angle = np.arctan2(dy, dx) * 180 / np.pi
        return (angle + 360) % 360
    
    def draw_trajectories(self, frame, max_length=50):
        """
        在帧上绘制轨迹
        """
        frame = frame.copy()
        
        for track_id, traj in self.trajectories.items():
            if len(traj) < 2:
                continue
            
            # 只绘制最近的轨迹点
            recent_traj = traj[-max_length:]
            
            # 获取颜色
            color = tuple(map(int, VideoTracker.COLORS[track_id % 100]))
            
            # 绘制轨迹线
            points = [(int(t['cx']), int(t['cy'])) for t in recent_traj]
            for i in range(1, len(points)):
                thickness = int(2 * i / len(points)) + 1
                cv2.line(frame, points[i-1], points[i], color, thickness)
        
        return frame
    
    def save_trajectories(self, filepath):
        """保存轨迹到JSON"""
        data = {
            'total_frames': self.frame_id,
            'trajectories': dict(self.trajectories)
        }
        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)
    
    def load_trajectories(self, filepath):
        """从JSON加载轨迹"""
        with open(filepath, 'r') as f:
            data = json.load(f)
        self.frame_id = data['total_frames']
        self.trajectories = defaultdict(list, data['trajectories'])
    
    def count_objects(self):
        """统计出现的目标总数"""
        return len(self.trajectories)
    
    def get_active_tracks(self, current_frame, threshold=10):
        """获取当前活跃的轨迹"""
        active = []
        for track_id, traj in self.trajectories.items():
            if traj and current_frame - traj[-1]['frame'] <= threshold:
                active.append(track_id)
        return active

8.2 区域计数(越线计数)

python 复制代码
class LineCrossCounter:
    """越线计数器"""
    
    def __init__(self, line_start, line_end):
        """
        Args:
            line_start: (x1, y1) 计数线起点
            line_end: (x2, y2) 计数线终点
        """
        self.line_start = np.array(line_start)
        self.line_end = np.array(line_end)
        
        self.previous_positions = {}  # {track_id: (cx, cy)}
        self.count_up = 0    # 从下往上
        self.count_down = 0  # 从上往下
        self.crossed_ids = set()
    
    def _cross_product(self, point):
        """计算点到线的叉积(判断点在线的哪一侧)"""
        line_vec = self.line_end - self.line_start
        point_vec = point - self.line_start
        return np.cross(line_vec, point_vec)
    
    def update(self, tracks):
        """
        更新计数
        
        Args:
            tracks: [M, 7] - [x1, y1, x2, y2, track_id, class_id, score]
        Returns:
            crossed_tracks: 本帧越线的track_id列表
        """
        crossed_tracks = []
        current_positions = {}
        
        for track in tracks:
            x1, y1, x2, y2, track_id = track[:5]
            track_id = int(track_id)
            cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
            current_pos = np.array([cx, cy])
            current_positions[track_id] = current_pos
            
            # 检查是否越线
            if track_id in self.previous_positions and track_id not in self.crossed_ids:
                prev_pos = self.previous_positions[track_id]
                
                prev_side = self._cross_product(prev_pos)
                curr_side = self._cross_product(current_pos)
                
                # 符号改变表示越线
                if prev_side * curr_side < 0:
                    self.crossed_ids.add(track_id)
                    crossed_tracks.append(track_id)
                    
                    # 判断方向
                    if curr_side > 0:
                        self.count_up += 1
                    else:
                        self.count_down += 1
        
        self.previous_positions = current_positions
        return crossed_tracks
    
    def draw(self, frame):
        """绘制计数线和计数结果"""
        frame = frame.copy()
        
        # 绘制计数线
        cv2.line(frame, 
                tuple(self.line_start.astype(int)),
                tuple(self.line_end.astype(int)),
                (0, 0, 255), 3)
        
        # 绘制计数结果
        text = f'Up: {self.count_up} | Down: {self.count_down} | Total: {self.count_up + self.count_down}'
        cv2.putText(frame, text, (10, 100),
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        
        return frame
    
    def reset(self):
        """重置计数器"""
        self.count_up = 0
        self.count_down = 0
        self.crossed_ids.clear()
        self.previous_positions.clear()

8.3 区域停留检测

python 复制代码
class RegionDwellDetector:
    """区域停留检测器"""
    
    def __init__(self, region_points, dwell_threshold=90):
        """
        Args:
            region_points: 区域多边形顶点 [(x1,y1), (x2,y2), ...]
            dwell_threshold: 停留帧数阈值
        """
        self.region = np.array(region_points, dtype=np.int32)
        self.dwell_threshold = dwell_threshold
        
        self.dwell_times = defaultdict(int)  # {track_id: frames_in_region}
        self.alerts = set()  # 已触发警报的ID
    
    def _point_in_region(self, point):
        """判断点是否在区域内"""
        return cv2.pointPolygonTest(self.region, point, False) >= 0
    
    def update(self, tracks):
        """
        更新停留检测
        
        Returns:
            dwelling_tracks: 停留超阈值的track_id列表
        """
        current_ids = set()
        dwelling_tracks = []
        
        for track in tracks:
            x1, y1, x2, y2, track_id = track[:5]
            track_id = int(track_id)
            cx, cy = (x1 + x2) / 2, y2  # 使用底部中心点
            
            current_ids.add(track_id)
            
            if self._point_in_region((cx, cy)):
                self.dwell_times[track_id] += 1
                
                if (self.dwell_times[track_id] >= self.dwell_threshold 
                    and track_id not in self.alerts):
                    self.alerts.add(track_id)
                    dwelling_tracks.append(track_id)
            else:
                self.dwell_times[track_id] = 0
        
        # 清理离开的目标
        for track_id in list(self.dwell_times.keys()):
            if track_id not in current_ids:
                del self.dwell_times[track_id]
                self.alerts.discard(track_id)
        
        return dwelling_tracks
    
    def draw(self, frame):
        """绘制检测区域"""
        frame = frame.copy()
        
        # 绘制区域
        overlay = frame.copy()
        cv2.fillPoly(overlay, [self.region], (0, 255, 255))
        cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
        cv2.polylines(frame, [self.region], True, (0, 255, 255), 2)
        
        # 标记停留目标
        for track_id in self.alerts:
            if track_id in self.dwell_times:
                text = f'ALERT: ID {track_id}'
                cv2.putText(frame, text, (10, 130),
                           cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        
        return frame

九、性能优化技巧

9.1 检测器优化

python 复制代码
# 1. 使用更小的模型
model = YOLO('yolov8n.pt')  # nano版本,最快

# 2. 降低输入分辨率
results = model(frame, imgsz=640)  # 默认640,可降至320

# 3. 使用TensorRT加速
model.export(format='engine', device=0)  # 导出TensorRT
model = YOLO('yolov8n.engine')

# 4. 使用半精度推理
results = model(frame, half=True)

# 5. 跳帧检测
frame_skip = 2  # 每2帧检测1次
if frame_id % frame_skip == 0:
    detections = detect(frame)

9.2 追踪器优化

python 复制代码
# 1. 减少特征提取(使用ByteTrack而非DeepSORT)
tracker = ByteTracker()  # 无需ReID特征

# 2. 限制轨迹数量
max_tracks = 100
if len(tracker.tracks) > max_tracks:
    # 删除最老的轨迹
    tracker.tracks = sorted(tracker.tracks, key=lambda t: t.age)[-max_tracks:]

# 3. 减少匹配范围
# 只匹配空间上接近的检测和轨迹
def spatial_filtering(tracks, detections, threshold=200):
    # 根据距离预过滤
    pass

9.3 多线程处理

python 复制代码
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import threading


class AsyncVideoTracker:
    """异步视频追踪器"""
    
    def __init__(self, detector_path='yolov8n.pt'):
        self.detector = YOLO(detector_path)
        self.tracker = ByteTracker()
        
        self.frame_queue = Queue(maxsize=10)
        self.result_queue = Queue(maxsize=10)
        self.running = False
    
    def _detection_worker(self):
        """检测线程"""
        while self.running:
            if not self.frame_queue.empty():
                frame_id, frame = self.frame_queue.get()
                detections = self.detect(frame)
                self.result_queue.put((frame_id, frame, detections))
    
    def start(self):
        """启动异步处理"""
        self.running = True
        self.detection_thread = threading.Thread(target=self._detection_worker)
        self.detection_thread.start()
    
    def stop(self):
        """停止异步处理"""
        self.running = False
        self.detection_thread.join()

十、常见问题与解决方案

10.1 ID频繁切换

问题:同一目标的ID不断变化

原因

  • 检测不稳定
  • 目标移动过快
  • 遮挡导致检测丢失

解决方案

python 复制代码
# 1. 降低检测阈值,保留更多检测
conf_thresh = 0.15  # 从0.25降低

# 2. 增加轨迹保持时间
tracker = ByteTracker(track_buffer=60)  # 从30增加到60

# 3. 使用带ReID的追踪器
tracker = DeepSORT(max_age=60, nn_budget=100)

10.2 遮挡后ID变化

问题:目标被遮挡后重新出现,ID改变

解决方案

python 复制代码
# 1. 使用DeepSORT的外观特征
# 2. 增加max_age参数
# 3. 使用ByteTrack的低分检测匹配
tracker = ByteTracker(
    low_thresh=0.1,     # 接受低分检测
    track_buffer=60     # 更长的保持时间
)

10.3 检测抖动

问题:边界框抖动严重

解决方案

python 复制代码
# 卡尔曼滤波本身有平滑效果
# 可以增加观测噪声,使预测更平滑
kf._std_weight_position = 1. / 10  # 增加位置不确定性

# 或者对输出进行额外平滑
def smooth_bbox(current, previous, alpha=0.7):
    return alpha * current + (1 - alpha) * previous

十一、总结

视频目标追踪是一项综合性技术,本文从原理到实战进行了全面讲解:

核心知识点

  1. MOT范式:检测+关联是主流方案
  2. 卡尔曼滤波:预测目标运动状态
  3. 匈牙利算法:最优匹配检测与轨迹
  4. ByteTrack:利用低分检测提升性能

实践建议

  • 简单场景:直接使用YOLOv8内置追踪
  • 需要定制:基于ByteTrack实现
  • 遮挡严重:考虑DeepSORT或BoT-SORT
  • 实时性要求高:使用SORT + 轻量检测器

推荐学习路径

复制代码
SORT → ByteTrack → DeepSORT → BoT-SORT → 自定义优化

希望这篇文章对你有帮助,如有问题欢迎评论区交流!


参考资料

  1. Bewley A, et al. "Simple Online and Realtime Tracking." ICIP 2016.
  2. Wojke N, et al. "Simple Online and Realtime Tracking with a Deep Association Metric." ICIP 2017.
  3. Zhang Y, et al. "ByteTrack: Multi-Object Tracking by Associating Every Detection Box." ECCV 2022.
  4. Aharon N, et al. "BoT-SORT: Robust Associations Multi-Pedestrian Tracking." arXiv 2022.

作者:Jia

更多技术文章,欢迎关注我的CSDN博客!

相关推荐
@高手2 小时前
AI应用开发基础
人工智能
wechat_Neal2 小时前
供应商合作模式中以产品中心取向的转型要点2
人工智能·汽车·devops
一个处女座的程序猿2 小时前
AI之xAI:《WTF is happening at xAI》解读:从 Sulaiman Ghori 的访谈看 xAI 的节奏、架构与“人类模拟器”愿景
人工智能·架构·xai
编码小哥2 小时前
OpenCV DNN模块:深度学习模型部署实战
深度学习·opencv·dnn
一招定胜负2 小时前
项目案例:指纹匹配,图像拼接
人工智能·深度学习·计算机视觉
凤希AI伴侣2 小时前
凤希AI积分系统上线与工具哲学思考-2026年1月24日
人工智能·凤希ai伴侣
逐梦苍穹2 小时前
一键推送AI项目到GitHub的完全指南
人工智能·github
HZjiangzi2 小时前
航空航天大部件检测革新:思看科技无贴点跟踪扫描方案
人工智能·科技·制造
薛定e的猫咪2 小时前
基于大型语言模型的多智能体制造系统用于智能车间
人工智能·机器学习·语言模型·制造