优化图像拼接算法思路

一、多线程优化

1、多线程并行图像拼接,视频数据流检测并将每节车厢的帧数保存,然后直接使用线程调用拼接函数,不影响主程序视频流的检测。

2、基于get_img_asyncio.py程序中修改run.py

python 复制代码
def run(self, video_num=0):
        # 基本数据初始化
        run_state = False
        img_list = []
        cg_list = [[], []]
        cut = []
        img_num_list = []
        begin_num = 0
        total_res = []
        det_res_list = []
        
        # 线程池相关变量
        stich_futures = []  # 存储所有拼接任务的Future对象
        stich_tasks = []    # 存储任务元数据
        
        try:
            cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
            if not cap.isOpened():
                logging.error(f"{self.task_id} 无法打开视频源: {self.source}")
                self.running = False
                return total_res
                
            cap_info = self.get_cap_info(cap)
            self.total_stich_side.basic_shape = cap_info['shape']
            self.total_stich_side.cut_shape = [self.center_box[3], self.center_box[2]-self.center_box[0]]
            self.total_stich_side.x_area = [self.center_box[0], self.center_box[2]]
            self.total_stich_side.basic_num = self.train_num_base
            self.total_stich_side.type = self.type

            self.total_stich_up.basic_shape = cap_info['shape']
            self.total_stich_up.cut_shape = [self.center_box[3], self.center_box[2]-self.center_box[0]]
            self.total_stich_up.x_area = [self.center_box[0], self.center_box[2]]
            self.total_stich_up.basic_num = self.train_num_base
            self.total_stich_up.type = self.type

            self.image_shape = cap_info['shape']
            pre_dedtect = {'stich': False, 'hitnum': 0}
            end_dedtect = 0
            logging.info(f"{self.task_id} 成功打开视频源: {self.source}")
            stop = 0
            start_det = 0
            
            # 创建线程池(根据需求调整max_workers)
            thread_pool = ThreadPoolExecutor(max_workers=4, thread_name_prefix=f"Sticher-{self.task_id}")
            
            for total_num in range(cap_info['frames']):
                # 检查已完成的拼接任务
                completed_futures = []
                for future in stich_futures:
                    if future.done():
                        try:
                            res = future.result()
                            if res is not None:
                                total_res.append(res)
                        except Exception as e:
                            logging.error(f"拼接任务出错: {str(e)}")
                            traceback.print_exc()
                        completed_futures.append(future)
                
                # 移除已完成的任务
                for future in completed_futures:
                    stich_futures.remove(future)
                
                ret = cap.grab()
                if not ret:
                    break
                    
                if total_num % 100 == 0:
                    logging.info(f'{self.task_id}  total_num:{total_num} ')
                    
                ret, frame = cap.retrieve()
                if pre_dedtect['hitnum'] == 0 and total_num % 5 != 0:
                    continue
                    
                det_res = self.det.infer(frame)
                
                # 第一节车厢拼接,获取开头相应位置以及初始化必要信息
                if not run_state:
                    if (video_num == 0 and total_num < 100) or len(det_res[1]) == 0:
                        continue
                    else:
                        if not pre_dedtect['stich']:
                            if self.flitter_railways(det_res):
                                pre_dedtect['hitnum'] += 1
                                if pre_dedtect['hitnum'] > 5:
                                    pre_dedtect['stich'] = True
                            else:
                                if pre_dedtect['hitnum'] > 0:
                                    pre_dedtect['hitnum'] -= 1
                        else:
                            standart_H, intersection = self.calculate_union_and_check_intersection(
                                det_res[0], det_res[1], self.center_box, run_state)
                            if not intersection:
                                continue
                            begin_num = total_num
                            run_state = True
                else:
                    standart_H, intersection = self.calculate_union_and_check_intersection(
                        det_res[0], det_res[1], self.center_box, run_state)
                    
                    if not intersection:
                        if len(cg_list[0]) == 0:
                            stop += 1
                            if stop >= 15:
                                break
                            continue
                        
                        new_list = [int(abs(x - (self.center_box[2] - self.center_box[0]) / 2)) for x in cg_list[0][0]]
                        if len(new_list) == 0:
                            if end_dedtect < 10:
                                end_dedtect += 1
                                continue
                            else:
                                break
                                
                        id = new_list.index(min(new_list))
                        end = img_num_list.index(cg_list[1][id])
                        cut.append([end, cg_list[0][id]])
                        
                        if self.total_stich is None:
                            self.total_stich = self.total_stich_side
                            
                        # 创建独立的任务数据副本
                        task_imgs = img_list[:end + 1].copy()
                        task_dets = det_res_list[:end + 1].copy()
                        task_cuts = cut.copy()
                        
                        # 创建新的拼接器实例避免状态冲突
                        if self.type == 'up':
                            task_sticher = self.total_stich_up
                        else:
                            task_sticher = self.total_stich_side
                        
                        # 创建拼接器副本(浅拷贝)
                        task_sticher = copy.copy(task_sticher)
                        
                        # 设置任务专属参数
                        init_params = {
                            'basic_shape': task_sticher.basic_shape,
                            'cut_shape': task_sticher.cut_shape,
                            'x_area': task_sticher.x_area,
                            'save_path': task_sticher.save_path,
                            'train_num': task_sticher.train_num_base + len(stich_tasks),
                            'type': task_sticher.type
                        }
                        
                        # 延迟初始化任务专属状态
                        if hasattr(task_sticher, 'initialize'):
                            task_sticher.initialize(init_params)
                        
                        # 提交任务到线程池
                        future = thread_pool.submit(
                            stich_worker_thread,
                            task_imgs,
                            task_dets,
                            task_cuts,
                            task_sticher
                        )
                        
                        stich_futures.append(future)
                        stich_tasks.append({
                            'end_index': end,
                            'train_num': self.train_num_base + len(stich_tasks)
                        })
                        
                        # 更新数据结构
                        img_list = img_list[end:]
                        det_res_list = det_res_list[end:]
                        img_num_list = img_num_list[end:]
                        cut = [[0, cut[-1][1]]]
                        cg_list = [[], []]
                        self.begin_cut = True
                        
                    else:
                        det_box, cg = self.calculate_intersection_of_union_with_input(det_res, self.center_box)
                        img_list.append(frame)
                        img_num_list.append(total_num)
                        det_res_list.append(det_res)
                        
                        if cg[0] > 0:
                            if len(cut) > 0:
                                if total_num - img_num_list[0] < 30:
                                    continue
                            cg_list[0].append(cg)
                            cg_list[1].append(total_num)
                        else:
                            if len(cg_list[0]) != 0:
                                new_list = [int(abs(x[0] - (self.center_box[2] - self.center_box[0]) / 2)) for x in cg_list[0]]
                                id = new_list.index(min(new_list))
                                
                                if (total_num - begin_num) - len(cg_list[0]) < 50 and not self.begin_cut:
                                    logging.info(f"{self.task_id}开始裁切起始位置")
                                    begin = img_num_list.index(cg_list[1][id])
                                    img_list = img_list[begin:]
                                    det_res_list = det_res_list[begin:]
                                    img_num_list = img_num_list[begin:]
                                    cut.append([begin, cg_list[0][id]])
                                    cg_list = [[], []]
                                    self.begin_cut = True
                                else:
                                    end = img_num_list.index(cg_list[1][id])
                                    cut.append([end, cg_list[0][id]])
                                    
                                    if self.total_stich is None:
                                        self.total_stich = self.total_stich_side
                                    
                                    # 创建独立的任务数据副本
                                    task_imgs = img_list[:end + 1].copy()
                                    task_dets = det_res_list[:end + 1].copy()
                                    task_cuts = cut.copy()
                                    
                                    # 创建新的拼接器实例
                                    if self.type == 'up':
                                        task_sticher = self.total_stich_up
                                    else:
                                        task_sticher = self.total_stich_side
                                    
                                    task_sticher = copy.copy(task_sticher)
                                    
                                    # 设置任务专属参数
                                    init_params = {
                                        'basic_shape': task_sticher.basic_shape,
                                        'cut_shape': task_sticher.cut_shape,
                                        'x_area': task_sticher.x_area,
                                        'save_path': task_sticher.save_path,
                                        'train_num': task_sticher.train_num_base + len(stich_tasks),
                                        'type': task_sticher.type
                                    }
                                    
                                    # 延迟初始化
                                    if hasattr(task_sticher, 'initialize'):
                                        task_sticher.initialize(init_params)
                                    
                                    # 提交任务
                                    future = thread_pool.submit(
                                        stich_worker_thread,
                                        task_imgs,
                                        task_dets,
                                        task_cuts,
                                        task_sticher
                                    )
                                    
                                    stich_futures.append(future)
                                    stich_tasks.append({
                                        'end_index': end,
                                        'train_num': self.train_num_base + len(stich_tasks)
                                    })
                                    
                                    # 更新数据结构
                                    img_list = img_list[end:]
                                    det_res_list = det_res_list[end:]
                                    img_num_list = img_num_list[end:]
                                    cut = [[0, cut[-1][1]]]
                                    cg_list = [[], []]
            
            # 等待所有剩余任务完成
            for future in stich_futures:
                try:
                    res = future.result(timeout=300)  # 5分钟超时
                    if res is not None:
                        total_res.append(res)
                except TimeoutError:
                    logging.warning(f"{self.task_id} 拼接任务超时")
                except Exception as e:
                    logging.error(f"{self.task_id} 拼接任务失败: {str(e)}")
                    traceback.print_exc()
            
            # 关闭线程池
            thread_pool.shutdown(wait=True)
            
            if cap is not None:
                cap.release()
                logging.info(f"{self.task_id} 释放视频捕获资源")
                
            return total_res
            
        except Exception as e:
            logging.error(f"{self.task_id} 运行异常: {str(e)}")
            traceback.print_exc()
            
            # 尝试关闭线程池
            try:
                thread_pool.shutdown(wait=False, cancel_futures=True)
            except:
                pass
                
            return total_res

