OpenCV(五十四):车辆检测

项目核心流程与算法原理

视频车流量统计的本质是目标检测(Object Detection) + 目标追踪(Object Tracking) + 越线判定(Line Crossing Detection)。整个项目的核心流水线(Pipeline)如下:

bash 复制代码
[视频输入] ──> [背景建模与去噪] ──> [轮廓检测与筛选] ──> [质心计算与追踪] ──> [碰撞线计数]

图像预处理与背景建模

在一幅交通视频中,马路、护栏、树木是静止的(背景),而车辆是运动的(前景)。我们采用 混合高斯模型(MOG2, Mixture of Gaussians) 动态分离前景。

  • 背景减除法(Background Subtraction):MOG2 会对图像中的每个像素点进行概率建模。当新的一帧进来时,如果某个像素的值与历史模型相差较大,则判定为前景(车辆)。
  • 形态学处理(Morphological Operations) :背景减除后会产生大量噪点(如树叶晃动、光影变化),且车辆内部可能出现空洞。我们通过 腐蚀(Erode) 消除细小噪点,通过 膨胀(Dilate) or 闭运算(Close) 填补车辆内部空洞,使车辆连成一个完成的连通域。

车辆定位与质心提取

  • 轮廓外接矩形 :通过 cv2.findContours 寻找处理后图像中的所有连通域边缘。为了过滤掉行人和小噪点,我们会设定一个面积阈值(如 Area>800Area > 800Area>800 像素),只有大于该面积的轮廓才被认定为车辆。
  • 质心(Centroid)计算 :通过矩(Moments)几何特征计算出车辆外接矩形的中心点坐标 (cx,cy)(cx, cy)(cx,cy)。追踪质心比追踪整辆车要高效得多

车辆追踪与越线计数

  • 动态追踪(简易数据关联):对于前后两帧,我们计算当前帧所有新质心与上一帧已有车辆质心之间的欧氏距离。如果距离小于预设阈值(如 20 像素),则认为这是同一辆车,并更新其轨迹。
  • 基准线相交判定 :在画面中人为绘制一条"虚拟计数线"(线段 Y=YlineY = Y_{line}Y=Yline)。当某辆车的质心在前一帧位于线上方,而在当前帧位于线下方(或反之),则触发计数器 count += 1

Python 代码实现

bash 复制代码
import cv2
import numpy as np
import time
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# ==========================================
# 1. 配置类:集中管理所有超参数
# ==========================================
@dataclass
class Config:
    """集中管理所有配置参数"""
    # 视频配置
    video_path: str = "traffic.mp4"
    
    # 车辆检测参数
    min_width: int = 40
    min_height: int = 40
    max_width: int = 300  # 新增:最大宽度过滤
    max_height: int = 200  # 新增:最大高度过滤
    
    # 计数线配置
    line_height_ratio: float = 0.6  # 使用相对位置(视频高度的60%)
    line_height: Optional[int] = None  # 动态计算
    
    # 追踪器参数
    max_tracking_distance: int = 35  # 最大追踪距离(像素)
    max_trajectory_length: int = 60  # 轨迹最大长度
    trajectory_timeout: int = 45     # 轨迹超时帧数
    
    # MOG2 背景减除参数
    mog_history: int = 500
    mog_var_threshold: int = 50
    mog_detect_shadows: bool = True
    
    # 形态学操作参数
    erode_kernel_size: Tuple[int, int] = (3, 3)
    dilate_kernel_size: Tuple[int, int] = (7, 7)
    erode_iterations: int = 1
    dilate_iterations: int = 3
    
    # 显示参数
    show_debug_windows: bool = True  # 是否显示调试窗口
    fps_display: bool = True         # 是否显示FPS
    draw_trajectories: bool = True   # 是否绘制轨迹
    trajectory_color: Tuple[int, int, int] = (0, 255, 255)  # 轨迹颜色(黄色)
    
    # 计数模式
    count_direction: str = "bidirectional"  # "downward", "upward", "bidirectional"
    
    def __post_init__(self):
        """验证配置"""
        if self.count_direction not in ["downward", "upward", "bidirectional"]:
            raise ValueError(f"无效的计数方向: {self.count_direction}")


