优化图像拼接算法思路

一、多线程优化

1、多线程并行图像拼接,视频数据流检测并将每节车厢的帧数保存,然后直接使用线程调用拼接函数,不影响主程序视频流的检测。

2、基于get_img_asyncio.py程序中修改run.py

python 复制代码
def run(self, video_num=0):
        # 基本数据初始化
        run_state = False
        img_list = []
        cg_list = [[], []]
        cut = []
        img_num_list = []
        begin_num = 0
        total_res = []
        det_res_list = []
        
        # 线程池相关变量
        stich_futures = []  # 存储所有拼接任务的Future对象
        stich_tasks = []    # 存储任务元数据
        
        try:
            cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
            if not cap.isOpened():
                logging.error(f"{self.task_id} 无法打开视频源: {self.source}")
                self.running = False
                return total_res
                
            cap_info = self.get_cap_info(cap)
            self.total_stich_side.basic_shape = cap_info['shape']
            self.total_stich_side.cut_shape = [self.center_box[3], self.center_box[2]-self.center_box[0]]
            self.total_stich_side.x_area = [self.center_box[0], self.center_box[2]]
            self.total_stich_side.basic_num = self.train_num_base
            self.total_stich_side.type = self.type

            self.total_stich_up.basic_shape = cap_info['shape']
            self.total_stich_up.cut_shape = [self.center_box[3], self.center_box[2]-self.center_box[0]]
            self.total_stich_up.x_area = [self.center_box[0], self.center_box[2]]
            self.total_stich_up.basic_num = self.train_num_base
            self.total_stich_up.type = self.type

            self.image_shape = cap_info['shape']
            pre_dedtect = {'stich': False, 'hitnum': 0}
            end_dedtect = 0
            logging.info(f"{self.task_id} 成功打开视频源: {self.source}")
            stop = 0
            start_det = 0
            
            # 创建线程池(根据需求调整max_workers)
            thread_pool = ThreadPoolExecutor(max_workers=4, thread_name_prefix=f"Sticher-{self.task_id}")
            
            for total_num in range(cap_info['frames']):
                # 检查已完成的拼接任务
                completed_futures = []
                for future in stich_futures:
                    if future.done():
                        try:
                            res = future.result()
                            if res is not None:
                                total_res.append(res)
                        except Exception as e:
                            logging.error(f"拼接任务出错: {str(e)}")
                            traceback.print_exc()
                        completed_futures.append(future)
                
                # 移除已完成的任务
                for future in completed_futures:
                    stich_futures.remove(future)
                
                ret = cap.grab()
                if not ret:
                    break
                    
                if total_num % 100 == 0:
                    logging.info(f'{self.task_id}  total_num:{total_num} ')
                    
                ret, frame = cap.retrieve()
                if pre_dedtect['hitnum'] == 0 and total_num % 5 != 0:
                    continue
                    
                det_res = self.det.infer(frame)
                
                # 第一节车厢拼接,获取开头相应位置以及初始化必要信息
                if not run_state:
                    if (video_num == 0 and total_num < 100) or len(det_res[1]) == 0:
                        continue
                    else:
                        if not pre_dedtect['stich']:
                            if self.flitter_railways(det_res):
                                pre_dedtect['hitnum'] += 1
                                if pre_dedtect['hitnum'] > 5:
                                    pre_dedtect['stich'] = True
                            else:
                                if pre_dedtect['hitnum'] > 0:
                                    pre_dedtect['hitnum'] -= 1
                        else:
                            standart_H, intersection = self.calculate_union_and_check_intersection(
                                det_res[0], det_res[1], self.center_box, run_state)
                            if not intersection:
                                continue
                            begin_num = total_num
                            run_state = True
                else:
                    standart_H, intersection = self.calculate_union_and_check_intersection(
                        det_res[0], det_res[1], self.center_box, run_state)
                    
                    if not intersection:
                        if len(cg_list[0]) == 0:
                            stop += 1
                            if stop >= 15:
                                break
                            continue
                        
                        new_list = [int(abs(x - (self.center_box[2] - self.center_box[0]) / 2)) for x in cg_list[0][0]]
                        if len(new_list) == 0:
                            if end_dedtect < 10:
                                end_dedtect += 1
                                continue
                            else:
                                break
                                
                        id = new_list.index(min(new_list))
                        end = img_num_list.index(cg_list[1][id])
                        cut.append([end, cg_list[0][id]])
                        
                        if self.total_stich is None:
                            self.total_stich = self.total_stich_side
                            
                        # 创建独立的任务数据副本
                        task_imgs = img_list[:end + 1].copy()
                        task_dets = det_res_list[:end + 1].copy()
                        task_cuts = cut.copy()
                        
                        # 创建新的拼接器实例避免状态冲突
                        if self.type == 'up':
                            task_sticher = self.total_stich_up
                        else:
                            task_sticher = self.total_stich_side
                        
                        # 创建拼接器副本(浅拷贝)
                        task_sticher = copy.copy(task_sticher)
                        
                        # 设置任务专属参数
                        init_params = {
                            'basic_shape': task_sticher.basic_shape,
                            'cut_shape': task_sticher.cut_shape,
                            'x_area': task_sticher.x_area,
                            'save_path': task_sticher.save_path,
                            'train_num': task_sticher.train_num_base + len(stich_tasks),
                            'type': task_sticher.type
                        }
                        
                        # 延迟初始化任务专属状态
                        if hasattr(task_sticher, 'initialize'):
                            task_sticher.initialize(init_params)
                        
                        # 提交任务到线程池
                        future = thread_pool.submit(
                            stich_worker_thread,
                            task_imgs,
                            task_dets,
                            task_cuts,
                            task_sticher
                        )
                        
                        stich_futures.append(future)
                        stich_tasks.append({
                            'end_index': end,
                            'train_num': self.train_num_base + len(stich_tasks)
                        })
                        
                        # 更新数据结构
                        img_list = img_list[end:]
                        det_res_list = det_res_list[end:]
                        img_num_list = img_num_list[end:]
                        cut = [[0, cut[-1][1]]]
                        cg_list = [[], []]
                        self.begin_cut = True
                        
                    else:
                        det_box, cg = self.calculate_intersection_of_union_with_input(det_res, self.center_box)
                        img_list.append(frame)
                        img_num_list.append(total_num)
                        det_res_list.append(det_res)
                        
                        if cg[0] > 0:
                            if len(cut) > 0:
                                if total_num - img_num_list[0] < 30:
                                    continue
                            cg_list[0].append(cg)
                            cg_list[1].append(total_num)
                        else:
                            if len(cg_list[0]) != 0:
                                new_list = [int(abs(x[0] - (self.center_box[2] - self.center_box[0]) / 2)) for x in cg_list[0]]
                                id = new_list.index(min(new_list))
                                
                                if (total_num - begin_num) - len(cg_list[0]) < 50 and not self.begin_cut:
                                    logging.info(f"{self.task_id}开始裁切起始位置")
                                    begin = img_num_list.index(cg_list[1][id])
                                    img_list = img_list[begin:]
                                    det_res_list = det_res_list[begin:]
                                    img_num_list = img_num_list[begin:]
                                    cut.append([begin, cg_list[0][id]])
                                    cg_list = [[], []]
                                    self.begin_cut = True
                                else:
                                    end = img_num_list.index(cg_list[1][id])
                                    cut.append([end, cg_list[0][id]])
                                    
                                    if self.total_stich is None:
                                        self.total_stich = self.total_stich_side
                                    
                                    # 创建独立的任务数据副本
                                    task_imgs = img_list[:end + 1].copy()
                                    task_dets = det_res_list[:end + 1].copy()
                                    task_cuts = cut.copy()
                                    
                                    # 创建新的拼接器实例
                                    if self.type == 'up':
                                        task_sticher = self.total_stich_up
                                    else:
                                        task_sticher = self.total_stich_side
                                    
                                    task_sticher = copy.copy(task_sticher)
                                    
                                    # 设置任务专属参数
                                    init_params = {
                                        'basic_shape': task_sticher.basic_shape,
                                        'cut_shape': task_sticher.cut_shape,
                                        'x_area': task_sticher.x_area,
                                        'save_path': task_sticher.save_path,
                                        'train_num': task_sticher.train_num_base + len(stich_tasks),
                                        'type': task_sticher.type
                                    }
                                    
                                    # 延迟初始化
                                    if hasattr(task_sticher, 'initialize'):
                                        task_sticher.initialize(init_params)
                                    
                                    # 提交任务
                                    future = thread_pool.submit(
                                        stich_worker_thread,
                                        task_imgs,
                                        task_dets,
                                        task_cuts,
                                        task_sticher
                                    )
                                    
                                    stich_futures.append(future)
                                    stich_tasks.append({
                                        'end_index': end,
                                        'train_num': self.train_num_base + len(stich_tasks)
                                    })
                                    
                                    # 更新数据结构
                                    img_list = img_list[end:]
                                    det_res_list = det_res_list[end:]
                                    img_num_list = img_num_list[end:]
                                    cut = [[0, cut[-1][1]]]
                                    cg_list = [[], []]
            
            # 等待所有剩余任务完成
            for future in stich_futures:
                try:
                    res = future.result(timeout=300)  # 5分钟超时
                    if res is not None:
                        total_res.append(res)
                except TimeoutError:
                    logging.warning(f"{self.task_id} 拼接任务超时")
                except Exception as e:
                    logging.error(f"{self.task_id} 拼接任务失败: {str(e)}")
                    traceback.print_exc()
            
            # 关闭线程池
            thread_pool.shutdown(wait=True)
            
            if cap is not None:
                cap.release()
                logging.info(f"{self.task_id} 释放视频捕获资源")
                
            return total_res
            
        except Exception as e:
            logging.error(f"{self.task_id} 运行异常: {str(e)}")
            traceback.print_exc()
            
            # 尝试关闭线程池
            try:
                thread_pool.shutdown(wait=False, cancel_futures=True)
            except:
                pass
                
            return total_res

