import os
import cv2
import time
import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from EfficientLoFTR.src.loftr import LoFTR, full_default_cfg, opt_default_cfg, reparameter
import matplotlib.pyplot as plt
class LoFTRMatcher:
    def __init__(self, model_type='full', precision='fp32', weight_path=None):
        self.model_type = model_type
        self.precision = precision
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.matcher = self._load_model(weight_path)
    def _load_model(self, weight_path):
        cfg = deepcopy(full_default_cfg if self.model_type == 'full' else opt_default_cfg)
        if self.precision == 'mp':
            cfg['mp'] = True
        elif self.precision == 'fp16':
            cfg['half'] = True
        matcher = LoFTR(config=cfg)
        if weight_path is None:
            raise ValueError("必须指定模型权重路径")
        # 加载带完整结构的 checkpoint(不加 weights_only=True)
        checkpoint = torch.load(weight_path, map_location=self.device, weights_only=False)
        matcher.load_state_dict(checkpoint['state_dict'])
        matcher = reparameter(matcher)
        if self.precision == 'fp16':
            matcher = matcher.half()
        matcher.eval().to(self.device)
        return matcher
    def _preprocess_image(self, img, to_gray=True):
        if isinstance(img, str):
            img = cv2.imread(img, cv2.IMREAD_GRAYSCALE if to_gray else cv2.IMREAD_COLOR)
        elif isinstance(img, np.ndarray):
            if to_gray and len(img.shape) == 3:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            elif to_gray and len(img.shape) == 2:
                pass
            else:
                raise ValueError("图像应为灰度或 BGR 图像")
        else:
            raise TypeError("输入应为图像路径或 numpy.ndarray")
        img = cv2.resize(img, (img.shape[1] // 32 * 32, img.shape[0] // 32 * 32))
        tensor = torch.from_numpy(img)[None][None]
        tensor = tensor.half() if self.precision == 'fp16' else tensor.float()
        return tensor.to(self.device) / 255., img
    def match_images(self, img0, img1):
        img0_tensor, img0_raw = self._preprocess_image(img0)
        img1_tensor, img1_raw = self._preprocess_image(img1)
        batch = {'image0': img0_tensor, 'image1': img1_tensor}
        with torch.no_grad():
            if self.precision == 'mp':
                with torch.autocast(enabled=True, device_type='cuda'):
                    self.matcher(batch)
            else:
                self.matcher(batch)
        mkpts0 = batch['mkpts0_f'].cpu().numpy()
        mkpts1 = batch['mkpts1_f'].cpu().numpy()
        mconf = batch['mconf'].cpu().numpy()
        return mkpts0, mkpts1, mconf, img0_raw, img1_raw
    def normalize_confidence(self, mconf):
        if mconf.size == 0:
            return mconf  # 空数组直接返回,避免后续 min/max 报错
    
        if self.model_type == 'opt':
            mconf = (mconf - min(20.0, mconf.min())) / (max(30.0, mconf.max()) - min(20.0, mconf.min()))
            mconf = np.clip(mconf, 0, 1)
        else:
            mconf_min, mconf_max = mconf.min(), mconf.max()
            if mconf_max - mconf_min > 1e-5:
                mconf = (mconf - mconf_min) / (mconf_max - mconf_min)
            else:
                mconf = np.zeros_like(mconf)
        return mconf
    def draw_matches(self, img0, img1, mkpts0, mkpts1, output_path=None):
        img0_color = cv2.cvtColor(img0, cv2.COLOR_GRAY2BGR)
        img1_color = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
        h0, w0 = img0.shape[:2]
        canvas = np.zeros((max(h0, img1.shape[0]), w0 + img1.shape[1], 3), dtype=np.uint8)
        canvas[:h0, :w0] = img0_color
        canvas[:img1.shape[0], w0:] = img1_color
        for pt0, pt1 in zip(mkpts0, mkpts1):
            pt0 = tuple(np.round(pt0).astype(int))
            pt1 = tuple(np.round(pt1).astype(int))
            cv2.line(canvas, pt0, (pt1[0] + w0, pt1[1]), (0, 255, 0), 1, cv2.LINE_AA)
            cv2.circle(canvas, pt0, 3, (0, 255, 0), -1, cv2.LINE_AA)
            cv2.circle(canvas, (pt1[0] + w0, pt1[1]), 3, (0, 255, 0), -1, cv2.LINE_AA)
        cv2.putText(canvas, f"Matches: {len(mkpts0)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1,
                    (255, 255, 0), 2, cv2.LINE_AA)
        if output_path:
            cv2.imwrite(output_path, canvas)
            print(f"[保存] 匹配图像写入到:{output_path}")
        return canvas
    def match_video(self, query_img_path, video_path, output_path, frame_interval=10, conf_thresh=0.7, min_match_points=10):
        video_start_time = time.time()
        query_tensor, query_raw = self._preprocess_image(query_img_path)
        h0, w0 = query_raw.shape[:2]
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise RuntimeError("视频无法打开")
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        pbar = tqdm(total=total_frames, desc='处理视频')
        frame_count = 0
        last_corners = None
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            vis_frame = frame.copy()
            if frame_count % frame_interval == 0:
                frame_tensor, frame_gray = self._preprocess_image(frame)
                batch = {'image0': query_tensor, 'image1': frame_tensor}
                with torch.no_grad():
                    self.matcher(batch)
                    mkpts0 = batch['mkpts0_f'].cpu().numpy()
                    mkpts1 = batch['mkpts1_f'].cpu().numpy()
                    mconf = batch['mconf'].cpu().numpy()
                
                if mconf.size == 0:
                    print(f"[中止] query 图像无法与视频帧建立任何匹配点,请更换更清晰或更大尺寸的图像。")
                    cap.release()
                    writer.release()
                    pbar.close()
                    return
                mconf = self.normalize_confidence(mconf)
                valid = mconf > conf_thresh
                mkpts0, mkpts1 = mkpts0[valid], mkpts1[valid]
                if len(mkpts0) >= min_match_points:
                    H, _ = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5.0)
                    if H is not None:
                        corners = np.array([[0, 0], [w0, 0], [w0, h0], [0, h0]], dtype=np.float32)
                        last_corners = cv2.perspectiveTransform(corners[None], H)[0]
                        x, y, w, h = cv2.boundingRect(np.round(last_corners).astype(int))
                        cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
                        cv2.putText(vis_frame, f"Match: {len(mkpts0)}", (x, y - 10),
                                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    else:
                        last_corners = None
                else:
                    last_corners = None
            elif last_corners is not None:
                x, y, w, h = cv2.boundingRect(np.round(last_corners).astype(int))
                cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
            
            writer.write(vis_frame)
            frame_count += 1
            pbar.update(1)
        cap.release()
        writer.release()
        pbar.close()
        video_end_time = time.time()
        total_time = video_end_time - video_start_time
        fps_real = frame_count / total_time if total_time > 0 else 0
        print(f"[完成] 视频已保存到:{output_path}")
        print(f"[统计] 总耗时: {total_time:.2f} 秒, 实际平均帧率: {fps_real:.2f} FPS")
# 示例调用(可放入 test_loftr.py 中)
if __name__ == "__main__":
    matcher = LoFTRMatcher(
        model_type='full',
        precision='fp32',
        weight_path='E:/test/EfficientLoFTR/weights/eloftr_outdoor.ckpt'
    )
    # 图像匹配测试
    mkpts0, mkpts1, mconf, img0, img1 = matcher.match_images(
        img0="E:/test/files/car.png",
        img1="E:/test/files/00001.jpg"
    )
    matcher.draw_matches(img0, img1, mkpts0, mkpts1, output_path="E:/test/sam2-main/files/match_result.jpg")
    # 视频匹配测试
    matcher.match_video(
        query_img_path="E:/test/files/car.png",
        video_path="E:/test/files/c003.mp4",
        output_path="E:/test/files/out.mp4",
        frame_interval=2,
        conf_thresh=0.75,
        min_match_points=12
    )注意EfficientLoFTR的安装参考git,EfficientLoFTR的使用会比LoFTR更高效,检测出的匹配点也更多。