# ==========================================
# 2. 车辆追踪器类:封装追踪逻辑
# ==========================================
class VehicleTracker:
    """
    改进的车辆追踪器
    - 使用平方距离避免开方运算
    - 添加轨迹超时清理
    - 支持双向计数
    """
    
    def __init__(self, config: Config):
        self.config = config
        self.trajectories: Dict[int, dict] = {}  # 使用ID字典而非列表
        self.next_id: int = 0
        self.down_counter: int = 0  # 向下计数
        self.up_counter: int = 0    # 向上计数
        
    def update(self, current_centroids: List[Tuple[int, int]], line_y: int) -> Tuple[int, int]:
        """
        更新追踪状态
        
        Args:
            current_centroids: 当前帧检测到的车辆质心列表
            line_y: 计数线Y坐标
            
        Returns:
            (向下计数, 向上计数)
        """
        # 构建旧轨迹的匹配映射
        matched_old_ids = set()
        new_trajectories = {}
        
        # 计算平方距离阈值,避免开方
        max_dist_sq = self.config.max_tracking_distance ** 2
        
        for cx, cy in current_centroids:
            best_match_id = None
            best_dist_sq = max_dist_sq
            
            # 寻找最佳匹配(贪心策略)
            for traj_id, traj in self.trajectories.items():
                if traj_id in matched_old_ids:
                    continue
                    
                last_point = traj['points'][-1]
                dist_sq = (cx - last_point[0])**2 + (cy - last_point[1])**2
                
                if dist_sq < best_dist_sq:
                    best_dist_sq = dist_sq
                    best_match_id = traj_id
            
            if best_match_id is not None:
                # 匹配到已有轨迹
                matched_old_ids.add(best_match_id)
                traj = self.trajectories[best_match_id]
                traj['points'].append((cx, cy))
                traj['last_frame'] = len(traj['points'])
                
                # 检查是否跨越计数线
                self._check_crossing(traj, line_y)
                
                new_trajectories[best_match_id] = traj
            else:
                # 新建轨迹
                new_id = self.next_id
                self.next_id += 1
                
                new_traj = {
                    'points': [(cx, cy)],
                    'last_frame': 1,
                    'counted_down': False,
                    'counted_up': False
                }
                new_trajectories[new_id] = new_traj
        
        # 保留未匹配但未满超时的轨迹
        for traj_id, traj in self.trajectories.items():
            if traj_id not in matched_old_ids:
                timeout_frames = self.config.trajectory_timeout
                if len(traj['points']) < timeout_frames:
                    new_trajectories[traj_id] = traj
        
        # 清理过长的轨迹
        self.trajectories = {
            tid: t for tid, t in new_trajectories.items()
            if len(t['points']) < self.config.max_trajectory_length
        }
        
        return self.down_counter, self.up_counter
    
    def _check_crossing(self, traj: dict, line_y: int):
        """检查轨迹是否跨越计数线"""
        points = traj['points']
        if len(points) < 2:
            return
        
        p1 = points[-2]
        p2 = points[-1]
        
        # 向下穿越(从上往下)
        if self.config.count_direction in ["downward", "bidirectional"]:
            if p1[1] < line_y <= p2[1] and not traj['counted_down']:
                self.down_counter += 1
                traj['counted_down'] = True
                logger.debug(f"车辆向下穿越计数线,总数: {self.down_counter}")
        
        # 向上穿越(从下往上)
        if self.config.count_direction in ["upward", "bidirectional"]:
            if p2[1] < line_y <= p1[1] and not traj['counted_up']:
                self.up_counter += 1
                traj['counted_up'] = True
                logger.debug(f"车辆向上穿越计数线,总数: {self.up_counter}")
    
    def get_total_count(self) -> int:
        """获取总计数"""
        if self.config.count_direction == "downward":
            return self.down_counter
        elif self.config.count_direction == "upward":
            return self.up_counter
        else:  # bidirectional
            return self.down_counter + self.up_counter
    
    def draw_trajectories(self, frame: np.ndarray):
        """在帧上绘制所有轨迹"""
        if not self.config.draw_trajectories:
            return
        
        for traj_id, traj in self.trajectories.items():
            points = traj['points']
            if len(points) < 2:
                continue
            
            # 绘制轨迹线
            points_arr = np.array(points, np.int32)
            points_arr = points_arr.reshape(-1, 1, 2)
            cv2.polylines(frame, [points_arr], False, self.config.trajectory_color, 1)


