CoTracker 环境配置&与ORB 特征点提取结合实现视频特征点追踪

CoTracker 环境配置&与ORB 特征点提取结合实现视频特征点追踪

文章目录

Meta 新开源 CoTracker :跟踪任意长视频中的任意多个点,并且可以随时添加新的点进行跟踪!并且性能上直接超越了谷歌的 OmniMotion
我所做的项目是对相机捕获的图像进行实时追踪。当时没有研究过这个网络,所以想着配一下环境,看看后续可不可以应用在相机上。
但是:事与愿违,配好了环境,并且在 Demo 里面也可以获取视频,对视频第一帧进行 ORB 特征点识别然后在全局视频里面进行追踪,可是发现没有办法进行相机的实时跟踪处理。
后面在大致看过网络结构(其实)以及相关文献之后,终于确定 ,这个牛逼的 CoTracker 因为其网路输入只能是
视频格式的长时间数据
因此并不能进行相机的实时处理。所以如果后面的小伙伴也要用相机去做,建议搜索 LightGlue 等其他的方法(光流法、或者神经网络)等等进行实时追踪。
想继续了解 CoTracker 原理的小伙伴可以参考这一篇博文相关链接: CoTracker跟踪器 - CoTracker: It is Better to Track Together
CoTracker 项目的源代码链接也在这里,可自行下载: co-tracker

Step1:配置 CoTracker 环境

首先下载 conda,然后安装虚拟环境。

bash 复制代码
	conda craete -n cotracker python=3.8
	conda activate cotracker

然后根据官方提示从 Github 上面下载源码。

参考官方的提示,这个项目支持在 CPU 和 GPU 上运行,因此在配置环境时建议同时安装支持 CUDA 的 PyTorch 和 TorchVision

官方链接的终端命令贴出来了,需要可自行粘帖。

bash 复制代码
	git clone https://github.com/facebookresearch/co-tracker
	cd co-tracker
	pip install -e .
	pip install matplotlib flow_vis tqdm tensorboard

因为官方有已经训练好的权重文件,我们只需要下载下来就可以在 Demo 里面直接调用。命令也在此处。

bash 复制代码
	mkdir -p checkpoints
	cd checkpoints
	wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth
	cd ..

当然,这个 CoTracker 在配置环境过程中肯定会有一些库的版本不对,因此需要重新卸载再安装一些库的版本。

以下是我的 cotracker 虚拟环境 里面需要的库版本(只摘出来 Setup.py 文件 里安装的,以及通过命令行安装的库)。大家可自行对照。

bash 复制代码
	matplotlib                    3.7.3
	flow-vis                      0.1

	opencv-python                 4.8.1.78

	torch                         2.1.1
	torchaudio                    2.1.1
	torchsummary                  1.5.1
	torchvision                   0.16.1
	tqdm                          4.66.1
	tensorboard#(没找到,不过并不影响 CoTracker 的使用)

Step2:运行官方的例程

官方有一份 demo.py 文件可以直接调用一些接口,方便进行视频的处理,但是为了更好的了解里面的一些借口的参数。建议可以参考项目里面的 demo.ipynb 文件,按照里面的步骤,自己重新写一个 demo 文件。

Step3:结合 ORB 特征点提取

为了下一步进行视频帧追踪预演,提前编写了一个针对连续图像读取并追踪的代码(注意:代码里面输入的不是一个视频,而是将一连串连续的图片转换成张量的数据格式传入了 GPU,所以虽然不是视频,但是效果差不多)。如下所示:

python 复制代码
import os
import cv2
import torch
import argparse
import numpy as np
from base64 import b64encode
from PIL import Image
import matplotlib.pyplot as plt
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor
import torch.nn.functional as F


