python
复制代码
import cv2
import torch
import numpy as np
# -------------------------- 请修改这2个参数(必改)--------------------------
model_path = "best.pt" # 你的YOLOv5训练好的模型路径(如:runs/train/exp/weights/best.pt)
video_path = 0 # 0=电脑默认摄像头,可替换为视频路径(如:"test.mp4")
# --------------------------------------------------------------------------
# 1. 加载YOLOv5模型(复用之前的模型,无需重新训练,兼容ultralytics最新版本)
model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True)
# 设置检测置信度阈值(参考3月31日优化技巧,过滤无效框,可根据需求调整)
model.conf = 0.4
# 定义类别(0=猫,1=狗,与之前的数据集一致,不可随意修改)
class_names = ['cat', 'dog']
# 2. 初始化OpenCV视频读取,设置视频读取参数,提升兼容性
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) # 设置画面宽度,避免画面过小
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) # 设置画面高度,提升可视化效果
# 定义计数变量(猫、狗分别计数,初始化为0,避免计数异常)
cat_count = 0
dog_count = 0
# 定义目标追踪字典(记录每个目标的轨迹,key=目标ID,value=(中心x, 中心y, 类别))
tracks = {}
track_id = 0 # 目标唯一ID,从0开始递增,避免重复
# 定义轨迹颜色(猫=蓝色,狗=红色,BGR格式,OpenCV默认格式,可自定义)
colors = [(255, 0, 0), (0, 0, 255)]
# 3. 实时处理(循环读取视频帧,直到视频结束或手动退出)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break # 视频读取完毕或读取失败,退出循环
# 4. YOLOv5目标检测(检测当前帧中的猫狗,不改变原画面尺寸)
results = model(frame)
# 提取检测结果(坐标、类别、置信度),格式:xmin, ymin, xmax, ymax, conf, class, name
detections = results.pandas().xyxy[0].values
# 5. 目标追踪与计数(核心逻辑,避免重复计数,适配不同移动速度的目标)
current_tracks = [] # 记录当前帧中存在的目标ID,用于删除消失的轨迹
for det in detections:
xmin, ymin, xmax, ymax, conf, cls, name = det
# 计算目标中心坐标(用于追踪匹配,提升追踪准确性)
center_x = int((xmin + xmax) / 2)
center_y = int((ymin + ymax) / 2)
cls = int(cls) # 类别索引(0=猫,1=狗)
# 追踪逻辑:匹配当前目标与已有的轨迹,避免重复计数
matched = False
for track_id_exist, (track_center_x, track_center_y, track_cls) in tracks.items():
# 计算两个目标中心的欧氏距离,距离近且类别一致则视为同一个目标
distance = np.sqrt((center_x - track_center_x)**2 + (center_y - track_center_y)**2)
if distance < 50 and track_cls == cls: # 50为距离阈值,可根据目标移动速度调整
# 匹配成功,更新该目标的轨迹坐标
tracks[track_id_exist] = (center_x, center_y, cls)
current_tracks.append(track_id_exist)
matched = True
break
# 若未匹配到已有轨迹,视为新目标,分配新ID,对应类别计数+1
if not matched:
tracks[track_id] = (center_x, center_y, cls)
current_tracks.append(track_id)
# 对应类别计数+1,确保计数准确,不重复、不漏计
if cls == 0:
cat_count += 1
else:
dog_count += 1
track_id += 1
# 6. 删除消失的目标轨迹(避免轨迹积累,减少内存占用,提升运行速度)
tracks = {k: v for k, v in tracks.items() if k in current_tracks}
# 7. OpenCV 可视化绘制(检测框、类别、置信度、轨迹、计数,直观清晰)
for det in detections:
xmin, ymin, xmax, ymax, conf, cls, name = det
# 将坐标转为整数,避免绘制报错(OpenCV绘制需整数坐标)
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
cls = int(cls)
# 绘制检测框(猫=蓝色,狗=红色,线宽2,清晰可见)
cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), colors[cls], 2)
# 绘制类别和置信度(字体适中,避免遮挡画面,颜色与检测框一致)
label = f"{name} {conf:.2f}" # 保留2位小数,简洁直观
cv2.putText(frame, label, (xmin, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[cls], 2)
# 绘制目标轨迹(圆点标记轨迹点,连线显示移动路径,可选)
for track_id_exist, (center_x, center_y, cls) in tracks.items():
# 绘制轨迹点(实心圆点,半径5,颜色与检测框一致)
cv2.circle(frame, (center_x, center_y), 5, colors[cls], -1)
# 绘制轨迹连线(连接当前轨迹点与上一个同类别轨迹点,线宽2)
if track_id_exist > 0:
prev_center = tracks.get(track_id_exist - 1, None)
if prev_center and prev_center[2] == cls:
cv2.line(frame, (prev_center[0], prev_center[1]), (center_x, center_y), colors[cls], 2)
# 绘制实时计数(屏幕左上角,字体较大,颜色与对应类别一致,便于查看)
cv2.putText(frame, f"Cat: {cat_count}", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
cv2.putText(frame, f"Dog: {dog_count}", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
# 8. 显示处理后的画面(窗口名称清晰,便于识别)
cv2.imshow("YOLOv5 + OpenCV 实时计数+追踪", frame)
# 按「q」键退出(必加,否则无法正常关闭窗口,避免程序卡死)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# 释放资源(避免内存泄漏,养成良好编程习惯)
cap.release()
cv2.destroyAllWindows()
# 打印最终计数结果(控制台输出,便于记录和核对)
print(f"最终检测计数:猫 {cat_count} 只,狗 {dog_count} 只")