上面的代码中,主要通过:

提交任务到线程池

future = thread_pool.submit(

stich_worker_thread,

task_imgs,

task_dets,

task_cuts,

task_sticher

)

来调用拼接程序。

3、基于get_img_asyncio.py程序中给出一个全局函数stich_worker_thread

python 复制代码
def stich_worker_thread(imgs, dets, cuts, sticher):
    try:
        return sticher.stich_images(imgs, dets, cuts)
    except Exception as e:
        logging.error(f"拼接线程出错: {str(e)}")
        traceback.print_exc()
        return None

4、在TotalStich_sideTotalStich_up类中添加initialize方法

python 复制代码
def initialize(self, init_params):
    """初始化线程专有状态"""
    for key, value in init_params.items():
        setattr(self, key, value)
    
    # 初始化需要延迟创建的资源
    self.weight_map = self.create_weight_map()
    self.img_list = []
    self.cut = False
    # 其他需要重置的状态..
相关推荐
Fcy6481 小时前
算法基础详解(4)双指针算法
开发语言·算法·双指针
xwz小王子2 小时前
Nature Communications从结构到功能:基于Kresling折纸的多模态微型机器人设计
人工智能·算法·机器人
luj_17682 小时前
从R语言想起的,。。。
服务器·c语言·开发语言·经验分享·算法
计算机安禾2 小时前
【数据结构与算法】第29篇:红黑树原理与C语言模拟
c语言·开发语言·数据结构·c++·算法·visual studio
生信研究猿2 小时前
94. 二叉树的中序遍历 (二叉树遍历整理)
数据结构·算法
挂科边缘2 小时前
image-restoration-sde复现,图像修复,使用均值回复随机微分方程进行图像修复,ICML 2023
算法·均值算法·ir-sde·扩散模块图像修复
2301_822703202 小时前
开源鸿蒙跨平台Flutter开发:血氧饱和度数据降噪:基于滑动窗口的滤波算法优化-利用动态列队 (Queue) 与时间窗口平滑光电容积脉搏波 (PPG)
算法·flutter·华为·开源·harmonyos
Vin0sen2 小时前
算法-线段树与树状数组
算法
sycmancia2 小时前
QT——计算器核心算法
开发语言·qt·算法