# ==========================================
# 3. 车辆检测器类:封装检测逻辑
# ==========================================
class VehicleDetector:
    """封装车辆检测逻辑"""
    
    def __init__(self, config: Config):
        self.config = config
        self.bg_subtractor = cv2.createBackgroundSubtractorMOG2(
            history=config.mog_history,
            varThreshold=config.mog_var_threshold,
            detectShadows=config.mog_detect_shadows
        )
        self.kernel_erode = cv2.getStructuringElement(
            cv2.MORPH_RECT, config.erode_kernel_size
        )
        self.kernel_dilate = cv2.getStructuringElement(
            cv2.MORPH_RECT, config.dilate_kernel_size
        )
    
    def detect(self, frame: np.ndarray) -> Tuple[List[Tuple[int, int]], np.ndarray, np.ndarray]:
        """
        检测车辆并返回质心
        
        Returns:
            (质心列表, 前景掩膜, 形态学处理后图像)
        """
        # 灰度化和高斯模糊
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (5, 5), 0)
        
        # 背景减除
        fg_mask = self.bg_subtractor.apply(blur)
        
        # 阈值过滤阴影
        _, threshed = cv2.threshold(fg_mask, 200, 255, cv2.THRESH_BINARY)
        
        # 形态学操作
        eroded = cv2.erode(threshed, self.kernel_erode, iterations=self.config.erode_iterations)
        dilated = cv2.dilate(eroded, self.kernel_dilate, iterations=self.config.dilate_iterations)
        closing = cv2.morphologyEx(dilated, cv2.MORPH_CLOSE, self.kernel_dilate)
        
        # 轮廓检测
        contours, _ = cv2.findContours(closing, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # 提取质心
        centroids = []
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            
            # 尺寸过滤
            if (self.config.min_width <= w <= self.config.max_width and
                self.config.min_height <= h <= self.config.max_height):
                cx = x + w // 2
                cy = y + h // 2
                centroids.append((cx, cy))
                
                # 绘制检测框
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
                cv2.circle(frame, (cx, cy), 4, (0, 0, 255), -1)
        
        return centroids, fg_mask, closing


# ==========================================
# 4. 主程序类
# ==========================================
class TrafficCounter:
    """交通计数主程序"""
    
    def __init__(self, config: Config):
        self.config = config
        self.detector = VehicleDetector(config)
        self.tracker = VehicleTracker(config)
        self.fps: float = 0.0
        self.frame_count: int = 0
        self.start_time: float = 0.0
    
    def run(self):
        """运行主循环"""
        cap = cv2.VideoCapture(self.config.video_path)
        
        if not cap.isOpened():
            logger.error(f"无法打开视频文件: {self.config.video_path}")
            return
        
        try:
            # 获取视频高度以计算计数线位置
            frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            if self.config.line_height is None:
                self.config.line_height = int(frame_height * self.config.line_height_ratio)
            
            logger.info(f"开始处理视频,计数线Y坐标: {self.config.line_height}")
            self.start_time = time.time()
            
            while True:
                ret, frame = cap.read()
                if not ret:
                    logger.info("视频播放结束")
                    break
                
                # 计算FPS
                self.frame_count += 1
                elapsed = time.time() - self.start_time
                if elapsed > 0:
                    self.fps = self.frame_count / elapsed
                
                # 检测车辆
                centroids, fg_mask, closing = self.detector.detect(frame)
                
                # 更新追踪
                down_count, up_count = self.tracker.update(centroids, self.config.line_height)
                total_count = self.tracker.get_total_count()
                
                # 绘制计数线
                cv2.line(
                    frame, (0, self.config.line_height),
                    (frame.shape[1], self.config.line_height),
                    (0, 0, 255), 3
                )
                
                # 绘制轨迹
                self.tracker.draw_trajectories(frame)
                
                # 绘制统计信息
                self._draw_info(frame, total_count, down_count, up_count)
                
                # 显示窗口
                cv2.imshow("Traffic Counter", frame)
                
                if self.config.show_debug_windows:
                    cv2.imshow("Foreground Mask", fg_mask)
                    cv2.imshow("Morphological Closing", closing)
                
                # 检查退出键
                key = cv2.waitKey(1) & 0xFF
                if key == 27:  # ESC
                    logger.info("用户按下ESC退出")
                    break
                elif key == ord('s'):  # 截图
                    timestamp = int(time.time())
                    cv2.imwrite(f"screenshot_{timestamp}.jpg", frame)
                    logger.info(f"截图已保存: screenshot_{timestamp}.jpg")
                
        except Exception as e:
            logger.error(f"运行时错误: {e}", exc_info=True)
        finally:
            cap.release()
            cv2.destroyAllWindows()
            logger.info(f"处理完成。总计数: {self.tracker.get_total_count()}, FPS: {self.fps:.2f}")
    
    def _draw_info(self, frame: np.ndarray, total: int, down: int, up: int):
        """绘制统计信息"""
        y_offset = 40
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.7
        thickness = 2
        
        # 标题
        cv2.putText(frame, "TRAFFIC COUNTER", (20, y_offset),
                    font, 1.0, (255, 255, 255), 2)
        y_offset += 35
        
        # 总计数
        cv2.putText(frame, f"TOTAL: {total}", (20, y_offset),
                    font, font_scale, (0, 255, 0), thickness)
        y_offset += 25
        
        # 分方向计数
        if self.config.count_direction == "bidirectional":
            cv2.putText(frame, f"DOWN: {down}", (20, y_offset),
                        font, font_scale, (0, 165, 255), thickness)
            y_offset += 25
            cv2.putText(frame, f"UP: {up}", (20, y_offset),
                        font, font_scale, (255, 165, 0), thickness)
            y_offset += 25
        
        # FPS
        if self.config.fps_display:
            cv2.putText(frame, f"FPS: {self.fps:.1f}", (20, y_offset),
                        font, font_scale, (255, 255, 255), thickness)
            y_offset += 25
        
        # 追踪车辆数
        cv2.putText(frame, f"TRACKING: {len(self.tracker.trajectories)}", 
                    (20, y_offset), font, font_scale, (255, 255, 255), thickness)


# ==========================================
# 5. 入口函数
# ==========================================
def main():
    """主入口函数"""
    # 创建配置
    config = Config(
        video_path="traffic.mp4",
        min_width=40,
        min_height=40,
        max_width=300,
        max_height=200,
        line_height_ratio=0.6,
        max_tracking_distance=35,
        trajectory_timeout=45,
        count_direction="bidirectional",  # 可选: "downward", "upward", "bidirectional"
        show_debug_windows=True,
        draw_trajectories=True,
        fps_display=True
    )
    
    # 运行
    counter = TrafficCounter(config)
    counter.run()


if __name__ == "__main__":
    main()

视频流 I/O 与基础控制技术

这是所有视觉项目的入口和出口。

  • cv2.VideoCapture(视频流读取) :用于建立视频输入管道。它不仅能读取本地的 .mp4.avi 视频文件,还可以通过传入 01 调用本地摄像头,或者传入 rtsp://... 连接网络监控摄像头。
  • cv2.waitKey(delay)(键盘拦截与帧率控制) :这是 OpenCV 维持视窗刷新的核心函数。它有两大作用:一是控制视频播放的刷新间隔(毫秒);二是捕捉键盘输入(例如代码中的 key == 27 拦截 ESC 键),实现非阻塞式的安全退出。

图像空间转换与平滑去噪

在进行高级分析前,必须降低数据维度并滤除传感器噪声。

  • cv2.cvtColor(..., cv2.COLOR_BGR2GRAY)(彩色转灰度)

    将三通道的 BGR 彩色图像转换为单通道的灰度图。由于车辆的运动轮廓 只取决于像素的亮度变化,与颜色无关,转为灰度图可以将计算量直接暴降到原来的 13\frac{1}{3}31,极大地提升了算法的实时性。

  • cv2.GaussianBlur(高斯模糊)

    图像在传输或拍摄过程中会有很多高频噪点(如雪花点、蚊子噪)。高斯模糊利用二维高斯分布矩阵(核大小如 5×55 \times 55×5)对图像进行平滑卷积,让噪声"融化"到周围像素中,防止后面的算法把噪点误判为运动车辆。

背景减除与运动前景提取

这是传统视觉识别运动目标的核心武器。

  • cv2.createBackgroundSubtractorMOG2(混合高斯模型背景减除) : 传统的帧差法(Frame Difference)无法解决车辆静止、光照渐变的问题。MOG2 算法为每个像素点建立多个高斯分布,动态学习视频的"静止背景"(马路、建筑)。当车辆驶入时,像素值偏离了背景模型,从而被精准地剥离为前景掩膜(Foreground Mask)
  • cv2.threshold(二值化变换): 由于 MOG2 在提取前景时会将车身投射的阴影(Shadows)标记为特殊的灰色(灰度值 127),我们通过二值化强行将大于 200 灰度值的像素设为 255(纯白),小于的设为 0(纯黑),从而干净地剔除了地面阴影的干扰。

数学形态学(Morphological Operations)

用于修补二值化图像的缺陷。

  • cv2.getStructuringElement(结构元素构建) :用于定义形态学操作的"画笔"形状和大小(如 3×33 \times 33×3 或 7×77 \times 77×7 的矩形核)。
  • cv2.erode(腐蚀):让图像中的白色区域向内收缩一圈。其主要作用是"消除散点",将画面中由于树叶晃动或噪点产生的小白点直接抹去。
  • cv2.dilate(膨胀):让图像中的白色区域向外扩张一圈。其主要作用是"桥接断裂",让同一辆汽车断开的轮廓重新融合成一个整体。
  • cv2.morphologyEx(..., cv2.MORPH_CLOSE)(闭运算):先膨胀后腐蚀的组合拳。它可以填充车辆内部由于挡风玻璃反光造成的"黑色空洞",并平滑车辆的外边缘。

几何特征分析与轮廓提取

将像素级别的白色块转换为具有物理意义的坐标。

  • cv2.findContours(拓扑轮廓寻找): 该算法会扫描整张二值化图像,将所有连续的白色像素块的边缘提取出来,形成一个无序的几何轮廓列表。
  • cv2.boundingRect(最大外接矩形) : 输入一个不规则的轮廓,它会自动计算出能包裹该轮廓的最小正矩形,并返回左上角坐标 (x, y) 以及宽高 (w, h)。这是我们进行车辆几何尺寸筛选(w >= MIN_W)的数学依据。

几何图形绘制与 UI 渲染

用于在视频画面上实时呈现算法结果。

  • cv2.linecv2.rectangle(线与矩形绘制):在图像矩阵上直接修改像素,画出绿色的车辆检测框和红色的虚拟计数线。
  • cv2.circle(圆形绘制) :用于将我们计算出的物理质心 (cx,cy)(cx, cy)(cx,cy) 以红点的形式标刻在画面中。
  • cv2.putText(文本渲染) :将动态累加的 car_counter 变量以指定的字体、大小和颜色实时"写"在视频帧的左上角。