def convert_images_to_tensor(image_folder):
    image_files = sorted(os.listdir(image_folder))  # 获取图片文件列表并排序
    first_path = os.path.join(image_folder, image_files[0])
    print(first_path)
    images = []
    n = 0
    for image_file in image_files:
        n += 1
        print(n)
        image_path = os.path.join(image_folder, image_file)
        image = cv2.imread(image_path)  # 使用OpenCV读取图片
        height, width, _ = image.shape
        left_half = image[:, :width//2, :]
        image = cv2.cvtColor(left_half, cv2.COLOR_BGR2RGB)  # 将图片从BGR颜色空间转换为RGB
        image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()  # 转换为PyTorch张量
        images.append(image_tensor)

    video_tensor = torch.stack(images)
    video_tensor = video_tensor.permute(1, 0, 2, 3, 4)  # 转换成视频张量的形式
    shape = video_tensor.shape
    print(shape[0], shape[1], shape[2], shape[3], shape[4])
    return first_path, video_tensor


# 特征点检测的参数
max_corners = 30
quality_level = 0.1
min_distance = 200


def orb_track_points(first_image_path):
    raw_image = cv2.imread(first_image_path)
    height, width, _ = raw_image.shape
    raw_left_image = raw_image[:, :width // 2, :]                # 只取左边部分
    corners = cv2.goodFeaturesToTrack(cv2.cvtColor(raw_left_image, cv2.COLOR_BGR2GRAY), max_corners, quality_level, min_distance)
    corners = np.int0(corners)
    queries = []
    # 將图像上检测到的特征点,添加到追踪里面
    for corner in corners:
        x, y = corner.ravel()
        # cv2.circle(raw_left_image, (x, y), 2, vector_color[i].tolist(), 2)
        coordinate = [0., float(x), float(y)]
        queries.append(coordinate)
    queries = torch.tensor(queries)
    print(queries)
    # 并将图像上选取的点变成张量输入
    if torch.cuda.is_available():
        queries = queries.cuda()
    # 创建了一个包含四个子图的2x2图像网格,用于可视化查询点的位置,将查询点的帧号提取出来,并转换为整数类型的列表 frame_numbers。帧号将用于在每个子图上显示对应的帧数
    frame_numbers = queries[:, 0].int().tolist()
    # plt.subplots()函数创建了一个图像网格, 并将返回的"轴"对象存储在变量axs中
    fig, axs = plt.subplots(1, 1)
    # 通过调用axs.set_title()设置子图的标题为"Frame {}
    axs.set_title("Frame {}".format(0))
    # 通过enumerate()函数同时迭代查询点(query)和对应的帧号(frame_number)
    for i, (query, frame_number) in enumerate(zip(queries, frame_numbers)):
        # 使用plot()函数在该子图上绘制一个红色的点,其坐标为(query[1].item(), query[2].item())
        axs.plot(query[1].item(), query[2].item(), 'ro')
        # 设置子图的x和y轴范围
        axs.set_xlim(0, video.shape[4])
        axs.set_ylim(0, video.shape[3])
        # 翻转y轴,以与视频的坐标系一致
        axs.invert_yaxis()
    # 调整子图之间的布局
    plt.tight_layout()
    plt.savefig('./saved_videos/image_grid.png')
    return queries


# 指定图片文件夹路径
# images_folder = "./assets/1212/snapSave_p/Cam_2"      # Pitch 俯仰角
images_folder = "./assets/1212/snapSave_r/Cam_2"      # Roll  翻滚角
# images_folder = "./assets/1212/snapSave_y/Cam_2"      # Taw   偏航角

# 调用函数将图片转换为张量
first_im_path, video = convert_images_to_tensor(images_folder)
image_queries = orb_track_points(first_im_path)

model = CoTrackerPredictor(checkpoint=os.path.join(
    './checkpoints/cotracker_stride_4_wind_8.pth')
)

if torch.cuda.is_available():
    model = model.cuda()
    video = video.cuda()

# 前向
pred_tracks, pred_visibility = model(video, queries=image_queries[None])
print("数据计算完毕")
vis = Visualizer(save_dir='./saved_videos', linewidth=2, mode='cool', tracks_leave_trace=-1)

# tracks_leave_trace = -1 可以显示出跟踪点的轨迹
vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='orb_track')
print("视频存储完成")
# 原文里面有考虑对相机运动的补偿消除一些影响,但是代码里面这一部分设定为 False,即没有考虑相机运动的影响
# 因此 pred_tracks, pred_visibility 即跟踪真实值
track_save_data = './saved_videos/track_data'
if not os.path.exists(track_save_data):
    os.makedirs(track_save_data)

for i in range(max_corners):
    format_i = "{:02d}".format(i)
    with open(track_save_data + '/save_data_' + str(format_i), 'w') as data_txt:
        for pred_track in pred_tracks[0]:
            point_track = str(pred_track[i][0].item()) + ' ' + \
                          str(pred_track[i][1].item()) + '\n'
            data_txt.write(point_track)
    data_txt.close()

print("数据文件关闭")

结果展示:

orb_track_pred_track

Step4:针对相机进行实时追踪,但失败

还是之前说的,因为 CoTracker 的神经网络本身在训练模型的时候就是以视频作为输入数据进行输入的,因此针对连续图片可以做到追踪,但时如果只是单个图片,那么追踪将无法进行。

下面可能就有小伙伴会想,通过缩小传入视频的帧率再输入。例如将 3 ~ 4 帧的图片作为一个短视频输入进去,然后计算出来结果后,将结果保存并用于下一个短视频的追踪,如此往复,实现相机实时追踪效果。

这个方向我也尝试过,但时 CoTracker 本身在进行视频的特征点计算的时候,就极其消耗算力。而且这个消耗的时间随着传入的视频时间以及要追踪的特征点数量线性增加

我的设备是 RTX4060 和 i7-12650 。性能还算可以。但是在传入一个连续 5 帧 的视频,并追踪 10 个点 的时候,依旧要花费 0.3 ~ 0.4 秒时间 计算。出现的结构就是,视频一卡一卡的,实时跟踪效果很差。
(为什么传入 5 帧? 因为 5 帧已经是网络输入要求的最低帧数了,再小就没有结果输出了。)

代码依旧贴在下面,其实就是在上面视频的基础上进行的改进:

python 复制代码
import os
import cv2
import torch
import argparse
import numpy as np
from base64 import b64encode
from PIL import Image
import matplotlib.pyplot as plt
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor
import torch.nn.functional as F
import time


def mkdir():
    if not os.path.exists(saved_videos):
        os.makedirs(saved_videos)


def initialize(first_image):
    n = 5
    i = 0
    images_pytorch = []
    image = cv2.cvtColor(first_image, cv2.COLOR_BGR2RGB)            # 将第一张图片从BGR颜色空间转换为RGB
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()  # 转换为PyTorch张量
    images_pytorch.append(image_tensor)
    while i < 5:
        ret, current_image = cap.read()
        image = cv2.cvtColor(current_image, cv2.COLOR_BGR2RGB)  # 将第一张图片从BGR颜色空间转换为RGB
        image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()  # 转换为PyTorch张量
        images_pytorch.append(image_tensor)
        i += 1

    # 將图片张量转换成网络输入视频张量的形式
    video_tensor = torch.stack(images_pytorch)
    video_tensor = video_tensor.permute(1, 0, 2, 3, 4)  # 转换成视频张量的形式
    print("video tensor------------------------------------------------------")
    print(video_tensor)
    return images_pytorch, video_tensor


# 特征点检测的参数
max_corners = 5
quality_level = 0.1
min_distance = 100


def orb_track_points(first_image_):
    # # 仿生眼相机图像前处理部分
    # raw_image = cv2.imread(first_image_path)
    # height, width, _ = raw_image.shape
    # raw_left_image = raw_image[:, :width // 2, :]                # 只取左边部分
    # corners = cv2.goodFeaturesToTrack(cv2.cvtColor(raw_left_image, cv2.COLOR_BGR2GRAY), max_corners, quality_level, min_distance)
    # corners = np.int0(corners)
    # # 电脑相机图像处理部分
    corners = cv2.goodFeaturesToTrack(cv2.cvtColor(first_image_, cv2.COLOR_BGR2GRAY), max_corners, quality_level, min_distance)
    corners = np.int0(corners)
    queries = []
    # 將图像上检测到的特征点,添加到追踪里面
    for corner in corners:
        x, y = corner.ravel()
        coordinate = [0., float(x), float(y)]
        queries.append(coordinate)
    queries = torch.tensor(queries)
    # 并将图像上选取的点变成张量输入
    if torch.cuda.is_available():
        queries = queries.cuda()
    return queries


def convert_images_to_tensor(current_image, pre_images_pytorch):
    # # 將当前图像转换成 pytorch 张量,仿生眼相机图像预处理
    # height, width, _ = current_image.shape
    # left_half = current_image[:, :width // 2, :]
    # image = cv2.cvtColor(left_half, cv2.COLOR_BGR2RGB)  # 将图片从BGR颜色空间转换为RGB
    # 將当前图像转换成 pytorch 张量,电脑相机图像预处理
    image = cv2.cvtColor(current_image, cv2.COLOR_BGR2RGB)  # 将图片从BGR颜色空间转换为RGB
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()  # 转换为PyTorch张量
    # 再將图片存入 pre_images 里面进行后续的跟踪计算
    pre_images_pytorch.append(image_tensor)
    update_images_pytorch = pre_images_pytorch[1:]
    # print("update_images_pytorch: %d", update_images_pytorch)
    # 將图片张量转换成网络输入视频张量的形式
    video_tensor = torch.stack(update_images_pytorch)
    video_tensor = video_tensor.permute(1, 0, 2, 3, 4)  # 转换成视频张量的形式
    return update_images_pytorch, video_tensor


if __name__ == '__main__':
    saved_videos = "./assets/saved_videos/"
    mkdir()
    # 开启相机获取图像
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("无法打开视频文件")
        exit()
    ret, first_frame = cap.read()
    if not ret:
        print("无法获取图像")
        exit()
    first_queries = orb_track_points(first_frame)
    first_images_pytorch, first_video = initialize(first_frame)
    print(first_queries)        # 分别是 0, x, y

    # 加载模型文件
    model = CoTrackerPredictor(checkpoint=os.path.join('./checkpoints/cotracker_stride_4_wind_8.pth'))
    print("模型创建完毕")
    # 將视频数据和模型数据转换
    if torch.cuda.is_available():
        model = model.cuda()
        first_video = first_video.cuda()
    # 前向
    first_tracks, first_visibility = model(first_video, queries=first_queries[None])        # 此处的 None 是用来增加维度的
    print("数据计算完毕")
    vis = Visualizer(save_dir=saved_videos, linewidth=2, mode='cool', tracks_leave_trace=-1)  # t_l_t:-1显示跟踪轨迹
    print('----------------------------------------------------------------------pre')
    print(first_tracks[0])
    vis.visualize(video=first_video, tracks=first_tracks, visibility=first_visibility, filename='orb_track')
    print("视频存储完成")

    images_pytorch = first_images_pytorch
    # 跟踪部分
    while True:
        ret, current_frame = cap.read()
        cv2.imshow("current", current_frame)
        cv2.waitKey(20)
        images_pytorch, current_video = convert_images_to_tensor(current_frame, images_pytorch)
        # 將视频数据和模型数据转换
        if torch.cuda.is_available():
            model = model.cuda()
            current_video = current_video.cuda()
        # 前向
        current_tracks, current_visibility = model(current_video, queries=first_queries[None])  # 此处的 None 是用来增加维度的
        print("----------------------------------------------------------")
        print(current_tracks[0][0])
        print("数据计算完毕")

5.内部代码的修改

原本代码里面为了显示跟踪的连续性,在可视化部分 ,将追踪点在不同时间段的轨迹连成了一条线。

我的项目里面之前为了结果的点的轨迹可以清楚一些,因此修改了原本可视化里面连线的部分,该成了画点。如下所示,里面注释掉的部分为曾经画线的代码,下面新增的为画点的代码

python 复制代码
    def _draw_pred_tracks(
        self,
        rgb: np.ndarray,  # H x W x 3
        tracks: np.ndarray,  # T x 2
        vector_colors: np.ndarray,
        alpha: float = 0.5,
    ):
        radius = 2  # 半径
        thickness = 2  # 线条宽度
        T, N, _ = tracks.shape
        for s in range(T - 1):
            vector_color = vector_colors[s]
            original = rgb.copy()
            alpha = (s / T) ** 2
            for i in range(N):
                coord_x = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
                if coord_x[0] != 0 and coord_x[1] != 0:
                    cv2.circle(rgb, coord_x, radius, vector_color[i].tolist(), thickness)   # 直接画出之前轨迹的点
            if self.tracks_leave_trace > 0:
                rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
        #   遍历之前追踪的点集,然后连接相邻两点,画一条直线,构成轨迹图
#         for s in range(T - 1):
#             vector_color = vector_colors[s]
#             original = rgb.copy()
#             alpha = (s / T) ** 2
#             for i in range(N):
#                 coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
#                 coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
#                 if coord_y[0] != 0 and coord_y[1] != 0:
#                     cv2.line(
#                         rgb,
#                         coord_y,
#                         coord_x,
#                         vector_color[i].tolist(),
#                         self.linewidth,
#                         cv2.LINE_AA,
#                     )
#             if self.tracks_leave_trace > 0:
#                 rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)

        return rgb

当然不排除可能是本人技术太菜无法实现 CoTracker 的相机实时性追踪。如果后面有小伙伴实现了,欢迎在评论区里面分享。

相关推荐
大耳朵土土垚20 分钟前
【Linux 】开发利器:深度探索 Vim 编辑器的无限可能
linux·编辑器·vim
sp_fyf_202423 分钟前
浅谈计算机视觉的学习路径1
计算机视觉
极客小张28 分钟前
基于STM32MP157与OpenCV的嵌入式Linux人脸识别系统开发设计流程
linux·stm32·单片机·opencv·物联网
x66ccff33 分钟前
【linux】4张卡,坏了1张,怎么办?
linux·运维·服务器
jjb_2361 小时前
LinuxC高级作业2
linux·bash
OH五星上将1 小时前
OpenHarmony(鸿蒙南向开发)——小型系统内核(LiteOS-A)【扩展组件】上
linux·嵌入式硬件·harmonyos·openharmony·鸿蒙开发·liteos-a·鸿蒙内核
VB.Net1 小时前
EmguCV学习笔记 VB.Net 12.3 OCR
opencv·计算机视觉·c#·ocr·图像·vb.net·emgucv
拾光师2 小时前
linux之网络命令
linux·服务器·网络
ShuQiHere2 小时前
【ShuQiHere】 探索计算机视觉的世界:从基础到应用
人工智能·计算机视觉
我命由我123452 小时前
GPIO 理解(基本功能、模拟案例)
linux·运维·服务器·c语言·c++·嵌入式硬件·c#