import os
import cv2
import time
import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from EfficientLoFTR.src.loftr import LoFTR, full_default_cfg, opt_default_cfg, reparameter
import matplotlib.pyplot as plt
class LoFTRMatcher:
def __init__(self, model_type='full', precision='fp32', weight_path=None):
self.model_type = model_type
self.precision = precision
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.matcher = self._load_model(weight_path)
def _load_model(self, weight_path):
cfg = deepcopy(full_default_cfg if self.model_type == 'full' else opt_default_cfg)
if self.precision == 'mp':
cfg['mp'] = True
elif self.precision == 'fp16':
cfg['half'] = True
matcher = LoFTR(config=cfg)
if weight_path is None:
raise ValueError("必须指定模型权重路径")
# 加载带完整结构的 checkpoint(不加 weights_only=True)
checkpoint = torch.load(weight_path, map_location=self.device, weights_only=False)
matcher.load_state_dict(checkpoint['state_dict'])
matcher = reparameter(matcher)
if self.precision == 'fp16':
matcher = matcher.half()
matcher.eval().to(self.device)
return matcher
def _preprocess_image(self, img, to_gray=True):
if isinstance(img, str):
img = cv2.imread(img, cv2.IMREAD_GRAYSCALE if to_gray else cv2.IMREAD_COLOR)
elif isinstance(img, np.ndarray):
if to_gray and len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
elif to_gray and len(img.shape) == 2:
pass
else:
raise ValueError("图像应为灰度或 BGR 图像")
else:
raise TypeError("输入应为图像路径或 numpy.ndarray")
img = cv2.resize(img, (img.shape[1] // 32 * 32, img.shape[0] // 32 * 32))
tensor = torch.from_numpy(img)[None][None]
tensor = tensor.half() if self.precision == 'fp16' else tensor.float()
return tensor.to(self.device) / 255., img
def match_images(self, img0, img1):
img0_tensor, img0_raw = self._preprocess_image(img0)
img1_tensor, img1_raw = self._preprocess_image(img1)
batch = {'image0': img0_tensor, 'image1': img1_tensor}
with torch.no_grad():
if self.precision == 'mp':
with torch.autocast(enabled=True, device_type='cuda'):
self.matcher(batch)
else:
self.matcher(batch)
mkpts0 = batch['mkpts0_f'].cpu().numpy()
mkpts1 = batch['mkpts1_f'].cpu().numpy()
mconf = batch['mconf'].cpu().numpy()
return mkpts0, mkpts1, mconf, img0_raw, img1_raw
def normalize_confidence(self, mconf):
if mconf.size == 0:
return mconf # 空数组直接返回,避免后续 min/max 报错
if self.model_type == 'opt':
mconf = (mconf - min(20.0, mconf.min())) / (max(30.0, mconf.max()) - min(20.0, mconf.min()))
mconf = np.clip(mconf, 0, 1)
else:
mconf_min, mconf_max = mconf.min(), mconf.max()
if mconf_max - mconf_min > 1e-5:
mconf = (mconf - mconf_min) / (mconf_max - mconf_min)
else:
mconf = np.zeros_like(mconf)
return mconf
def draw_matches(self, img0, img1, mkpts0, mkpts1, output_path=None):
img0_color = cv2.cvtColor(img0, cv2.COLOR_GRAY2BGR)
img1_color = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
h0, w0 = img0.shape[:2]
canvas = np.zeros((max(h0, img1.shape[0]), w0 + img1.shape[1], 3), dtype=np.uint8)
canvas[:h0, :w0] = img0_color
canvas[:img1.shape[0], w0:] = img1_color
for pt0, pt1 in zip(mkpts0, mkpts1):
pt0 = tuple(np.round(pt0).astype(int))
pt1 = tuple(np.round(pt1).astype(int))
cv2.line(canvas, pt0, (pt1[0] + w0, pt1[1]), (0, 255, 0), 1, cv2.LINE_AA)
cv2.circle(canvas, pt0, 3, (0, 255, 0), -1, cv2.LINE_AA)
cv2.circle(canvas, (pt1[0] + w0, pt1[1]), 3, (0, 255, 0), -1, cv2.LINE_AA)
cv2.putText(canvas, f"Matches: {len(mkpts0)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1,
(255, 255, 0), 2, cv2.LINE_AA)
if output_path:
cv2.imwrite(output_path, canvas)
print(f"[保存] 匹配图像写入到:{output_path}")
return canvas
def match_video(self, query_img_path, video_path, output_path, frame_interval=10, conf_thresh=0.7, min_match_points=10):
video_start_time = time.time()
query_tensor, query_raw = self._preprocess_image(query_img_path)
h0, w0 = query_raw.shape[:2]
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError("视频无法打开")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
pbar = tqdm(total=total_frames, desc='处理视频')
frame_count = 0
last_corners = None
while True:
ret, frame = cap.read()
if not ret:
break
vis_frame = frame.copy()
if frame_count % frame_interval == 0:
frame_tensor, frame_gray = self._preprocess_image(frame)
batch = {'image0': query_tensor, 'image1': frame_tensor}
with torch.no_grad():
self.matcher(batch)
mkpts0 = batch['mkpts0_f'].cpu().numpy()
mkpts1 = batch['mkpts1_f'].cpu().numpy()
mconf = batch['mconf'].cpu().numpy()
if mconf.size == 0:
print(f"[中止] query 图像无法与视频帧建立任何匹配点,请更换更清晰或更大尺寸的图像。")
cap.release()
writer.release()
pbar.close()
return
mconf = self.normalize_confidence(mconf)
valid = mconf > conf_thresh
mkpts0, mkpts1 = mkpts0[valid], mkpts1[valid]
if len(mkpts0) >= min_match_points:
H, _ = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5.0)
if H is not None:
corners = np.array([[0, 0], [w0, 0], [w0, h0], [0, h0]], dtype=np.float32)
last_corners = cv2.perspectiveTransform(corners[None], H)[0]
x, y, w, h = cv2.boundingRect(np.round(last_corners).astype(int))
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
cv2.putText(vis_frame, f"Match: {len(mkpts0)}", (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
else:
last_corners = None
else:
last_corners = None
elif last_corners is not None:
x, y, w, h = cv2.boundingRect(np.round(last_corners).astype(int))
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
writer.write(vis_frame)
frame_count += 1
pbar.update(1)
cap.release()
writer.release()
pbar.close()
video_end_time = time.time()
total_time = video_end_time - video_start_time
fps_real = frame_count / total_time if total_time > 0 else 0
print(f"[完成] 视频已保存到:{output_path}")
print(f"[统计] 总耗时: {total_time:.2f} 秒, 实际平均帧率: {fps_real:.2f} FPS")
# 示例调用(可放入 test_loftr.py 中)
if __name__ == "__main__":
matcher = LoFTRMatcher(
model_type='full',
precision='fp32',
weight_path='E:/test/EfficientLoFTR/weights/eloftr_outdoor.ckpt'
)
# 图像匹配测试
mkpts0, mkpts1, mconf, img0, img1 = matcher.match_images(
img0="E:/test/files/car.png",
img1="E:/test/files/00001.jpg"
)
matcher.draw_matches(img0, img1, mkpts0, mkpts1, output_path="E:/test/sam2-main/files/match_result.jpg")
# 视频匹配测试
matcher.match_video(
query_img_path="E:/test/files/car.png",
video_path="E:/test/files/c003.mp4",
output_path="E:/test/files/out.mp4",
frame_interval=2,
conf_thresh=0.75,
min_match_points=12
)
注意EfficientLoFTR的安装参考git,EfficientLoFTR的使用会比LoFTR更高效,检测出的匹配点也更多。