上面的代码中,主要通过:

提交任务到线程池

future = thread_pool.submit(

stich_worker_thread,

task_imgs,

task_dets,

task_cuts,

task_sticher

)

来调用拼接程序。

3、基于get_img_asyncio.py程序中给出一个全局函数stich_worker_thread

python 复制代码
def stich_worker_thread(imgs, dets, cuts, sticher):
    try:
        return sticher.stich_images(imgs, dets, cuts)
    except Exception as e:
        logging.error(f"拼接线程出错: {str(e)}")
        traceback.print_exc()
        return None

4、在TotalStich_sideTotalStich_up类中添加initialize方法

python 复制代码
def initialize(self, init_params):
    """初始化线程专有状态"""
    for key, value in init_params.items():
        setattr(self, key, value)
    
    # 初始化需要延迟创建的资源
    self.weight_map = self.create_weight_map()
    self.img_list = []
    self.cut = False
    # 其他需要重置的状态..
相关推荐
郑州光合科技余经理4 分钟前
同城O2O海外版二次开发实战:从支付网关到配送算法
开发语言·前端·后端·算法·架构·uni-app·php
d111111111d3 小时前
STM32-UART封装问题解析
笔记·stm32·单片机·嵌入式硬件·学习·算法
Jiangxl~4 小时前
IP数据云如何为不同行业提供精准IP查询与风险防控解决方案?
网络·网络协议·tcp/ip·算法·ai·ip·安全架构
李伟_Li慢慢5 小时前
wolfram详解山峦算法
前端·算法
counting money5 小时前
prim算法最小生成树(java)
算法
澈2075 小时前
C++面向对象:类与对象核心解析
c++·算法
用户690673881925 小时前
基于无人机的单目测距系统,平均误差仅2.12%
算法
dinl_vin5 小时前
LangChain 系列·(四):RAG 基础——给大模型装上“外脑“
人工智能·算法·langchain
探物 AI6 小时前
【感知·医学分割】当 YOLOv11 杀入医学赛道:先检测后分割的级联架构
算法·yolo·计算机视觉·架构
隔壁大炮6 小时前
Day06-08.CNN概述介绍
人工智能·pytorch·深度学习·算法·计算机视觉·cnn·numpy