本文系统讲解视频目标追踪的核心原理、主流算法(SORT/DeepSORT/ByteTrack/BoT-SORT)及完整PyTorch实战代码,助你快速掌握这一计算机视觉核心技术。
一、什么是视频目标追踪
1.1 任务定义
视频目标追踪(Video Object Tracking, VOT)是指在视频序列中持续定位并识别特定目标的技术。其核心任务是为每个目标分配唯一的ID,并在整个视频中保持ID的一致性。
追踪任务示意:
帧1: 检测到3个人 → 分配ID: Person_1, Person_2, Person_3
帧2: 检测到3个人 → 关联: Person_1, Person_2, Person_3 (保持ID)
帧3: Person_2被遮挡 → 保持: Person_1, Person_3
帧4: Person_2重新出现 → 恢复: Person_1, Person_2, Person_3
1.2 追踪 vs 检测 vs 分割
| 任务 | 输出 | 时序关系 | 应用场景 |
|---|---|---|---|
| 目标检测 | 单帧边界框+类别 | 无 | 图像理解 |
| 语义分割 | 像素级类别 | 无 | 场景解析 |
| 目标追踪 | 跨帧边界框+唯一ID | 有 | 行为分析、计数 |
| 实例分割+追踪 | 跨帧像素级mask+ID | 有 | 视频编辑、VFX |
1.3 应用场景
智能交通:车辆计数、轨迹分析、违章检测、交通流量统计。
安防监控:行人追踪、异常行为检测、跨摄像头追踪(ReID)。
自动驾驶:周围车辆和行人的持续追踪,预测运动轨迹。
体育分析:运动员追踪、战术分析、运动数据统计。
工业检测:生产线产品追踪、缺陷品溯源。
水下机器人:海洋生物追踪、水下目标持续监控。
二、目标追踪的两大范式
2.1 单目标追踪(SOT)
给定第一帧中目标的位置,在后续帧中持续追踪该目标。
SOT工作流程:
第1帧: 用户/算法指定目标框 [x, y, w, h]
↓
第2帧: 在目标周围搜索,找到最相似区域
↓
第3帧: 继续搜索...
↓
...持续追踪直到视频结束或目标丢失
代表算法:SiamFC、SiamRPN、SiamMask、TransT、OSTrack
优点:无需检测器,计算量小,可追踪任意类别目标
缺点:只能追踪单个目标,容易漂移
2.2 多目标追踪(MOT)
同时追踪视频中的多个目标,是目前工业界主流方案。
MOT工作流程(Tracking by Detection):
每一帧:
1. 检测器 → 获得所有目标的边界框
2. 特征提取 → 提取每个检测框的外观特征(可选)
3. 运动预测 → 用卡尔曼滤波预测已有轨迹的新位置
4. 数据关联 → 匹配检测框与已有轨迹
5. 轨迹管理 → 创建新轨迹、删除丢失轨迹
代表算法:SORT、DeepSORT、ByteTrack、BoT-SORT、OC-SORT
优点:可追踪多目标,鲁棒性强
缺点:依赖检测器质量,计算量较大
三、核心算法详解
3.1 SORT:简单在线实时追踪
SORT(Simple Online and Realtime Tracking)是多目标追踪的基石算法,核心思想极其简洁。
核心组件:
-
卡尔曼滤波:预测目标下一帧位置
-
匈牙利算法:最优匹配检测框与轨迹
-
IoU距离:衡量检测框与预测框的相似度
SORT算法流程:
检测结果 ↓┌─────────────────────────────────────────┐
│ 卡尔曼滤波预测 │
│ 已有轨迹 → 预测下一帧位置 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 计算IoU代价矩阵 │
│ Cost[i,j] = 1 - IoU(detection_i, │
│ prediction_j) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 匈牙利算法匹配 │
│ 找到使总代价最小的匹配方案 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 轨迹管理 │
│ - 匹配成功: 更新轨迹 │
│ - 检测未匹配: 创建新轨迹 │
│ - 轨迹未匹配: 标记丢失/删除 │
└─────────────────────────────────────────┘
SORT的局限:纯粹依赖运动信息(IoU),当目标被遮挡或快速运动时容易ID切换。
3.2 DeepSORT:引入外观特征
DeepSORT在SORT基础上引入深度外观特征,显著提升遮挡场景下的追踪性能。
核心改进:
-
外观特征:使用ReID网络提取目标外观描述子
-
级联匹配:优先匹配最近出现的轨迹
-
马氏距离:结合运动和外观信息
DeepSORT匹配策略:
代价矩阵 = λ × 运动距离(马氏距离) + (1-λ) × 外观距离(余弦距离)
检测框 D1 D2 D3 D4 T1 [0.2, 0.8, 0.9, 0.7] ← 轨迹1与各检测框的距离 T2 [0.7, 0.1, 0.6, 0.8] T3 [0.8, 0.7, 0.2, 0.9]级联匹配:
第1轮: 匹配最近1帧内出现的轨迹
第2轮: 匹配最近2帧内出现的轨迹
...
第N轮: 匹配丢失N帧的轨迹
3.3 ByteTrack:利用低分检测框
ByteTrack的核心洞察:低置信度检测框也包含有价值信息,不应直接丢弃。
ByteTrack两阶段匹配:
第一阶段:高分检测框匹配
- 筛选置信度 > τ_high 的检测框
- 与所有轨迹进行IoU匹配
- 匹配成功的轨迹更新状态
第二阶段:低分检测框匹配
- 筛选 τ_low < 置信度 < τ_high 的检测框
- 仅与第一阶段未匹配的轨迹进行匹配
- 用于恢复被遮挡的目标
典型阈值:τ_high = 0.6, τ_low = 0.1
ByteTrack的优势:
- 无需ReID网络,速度快
- 有效处理遮挡导致的低分检测
- 在MOT17、MOT20等榜单上取得SOTA
3.4 BoT-SORT:当前最强方案
BoT-SORT结合了多种技术优势,是目前综合性能最好的追踪器之一。
核心组件:
-
改进的卡尔曼滤波:使用相机运动补偿(CMC)
-
IoU-ReID融合:结合运动和外观信息
-
轨迹状态管理:更精细的轨迹生命周期管理
BoT-SORT改进点:
-
相机运动补偿(CMC)
- 检测相邻帧的全局运动(如相机平移、旋转)
- 在预测前补偿相机运动
- 提升动态相机场景的追踪精度
-
IoU-ReID融合
Cost = min(Cost_IoU, Cost_ReID) # 取较小值
而非 DeepSORT的加权融合 -
轨迹置信度
- 根据匹配历史动态调整轨迹置信度
- 低置信度轨迹更容易被删除
-
3.5 算法对比
| 算法 | 速度 | 精度 | 遮挡处理 | 需要ReID | 适用场景 |
|---|---|---|---|---|---|
| SORT | ⭐⭐⭐⭐⭐ | ⭐⭐ | 差 | 否 | 简单场景、实时性要求高 |
| DeepSORT | ⭐⭐⭐ | ⭐⭐⭐⭐ | 好 | 是 | 遮挡多、需要ReID |
| ByteTrack | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 较好 | 否 | 通用场景、平衡性能 |
| BoT-SORT | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | 很好 | 可选 | 高精度要求 |
| OC-SORT | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 较好 | 否 | 非线性运动场景 |
四、卡尔曼滤波详解
卡尔曼滤波是目标追踪的核心组件,用于预测目标的下一帧位置。
4.1 状态向量定义
python
# 状态向量 x = [x, y, a, h, vx, vy, va, vh]
# x, y: 边界框中心坐标
# a: 宽高比 (aspect ratio = w/h)
# h: 高度
# vx, vy, va, vh: 对应的速度
# 观测向量 z = [x, y, a, h]
# 只能观测到位置,无法直接观测速度
4.2 卡尔曼滤波实现
python
import numpy as np
from scipy.linalg import block_diag
class KalmanFilter:
"""
目标追踪专用卡尔曼滤波器
状态向量: [x, y, a, h, vx, vy, va, vh]
观测向量: [x, y, a, h]
"""
def __init__(self):
ndim = 4 # 观测维度
dt = 1.0 # 时间步长
# 状态转移矩阵 F
# 假设匀速运动模型: x' = x + vx * dt
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
# 观测矩阵 H
# 只能观测位置,不能观测速度
self._update_mat = np.eye(ndim, 2 * ndim)
# 过程噪声权重
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""
从第一次观测初始化轨迹
Args:
measurement: [x, y, a, h] 边界框观测
Returns:
mean: 状态均值
covariance: 状态协方差
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
# 初始协方差:位置不确定性小,速度不确定性大
std = [
2 * self._std_weight_position * measurement[3], # x
2 * self._std_weight_position * measurement[3], # y
1e-2, # a
2 * self._std_weight_position * measurement[3], # h
10 * self._std_weight_velocity * measurement[3], # vx
10 * self._std_weight_velocity * measurement[3], # vy
1e-5, # va
10 * self._std_weight_velocity * measurement[3], # vh
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
"""
预测步骤:根据运动模型预测下一状态
x' = F * x
P' = F * P * F^T + Q
"""
# 过程噪声 Q
std_pos = [
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-2,
self._std_weight_position * mean[3],
]
std_vel = [
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[3],
1e-5,
self._std_weight_velocity * mean[3],
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
# 预测
mean = np.dot(self._motion_mat, mean)
covariance = np.linalg.multi_dot([
self._motion_mat, covariance, self._motion_mat.T
]) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""
将状态投影到观测空间
z = H * x
S = H * P * H^T + R
"""
# 观测噪声 R
std = [
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-1,
self._std_weight_position * mean[3],
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot([
self._update_mat, covariance, self._update_mat.T
]) + innovation_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""
更新步骤:融合预测和观测
K = P * H^T * S^(-1)
x = x + K * (z - H * x)
P = (I - K * H) * P
"""
# 投影到观测空间
projected_mean, projected_cov = self.project(mean, covariance)
# 卡尔曼增益
chol_factor = np.linalg.cholesky(projected_cov)
kalman_gain = np.linalg.solve(
chol_factor,
np.linalg.solve(chol_factor,
np.dot(covariance, self._update_mat.T).T).T
).T
# 更新
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot([
kalman_gain, projected_cov, kalman_gain.T
])
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, only_position=False):
"""
计算马氏距离(用于门控)
"""
projected_mean, projected_cov = self.project(mean, covariance)
if only_position:
projected_mean = projected_mean[:2]
projected_cov = projected_cov[:2, :2]
measurements = measurements[:, :2]
chol_factor = np.linalg.cholesky(projected_cov)
d = measurements - projected_mean
z = np.linalg.solve(chol_factor, d.T).T
return np.sum(z * z, axis=1)
五、数据关联算法
5.1 IoU距离计算
python
import numpy as np
def bbox_iou(bbox1, bbox2):
"""
计算两个边界框的IoU
Args:
bbox1: [x1, y1, x2, y2]
bbox2: [x1, y1, x2, y2]
Returns:
iou: 交并比
"""
# 计算交集
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
# 计算并集
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
union_area = area1 + area2 - inter_area
return inter_area / (union_area + 1e-6)
def iou_distance(tracks, detections):
"""
计算轨迹与检测框之间的IoU距离矩阵
Args:
tracks: 轨迹列表
detections: 检测框列表
Returns:
cost_matrix: [num_tracks, num_detections]
"""
cost_matrix = np.zeros((len(tracks), len(detections)))
for t, track in enumerate(tracks):
track_bbox = track.to_tlbr() # 转换为 [x1, y1, x2, y2]
for d, det in enumerate(detections):
det_bbox = det.to_tlbr()
iou = bbox_iou(track_bbox, det_bbox)
cost_matrix[t, d] = 1 - iou # IoU越大,代价越小
return cost_matrix
5.2 匈牙利算法
python
from scipy.optimize import linear_sum_assignment
def hungarian_matching(cost_matrix, threshold=0.7):
"""
使用匈牙利算法进行最优匹配
Args:
cost_matrix: 代价矩阵 [num_tracks, num_detections]
threshold: 匹配阈值,超过此值不匹配
Returns:
matches: 匹配对列表 [(track_idx, det_idx), ...]
unmatched_tracks: 未匹配的轨迹索引
unmatched_detections: 未匹配的检测框索引
"""
if cost_matrix.size == 0:
return [], list(range(cost_matrix.shape[0])), list(range(cost_matrix.shape[1]))
# 匈牙利算法求解
row_indices, col_indices = linear_sum_assignment(cost_matrix)
matches = []
unmatched_tracks = list(range(cost_matrix.shape[0]))
unmatched_detections = list(range(cost_matrix.shape[1]))
for row, col in zip(row_indices, col_indices):
if cost_matrix[row, col] > threshold:
continue
matches.append((row, col))
unmatched_tracks.remove(row)
unmatched_detections.remove(col)
return matches, unmatched_tracks, unmatched_detections
5.3 级联匹配(DeepSORT)
python
def cascade_matching(tracks, detections, track_indices, detection_indices,
cost_function, threshold, max_age=30):
"""
级联匹配:优先匹配最近出现的轨迹
Args:
tracks: 轨迹列表
detections: 检测框列表
track_indices: 待匹配的轨迹索引
detection_indices: 待匹配的检测框索引
cost_function: 代价函数
threshold: 匹配阈值
max_age: 最大丢失帧数
"""
matches = []
unmatched_detections = list(detection_indices)
# 按丢失帧数分组,优先匹配最近的
for age in range(max_age):
if len(unmatched_detections) == 0:
break
# 筛选当前age的轨迹
track_indices_age = [
t for t in track_indices
if tracks[t].time_since_update == age + 1
]
if len(track_indices_age) == 0:
continue
# 计算代价矩阵
cost_matrix = cost_function(
[tracks[t] for t in track_indices_age],
[detections[d] for d in unmatched_detections]
)
# 匈牙利匹配
matched, _, unmatched_det = hungarian_matching(cost_matrix, threshold)
# 更新结果
for t_idx, d_idx in matched:
matches.append((track_indices_age[t_idx], unmatched_detections[d_idx]))
unmatched_detections = [unmatched_detections[d] for d in unmatched_det]
unmatched_tracks = [t for t in track_indices if t not in [m[0] for m in matches]]
return matches, unmatched_tracks, unmatched_detections
六、完整追踪器实现
6.1 轨迹类
python
from enum import Enum
class TrackState(Enum):
"""轨迹状态"""
Tentative = 1 # 待确认
Confirmed = 2 # 已确认
Deleted = 3 # 已删除
class Track:
"""单个目标轨迹"""
_count = 0 # 全局ID计数器
def __init__(self, detection, track_id=None, n_init=3, max_age=30):
"""
Args:
detection: 初始检测框 [x, y, w, h, score, class_id]
track_id: 轨迹ID(可选)
n_init: 确认所需的连续匹配次数
max_age: 最大丢失帧数
"""
# 分配ID
if track_id is None:
Track._count += 1
self.track_id = Track._count
else:
self.track_id = track_id
# 卡尔曼滤波器
self.kf = KalmanFilter()
# 初始化状态
measurement = self._bbox_to_measurement(detection[:4])
self.mean, self.covariance = self.kf.initiate(measurement)
# 轨迹属性
self.hits = 1 # 匹配次数
self.age = 1 # 存在帧数
self.time_since_update = 0 # 距上次更新的帧数
self.state = TrackState.Tentative
# 参数
self.n_init = n_init
self.max_age = max_age
# 检测信息
self.score = detection[4] if len(detection) > 4 else 1.0
self.class_id = int(detection[5]) if len(detection) > 5 else 0
# 外观特征历史(用于ReID)
self.features = []
def _bbox_to_measurement(self, bbox):
"""将[x, y, w, h]转换为[cx, cy, a, h]"""
x, y, w, h = bbox
return np.array([x + w/2, y + h/2, w/h, h])
def _measurement_to_bbox(self, measurement):
"""将[cx, cy, a, h]转换为[x, y, w, h]"""
cx, cy, a, h = measurement
w = a * h
return np.array([cx - w/2, cy - h/2, w, h])
def predict(self):
"""预测下一帧状态"""
self.mean, self.covariance = self.kf.predict(self.mean, self.covariance)
self.age += 1
self.time_since_update += 1
def update(self, detection, feature=None):
"""
用新检测更新轨迹
Args:
detection: [x, y, w, h, score, class_id]
feature: 外观特征向量(可选)
"""
measurement = self._bbox_to_measurement(detection[:4])
self.mean, self.covariance = self.kf.update(
self.mean, self.covariance, measurement
)
# 更新属性
self.hits += 1
self.time_since_update = 0
self.score = detection[4] if len(detection) > 4 else self.score
# 更新外观特征
if feature is not None:
self.features.append(feature)
if len(self.features) > 100: # 保留最近100个特征
self.features = self.features[-100:]
# 状态转换
if self.state == TrackState.Tentative and self.hits >= self.n_init:
self.state = TrackState.Confirmed
def mark_missed(self):
"""标记为丢失"""
if self.state == TrackState.Tentative:
self.state = TrackState.Deleted
elif self.time_since_update > self.max_age:
self.state = TrackState.Deleted
def is_tentative(self):
return self.state == TrackState.Tentative
def is_confirmed(self):
return self.state == TrackState.Confirmed
def is_deleted(self):
return self.state == TrackState.Deleted
def to_tlwh(self):
"""返回 [x, y, w, h] 格式"""
return self._measurement_to_bbox(self.mean[:4])
def to_tlbr(self):
"""返回 [x1, y1, x2, y2] 格式"""
bbox = self.to_tlwh()
return np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
6.2 ByteTrack追踪器
python
class ByteTracker:
"""
ByteTrack多目标追踪器
特点:利用低分检测框恢复被遮挡目标
"""
def __init__(self,
track_thresh=0.5, # 高分阈值
track_buffer=30, # 轨迹保持帧数
match_thresh=0.8, # 匹配阈值
low_thresh=0.1, # 低分阈值
new_track_thresh=0.6): # 新建轨迹阈值
self.track_thresh = track_thresh
self.track_buffer = track_buffer
self.match_thresh = match_thresh
self.low_thresh = low_thresh
self.new_track_thresh = new_track_thresh
self.tracks = [] # 已确认轨迹
self.lost_tracks = [] # 丢失轨迹
self.removed_tracks = [] # 已删除轨迹
self.frame_id = 0
Track._count = 0 # 重置ID计数
def update(self, detections):
"""
处理一帧检测结果
Args:
detections: numpy array, shape [N, 6]
每行: [x, y, w, h, score, class_id]
Returns:
outputs: 追踪结果 [[x1, y1, x2, y2, track_id, class_id, score], ...]
"""
self.frame_id += 1
# 分离高分和低分检测
if len(detections) > 0:
scores = detections[:, 4]
# 高分检测
high_mask = scores >= self.track_thresh
high_dets = detections[high_mask]
# 低分检测
low_mask = (scores < self.track_thresh) & (scores >= self.low_thresh)
low_dets = detections[low_mask]
else:
high_dets = np.empty((0, 6))
low_dets = np.empty((0, 6))
# 合并已确认和丢失的轨迹
all_tracks = self.tracks + self.lost_tracks
# 预测所有轨迹
for track in all_tracks:
track.predict()
# ========== 第一阶段:高分检测匹配 ==========
if len(high_dets) > 0 and len(all_tracks) > 0:
# 计算IoU代价矩阵
cost_matrix = iou_distance(all_tracks, self._dets_to_tracks(high_dets))
# 匈牙利匹配
matches, unmatched_tracks, unmatched_dets = hungarian_matching(
cost_matrix, threshold=self.match_thresh
)
# 更新匹配的轨迹
for track_idx, det_idx in matches:
all_tracks[track_idx].update(high_dets[det_idx])
# 获取未匹配的轨迹和检测
remain_tracks = [all_tracks[i] for i in unmatched_tracks]
remain_high_dets = high_dets[unmatched_dets]
else:
remain_tracks = all_tracks
remain_high_dets = high_dets
matches = []
# ========== 第二阶段:低分检测匹配 ==========
if len(low_dets) > 0 and len(remain_tracks) > 0:
cost_matrix = iou_distance(remain_tracks, self._dets_to_tracks(low_dets))
matches_low, unmatched_tracks_low, _ = hungarian_matching(
cost_matrix, threshold=0.5 # 低分匹配使用更宽松的阈值
)
# 更新匹配的轨迹
for track_idx, det_idx in matches_low:
remain_tracks[track_idx].update(low_dets[det_idx])
remain_tracks = [remain_tracks[i] for i in unmatched_tracks_low]
# ========== 轨迹管理 ==========
# 标记丢失的轨迹
for track in remain_tracks:
track.mark_missed()
# 为未匹配的高分检测创建新轨迹
for det in remain_high_dets:
if det[4] >= self.new_track_thresh:
new_track = Track(det, max_age=self.track_buffer)
self.tracks.append(new_track)
# 更新轨迹列表
self.tracks = [t for t in self.tracks + self.lost_tracks if not t.is_deleted()]
self.lost_tracks = [t for t in self.tracks if not t.is_confirmed()]
self.tracks = [t for t in self.tracks if t.is_confirmed()]
# ========== 输出结果 ==========
outputs = []
for track in self.tracks:
if track.is_confirmed() and track.time_since_update == 0:
bbox = track.to_tlbr()
outputs.append([
bbox[0], bbox[1], bbox[2], bbox[3],
track.track_id, track.class_id, track.score
])
return np.array(outputs) if len(outputs) > 0 else np.empty((0, 7))
def _dets_to_tracks(self, detections):
"""将检测结果转换为临时Track对象(用于计算IoU)"""
tracks = []
for det in detections:
track = Track(det)
tracks.append(track)
return tracks
七、完整推理系统
7.1 视频追踪推理器
python
import cv2
import time
import numpy as np
from ultralytics import YOLO
class VideoTracker:
"""
完整的视频目标追踪系统
集成YOLO检测 + ByteTrack追踪
"""
# 类别颜色映射
COLORS = np.random.randint(0, 255, size=(100, 3), dtype=np.uint8)
def __init__(self,
detector_path='yolov8n.pt',
conf_thresh=0.25,
iou_thresh=0.45,
classes=None,
device='cuda'):
"""
Args:
detector_path: YOLO模型路径
conf_thresh: 检测置信度阈值
iou_thresh: NMS IoU阈值
classes: 追踪的类别ID列表,None表示全部
device: 推理设备
"""
# 初始化检测器
self.detector = YOLO(detector_path)
self.conf_thresh = conf_thresh
self.iou_thresh = iou_thresh
self.classes = classes
self.device = device
# 初始化追踪器
self.tracker = ByteTracker(
track_thresh=0.5,
track_buffer=30,
match_thresh=0.8
)
# 统计信息
self.frame_count = 0
self.total_time = 0
def detect(self, frame):
"""
运行目标检测
Returns:
detections: [N, 6] - [x, y, w, h, score, class_id]
"""
results = self.detector(
frame,
conf=self.conf_thresh,
iou=self.iou_thresh,
classes=self.classes,
device=self.device,
verbose=False
)[0]
# 解析检测结果
boxes = results.boxes
if len(boxes) == 0:
return np.empty((0, 6))
detections = []
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
score = box.conf[0].cpu().numpy()
class_id = int(box.cls[0].cpu().numpy())
# 转换为 [x, y, w, h, score, class_id]
detections.append([
x1, y1, x2 - x1, y2 - y1, score, class_id
])
return np.array(detections)
def process_frame(self, frame):
"""
处理单帧
Returns:
tracks: [M, 7] - [x1, y1, x2, y2, track_id, class_id, score]
fps: 当前帧率
"""
start_time = time.time()
# 检测
detections = self.detect(frame)
# 追踪
tracks = self.tracker.update(detections)
# 统计
self.frame_count += 1
elapsed = time.time() - start_time
self.total_time += elapsed
fps = 1.0 / elapsed if elapsed > 0 else 0
return tracks, fps
def draw_tracks(self, frame, tracks, show_trajectory=True):
"""
可视化追踪结果
"""
frame = frame.copy()
for track in tracks:
x1, y1, x2, y2 = map(int, track[:4])
track_id = int(track[4])
class_id = int(track[5])
score = track[6]
# 获取颜色
color = tuple(map(int, self.COLORS[track_id % 100]))
# 绘制边界框
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
# 绘制标签
label = f'ID:{track_id} {self.detector.names[class_id]} {score:.2f}'
(label_w, label_h), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame,
(x1, y1 - label_h - baseline - 5),
(x1 + label_w, y1),
color, -1
)
cv2.putText(
frame, label, (x1, y1 - baseline - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
)
return frame
def process_video(self, input_path, output_path=None, show=True):
"""
处理视频文件
Args:
input_path: 输入视频路径
output_path: 输出视频路径(可选)
show: 是否实时显示
"""
cap = cv2.VideoCapture(input_path)
# 获取视频信息
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"视频信息: {width}x{height}, {fps}fps, {total_frames}帧")
# 初始化视频写入器
writer = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# 重置追踪器
self.tracker = ByteTracker()
self.frame_count = 0
self.total_time = 0
try:
while True:
ret, frame = cap.read()
if not ret:
break
# 处理帧
tracks, current_fps = self.process_frame(frame)
# 可视化
vis_frame = self.draw_tracks(frame, tracks)
# 添加FPS信息
cv2.putText(
vis_frame, f'FPS: {current_fps:.1f}',
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
)
cv2.putText(
vis_frame, f'Tracks: {len(tracks)}',
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
)
# 保存
if writer:
writer.write(vis_frame)
# 显示
if show:
cv2.imshow('Tracking', vis_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# 进度
if self.frame_count % 100 == 0:
print(f"处理进度: {self.frame_count}/{total_frames} "
f"({100*self.frame_count/total_frames:.1f}%)")
finally:
cap.release()
if writer:
writer.release()
if show:
cv2.destroyAllWindows()
# 统计信息
avg_fps = self.frame_count / self.total_time if self.total_time > 0 else 0
print(f"\n处理完成!")
print(f"总帧数: {self.frame_count}")
print(f"平均FPS: {avg_fps:.1f}")
print(f"总耗时: {self.total_time:.1f}s")
def process_camera(self, camera_id=0, output_path=None):
"""
处理摄像头实时流
"""
cap = cv2.VideoCapture(camera_id)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
writer = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_path, fourcc, 30, (1280, 720))
# 重置
self.tracker = ByteTracker()
print("按 'q' 退出")
try:
while True:
ret, frame = cap.read()
if not ret:
break
tracks, fps = self.process_frame(frame)
vis_frame = self.draw_tracks(frame, tracks)
cv2.putText(
vis_frame, f'FPS: {fps:.1f} | Tracks: {len(tracks)}',
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
)
if writer:
writer.write(vis_frame)
cv2.imshow('Camera Tracking', vis_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
finally:
cap.release()
if writer:
writer.release()
cv2.destroyAllWindows()
7.2 使用示例
python
# ================== 基础使用 ==================
# 创建追踪器
tracker = VideoTracker(
detector_path='yolov8n.pt', # 检测模型
conf_thresh=0.25, # 置信度阈值
classes=[0, 2, 5, 7], # 只追踪:人、车、公交、卡车
device='cuda'
)
# 处理视频文件
tracker.process_video(
input_path='input_video.mp4',
output_path='output_video.mp4',
show=True
)
# 处理摄像头
tracker.process_camera(camera_id=0)
# ================== 使用YOLOv8内置追踪 ==================
from ultralytics import YOLO
# 最简单的方式:一行代码
model = YOLO('yolov8n.pt')
results = model.track(
source='input_video.mp4',
show=True,
tracker='bytetrack.yaml', # 或 'botsort.yaml'
save=True
)
# 逐帧处理
model = YOLO('yolov8n.pt')
cap = cv2.VideoCapture('input_video.mp4')
while True:
ret, frame = cap.read()
if not ret:
break
# persist=True 保持跨帧追踪
results = model.track(frame, persist=True)
# 获取追踪结果
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu().numpy()
track_ids = results[0].boxes.id.cpu().numpy().astype(int)
for box, track_id in zip(boxes, track_ids):
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f'ID:{track_id}', (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
cv2.imshow('YOLOv8 Tracking', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
八、轨迹分析与应用
8.1 轨迹存储与分析
python
import json
from collections import defaultdict
class TrajectoryAnalyzer:
"""轨迹分析器:存储、分析、可视化轨迹"""
def __init__(self):
self.trajectories = defaultdict(list) # {track_id: [(frame_id, x, y, w, h), ...]}
self.frame_id = 0
def update(self, tracks):
"""
更新轨迹数据
Args:
tracks: [M, 7] - [x1, y1, x2, y2, track_id, class_id, score]
"""
self.frame_id += 1
for track in tracks:
x1, y1, x2, y2, track_id = track[:5]
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
w, h = x2 - x1, y2 - y1
self.trajectories[int(track_id)].append({
'frame': self.frame_id,
'cx': float(cx),
'cy': float(cy),
'w': float(w),
'h': float(h)
})
def get_trajectory(self, track_id):
"""获取指定ID的轨迹"""
return self.trajectories.get(track_id, [])
def compute_speed(self, track_id, fps=30, pixels_per_meter=100):
"""
计算目标平均速度
Args:
track_id: 轨迹ID
fps: 视频帧率
pixels_per_meter: 像素/米转换系数
Returns:
speed: 平均速度 (m/s)
"""
traj = self.trajectories.get(track_id, [])
if len(traj) < 2:
return 0.0
total_distance = 0
for i in range(1, len(traj)):
dx = traj[i]['cx'] - traj[i-1]['cx']
dy = traj[i]['cy'] - traj[i-1]['cy']
total_distance += np.sqrt(dx**2 + dy**2)
# 转换单位
distance_meters = total_distance / pixels_per_meter
time_seconds = len(traj) / fps
return distance_meters / time_seconds if time_seconds > 0 else 0
def compute_direction(self, track_id):
"""
计算目标运动方向
Returns:
direction: 角度 (0-360度, 0为正右方)
"""
traj = self.trajectories.get(track_id, [])
if len(traj) < 2:
return None
# 使用起点和终点计算总体方向
dx = traj[-1]['cx'] - traj[0]['cx']
dy = traj[-1]['cy'] - traj[0]['cy']
angle = np.arctan2(dy, dx) * 180 / np.pi
return (angle + 360) % 360
def draw_trajectories(self, frame, max_length=50):
"""
在帧上绘制轨迹
"""
frame = frame.copy()
for track_id, traj in self.trajectories.items():
if len(traj) < 2:
continue
# 只绘制最近的轨迹点
recent_traj = traj[-max_length:]
# 获取颜色
color = tuple(map(int, VideoTracker.COLORS[track_id % 100]))
# 绘制轨迹线
points = [(int(t['cx']), int(t['cy'])) for t in recent_traj]
for i in range(1, len(points)):
thickness = int(2 * i / len(points)) + 1
cv2.line(frame, points[i-1], points[i], color, thickness)
return frame
def save_trajectories(self, filepath):
"""保存轨迹到JSON"""
data = {
'total_frames': self.frame_id,
'trajectories': dict(self.trajectories)
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
def load_trajectories(self, filepath):
"""从JSON加载轨迹"""
with open(filepath, 'r') as f:
data = json.load(f)
self.frame_id = data['total_frames']
self.trajectories = defaultdict(list, data['trajectories'])
def count_objects(self):
"""统计出现的目标总数"""
return len(self.trajectories)
def get_active_tracks(self, current_frame, threshold=10):
"""获取当前活跃的轨迹"""
active = []
for track_id, traj in self.trajectories.items():
if traj and current_frame - traj[-1]['frame'] <= threshold:
active.append(track_id)
return active
8.2 区域计数(越线计数)
python
class LineCrossCounter:
"""越线计数器"""
def __init__(self, line_start, line_end):
"""
Args:
line_start: (x1, y1) 计数线起点
line_end: (x2, y2) 计数线终点
"""
self.line_start = np.array(line_start)
self.line_end = np.array(line_end)
self.previous_positions = {} # {track_id: (cx, cy)}
self.count_up = 0 # 从下往上
self.count_down = 0 # 从上往下
self.crossed_ids = set()
def _cross_product(self, point):
"""计算点到线的叉积(判断点在线的哪一侧)"""
line_vec = self.line_end - self.line_start
point_vec = point - self.line_start
return np.cross(line_vec, point_vec)
def update(self, tracks):
"""
更新计数
Args:
tracks: [M, 7] - [x1, y1, x2, y2, track_id, class_id, score]
Returns:
crossed_tracks: 本帧越线的track_id列表
"""
crossed_tracks = []
current_positions = {}
for track in tracks:
x1, y1, x2, y2, track_id = track[:5]
track_id = int(track_id)
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
current_pos = np.array([cx, cy])
current_positions[track_id] = current_pos
# 检查是否越线
if track_id in self.previous_positions and track_id not in self.crossed_ids:
prev_pos = self.previous_positions[track_id]
prev_side = self._cross_product(prev_pos)
curr_side = self._cross_product(current_pos)
# 符号改变表示越线
if prev_side * curr_side < 0:
self.crossed_ids.add(track_id)
crossed_tracks.append(track_id)
# 判断方向
if curr_side > 0:
self.count_up += 1
else:
self.count_down += 1
self.previous_positions = current_positions
return crossed_tracks
def draw(self, frame):
"""绘制计数线和计数结果"""
frame = frame.copy()
# 绘制计数线
cv2.line(frame,
tuple(self.line_start.astype(int)),
tuple(self.line_end.astype(int)),
(0, 0, 255), 3)
# 绘制计数结果
text = f'Up: {self.count_up} | Down: {self.count_down} | Total: {self.count_up + self.count_down}'
cv2.putText(frame, text, (10, 100),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
return frame
def reset(self):
"""重置计数器"""
self.count_up = 0
self.count_down = 0
self.crossed_ids.clear()
self.previous_positions.clear()
8.3 区域停留检测
python
class RegionDwellDetector:
"""区域停留检测器"""
def __init__(self, region_points, dwell_threshold=90):
"""
Args:
region_points: 区域多边形顶点 [(x1,y1), (x2,y2), ...]
dwell_threshold: 停留帧数阈值
"""
self.region = np.array(region_points, dtype=np.int32)
self.dwell_threshold = dwell_threshold
self.dwell_times = defaultdict(int) # {track_id: frames_in_region}
self.alerts = set() # 已触发警报的ID
def _point_in_region(self, point):
"""判断点是否在区域内"""
return cv2.pointPolygonTest(self.region, point, False) >= 0
def update(self, tracks):
"""
更新停留检测
Returns:
dwelling_tracks: 停留超阈值的track_id列表
"""
current_ids = set()
dwelling_tracks = []
for track in tracks:
x1, y1, x2, y2, track_id = track[:5]
track_id = int(track_id)
cx, cy = (x1 + x2) / 2, y2 # 使用底部中心点
current_ids.add(track_id)
if self._point_in_region((cx, cy)):
self.dwell_times[track_id] += 1
if (self.dwell_times[track_id] >= self.dwell_threshold
and track_id not in self.alerts):
self.alerts.add(track_id)
dwelling_tracks.append(track_id)
else:
self.dwell_times[track_id] = 0
# 清理离开的目标
for track_id in list(self.dwell_times.keys()):
if track_id not in current_ids:
del self.dwell_times[track_id]
self.alerts.discard(track_id)
return dwelling_tracks
def draw(self, frame):
"""绘制检测区域"""
frame = frame.copy()
# 绘制区域
overlay = frame.copy()
cv2.fillPoly(overlay, [self.region], (0, 255, 255))
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
cv2.polylines(frame, [self.region], True, (0, 255, 255), 2)
# 标记停留目标
for track_id in self.alerts:
if track_id in self.dwell_times:
text = f'ALERT: ID {track_id}'
cv2.putText(frame, text, (10, 130),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
return frame
九、性能优化技巧
9.1 检测器优化
python
# 1. 使用更小的模型
model = YOLO('yolov8n.pt') # nano版本,最快
# 2. 降低输入分辨率
results = model(frame, imgsz=640) # 默认640,可降至320
# 3. 使用TensorRT加速
model.export(format='engine', device=0) # 导出TensorRT
model = YOLO('yolov8n.engine')
# 4. 使用半精度推理
results = model(frame, half=True)
# 5. 跳帧检测
frame_skip = 2 # 每2帧检测1次
if frame_id % frame_skip == 0:
detections = detect(frame)
9.2 追踪器优化
python
# 1. 减少特征提取(使用ByteTrack而非DeepSORT)
tracker = ByteTracker() # 无需ReID特征
# 2. 限制轨迹数量
max_tracks = 100
if len(tracker.tracks) > max_tracks:
# 删除最老的轨迹
tracker.tracks = sorted(tracker.tracks, key=lambda t: t.age)[-max_tracks:]
# 3. 减少匹配范围
# 只匹配空间上接近的检测和轨迹
def spatial_filtering(tracks, detections, threshold=200):
# 根据距离预过滤
pass
9.3 多线程处理
python
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import threading
class AsyncVideoTracker:
"""异步视频追踪器"""
def __init__(self, detector_path='yolov8n.pt'):
self.detector = YOLO(detector_path)
self.tracker = ByteTracker()
self.frame_queue = Queue(maxsize=10)
self.result_queue = Queue(maxsize=10)
self.running = False
def _detection_worker(self):
"""检测线程"""
while self.running:
if not self.frame_queue.empty():
frame_id, frame = self.frame_queue.get()
detections = self.detect(frame)
self.result_queue.put((frame_id, frame, detections))
def start(self):
"""启动异步处理"""
self.running = True
self.detection_thread = threading.Thread(target=self._detection_worker)
self.detection_thread.start()
def stop(self):
"""停止异步处理"""
self.running = False
self.detection_thread.join()
十、常见问题与解决方案
10.1 ID频繁切换
问题:同一目标的ID不断变化
原因:
- 检测不稳定
- 目标移动过快
- 遮挡导致检测丢失
解决方案:
python
# 1. 降低检测阈值,保留更多检测
conf_thresh = 0.15 # 从0.25降低
# 2. 增加轨迹保持时间
tracker = ByteTracker(track_buffer=60) # 从30增加到60
# 3. 使用带ReID的追踪器
tracker = DeepSORT(max_age=60, nn_budget=100)
10.2 遮挡后ID变化
问题:目标被遮挡后重新出现,ID改变
解决方案:
python
# 1. 使用DeepSORT的外观特征
# 2. 增加max_age参数
# 3. 使用ByteTrack的低分检测匹配
tracker = ByteTracker(
low_thresh=0.1, # 接受低分检测
track_buffer=60 # 更长的保持时间
)
10.3 检测抖动
问题:边界框抖动严重
解决方案:
python
# 卡尔曼滤波本身有平滑效果
# 可以增加观测噪声,使预测更平滑
kf._std_weight_position = 1. / 10 # 增加位置不确定性
# 或者对输出进行额外平滑
def smooth_bbox(current, previous, alpha=0.7):
return alpha * current + (1 - alpha) * previous
十一、总结
视频目标追踪是一项综合性技术,本文从原理到实战进行了全面讲解:
核心知识点:
- MOT范式:检测+关联是主流方案
- 卡尔曼滤波:预测目标运动状态
- 匈牙利算法:最优匹配检测与轨迹
- ByteTrack:利用低分检测提升性能
实践建议:
- 简单场景:直接使用YOLOv8内置追踪
- 需要定制:基于ByteTrack实现
- 遮挡严重:考虑DeepSORT或BoT-SORT
- 实时性要求高:使用SORT + 轻量检测器
推荐学习路径:
SORT → ByteTrack → DeepSORT → BoT-SORT → 自定义优化
希望这篇文章对你有帮助,如有问题欢迎评论区交流!
参考资料:
- Bewley A, et al. "Simple Online and Realtime Tracking." ICIP 2016.
- Wojke N, et al. "Simple Online and Realtime Tracking with a Deep Association Metric." ICIP 2017.
- Zhang Y, et al. "ByteTrack: Multi-Object Tracking by Associating Every Detection Box." ECCV 2022.
- Aharon N, et al. "BoT-SORT: Robust Associations Multi-Pedestrian Tracking." arXiv 2022.
作者:Jia
更多技术文章,欢迎关注我的CSDN博客!