[sam2图像分割] 视频追踪API | VideoPredictor | `inference_state`记忆

第二章:SAM2视频预测器(视频追踪API)

欢迎回来

第一章:SAM2图像预测器(图像推理API)中,我们学习了SAM2ImagePredictor如何帮助我们精确地从单张图片中分割对象。它就像是为静态图像配备了一位超级智能的修图师。

但如果你的目标对象不是静止的呢?如果你的宠物猫正在视频中奔跑,而你希望每一帧都能突出显示它,手动在数百甚至数千帧中点击猫咪简直是噩梦!

这时,我们的下一个强大工具SAM2VideoPredictor就派上用场了。你可以把它想象成SAM-2的专属视频追踪导演 。它不仅能够编辑单张图像,还能智能地追踪并分割视频中移动的对象。

解决的问题

SAM2VideoPredictor的核心任务是视频对象分割(VOS),即在视频的所有帧中找到并勾勒出特定对象。

假设有一段繁忙街道的视频,想从红车出现的那一刻开始追踪,直到它驶出画面。这非常具有挑战性,因为:

  • 对象会移动和变形:车辆可能转弯、靠近或被部分遮挡。
  • 光照变化:阴影、阳光或夜晚会改变其外观。
  • 遮挡问题:其他车辆或物体可能暂时挡住红车。

SAM2VideoPredictor通过记忆对象的外观、预测其位置,并根据新帧调整预测来解决这些问题。它就像一位专业的视频编辑,能够智能地跟随并高亮对象,即使它暂时消失又出现!

视频追踪导演

让我们拆解SAM2VideoPredictor如何完成这项复杂任务,就像导演指挥一部电影:

  1. 场景设置(inference_state

    在开始追踪之前,导演需要一个"项目文件"。inference_state是一个特殊的存储区,记录视频和待分割对象的所有重要信息,包括视频帧、初始提示(如点击红车)、历史预测以及描述对象随时间变化的"记忆特征"。随着追踪的进行,这个状态会不断更新。

  2. 初始选角(添加点击/掩膜)

    你告诉导演要追踪哪个对象,通常通过在第一帧点击红车或绘制一个粗略的掩膜来完成。导演会将这些信息记录在inference_state中。

  3. 故事推进(视频追踪)

    一旦有了初始提示,导演就会接管工作。它利用inference_state中的信息预测对象在下一帧 的位置和形状,更新记忆,并继续处理后续帧。它甚至可以反向 追踪时间,这个过程称为分割传播

简而言之,SAM2VideoPredictor根据初始指引,自动在每一帧中找到目标对象,使视频分割变得高效实用。

如何使用SAM2VideoPredictor

让我们通过一个简单示例来学习如何追踪视频中的对象

步骤1:加载视频预测器

与图像预测器类似,我们需要准备SAM2VideoPredictor,这包括加载核心SAM-2模型并将其封装为视频预测工具。

python 复制代码
from sam2.build_sam import build_sam2_video_predictor_hf
import torch

# 指定设备(通常为NVIDIA GPU的"cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 加载专为视频追踪训练的SAM-2模型
# "facebook/sam2-hiera-base-plus"是一个示例模型ID。
predictor = build_sam2_video_predictor_hf(
    model_id="facebook/sam2-hiera-base-plus",
    device=device
)

说明 :我们使用build_sam2_video_predictor_hf(类似于第一章的build_sam2_hf)加载必要组件。现在,predictor就是我们的视频追踪导演,准备就绪

步骤2:初始化视频追踪项目(inference_state

接下来,我们为导演提供视频。预测器会加载帧并设置"项目文件"(inference_state)。

python 复制代码
import os
import numpy as np
# 假设你有一个名为'my_video_frames'的文件夹,包含JPEG图像
# 例如:my_video_frames/00000.jpg, my_video_frames/00001.jpg等
video_dir = "my_video_frames" # 替换为你的视频帧路径

# 为此示例创建虚拟视频目录和帧
os.makedirs(video_dir, exist_ok=True)
dummy_image = np.zeros((256, 256, 3), dtype=np.uint8)
from PIL import Image
Image.fromarray(dummy_image).save(os.path.join(video_dir, "00000.jpg"))
# 添加另一帧用于追踪
Image.fromarray(dummy_image).save(os.path.join(video_dir, "00001.jpg"))

# 用视频帧初始化追踪状态
inference_state = predictor.init_state(video_path=video_dir)

print(f"视频帧数:{inference_state['num_frames']}")
print(f"视频分辨率:{inference_state['video_height']}x{inference_state['video_width']}")

说明init_state()准备inference_state,加载视频帧(或其路径),确定视频尺寸,并设置存储对象数据和追踪结果的内部字典

通过处理第一帧的图像特征进行"预热",加速后续步骤

步骤3:为对象添加初始提示(点击/掩膜)

现在,我们告诉导演追踪哪个对象。通常在第一帧(索引0)点击或绘制掩膜,并为对象分配唯一ID(如1)

python 复制代码
# 假设我们在帧0的(x=100, y=150)处点击对象
ann_frame_idx = 0
ann_obj_id = 1 # 待追踪对象的唯一ID
points = np.array([[100, 150]], dtype=np.float32) # 点击坐标
labels = np.array([1], np.int32) # 标签1表示前景点

# 将此提示添加到预测器
frame_idx_out, obj_ids_out, masks_out = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

print(f"帧{frame_idx_out}的掩膜(对象{obj_ids_out})形状:{masks_out.shape}")

说明add_new_points_or_box()接收你的提示(此处为点击),并将其应用到指定帧和对象。内部调用类似SAM2ImagePredictor的组件,在单帧 中分割对象,结果(掩膜)存储在inference_state中,作为对象1在帧0的起点。

步骤4:在视频中传播分割

最后,我们让导演开始追踪!propagate_in_video方法会逐帧处理整个视频,利用记忆跟随对象。

python 复制代码
all_tracked_masks = {}

# 'propagate_in_video'是一个Python生成器,
# 逐帧生成处理结果。
for frame_idx, obj_ids, video_res_masks in predictor.propagate_in_video(inference_state):
    # 'video_res_masks'包含当前帧所有追踪对象的掩膜,
    # 已调整为原始视频分辨率。
    # 我们可以存储或显示这些掩膜。
    all_tracked_masks[frame_idx] = video_res_masks
    print(f"已处理帧{frame_idx}。掩膜形状:{video_res_masks.shape}")

print(f"成功追踪{len(all_tracked_masks)}帧。")
# 循环结束后,'all_tracked_masks'将包含所有追踪帧的分割对象。

说明propagate_in_video()遍历视频帧。对于每帧,它利用累积的inference_state(包含对象外观和运动历史)预测当前帧的掩膜,更新inference_state并返回结果。这是视频追踪的核心。

技术

让我们深入幕后,了解SAM2VideoPredictor的魔法。

🎢工作流程

SAM2VideoPredictor想象成一位经验丰富的导演,配备智能助手(inference_state)。

  1. **你(用户)**将视频(JPEG图像文件夹)交给导演(SAM2VideoPredictor)。
  2. 导演让助手(inference_state)准备整个视频。助手加载所有帧,并为每个对象和帧创建空文件,同时从第一帧提取初始"精华"(image_features)。
  3. 在初始帧(如帧0)点击目标对象。
  4. 导演处理此点击(类似SAM2ImagePredictor的方式),获取掩膜,并将对象的第一张分割图像存入inference_state,形成"初始外观档案"。导演还使用记忆编码器计算并存储此帧的"记忆特征"。
  5. 你发出指令:"在整个视频中追踪此对象!"(propagate_in_video)。
  6. 对于后续每帧:
    • 导演从inference_state获取对象的最新"外观档案"和"记忆特征"。
    • 结合当前帧图像和对象历史(记忆特征),预测对象的当前位置。此步骤利用强大的SAM2基础模型及其记忆注意力组件。
    • 优化预测,填补小孔,并将新掩膜和更新的"记忆特征"存回inference_state
    • 展示当前帧的分割对象。

这种"预测、更新记忆、保存、移至下一帧"的循环,使SAM2VideoPredictor能够稳健地追踪视频中的对象。

以下是简化的工作流程图:

代码

让我们看看sam2/sam2_video_predictor.py中如何实现这些步骤。

  1. 初始化(init_state

    python 复制代码
    def init_state(self, video_path, **kwargs):
        images, video_height, video_width = load_video_frames(
            video_path=video_path, image_size=self.image_size, **kwargs
        )
        
        inference_state = {
            "images": images,  # 存储所有视频帧
            "num_frames": len(images),
            "video_height": video_height,
            "video_width": video_width,
            "device": self.device,
            "point_inputs_per_obj": {},  # 每帧对象的点击输入
            "mask_inputs_per_obj": {},    # 每帧对象的掩膜输入
            "output_dict_per_obj": {},    # 追踪结果(掩膜、记忆特征)
            "obj_id_to_idx": OrderedDict(),  # 对象ID到内部索引的映射
            "obj_idx_to_id": OrderedDict(),
            "obj_ids": []
        }
        
        # 预热第一帧的图像编码器
        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
        return inference_state

    说明init_state设置inference_state字典,加载视频帧,存储原始尺寸,并初始化对象数据和追踪结果的存储区。它还预处理第一帧以确保模型就绪。

  2. 添加初始提示(add_new_points_or_box

    python 复制代码
    def add_new_points_or_box(self, inference_state, frame_idx, obj_id, points=None, labels=None, **kwargs):
        obj_idx = self._obj_id_to_idx(inference_state, obj_id)  # 对象ID映射
        
        # 存储点击输入
        inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = concat_points(
            inference_state["point_inputs_per_obj"][obj_idx].get(frame_idx, None), points, labels
        )
    
        # 运行单帧推理(类似SAM2ImagePredictor)
        current_out, _ = self._run_single_frame_inference(
            inference_state=inference_state,
            output_dict=inference_state["output_dict_per_obj"][obj_idx],
            frame_idx=frame_idx,
            batch_size=1,
            is_init_cond_frame=True,  # 标记为初始输入帧
            point_inputs=inference_state["point_inputs_per_obj"][obj_idx][frame_idx],
            mask_inputs=None,
            reverse=False,
            run_mem_encoder=False,  # 记忆编码器稍后运行
            prev_sam_mask_logits=None,
        )
        
        # 临时存储当前输出掩膜
        inference_state["temp_output_dict_per_obj"][obj_idx]["cond_frame_outputs"][frame_idx] = current_out
    
        # 返回调整到原始视频分辨率的掩膜
        consolidated_out = self._consolidate_temp_output_across_obj(inference_state, frame_idx, is_cond=True)
        _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
        return frame_idx, inference_state["obj_ids"], video_res_masks

    说明 :此方法将对象ID映射到内部索引,存储点击数据,并调用_run_single_frame_inference生成单帧掩膜。结果掩膜临时存入inference_state,并返回调整后的掩膜。

  3. 传播分割(propagate_in_video

    python 复制代码
    def propagate_in_video(self, inference_state, start_frame_idx=None, **kwargs):
        self.propagate_in_video_preflight(inference_state)  # 预处理初始输入
        
        for frame_idx in processing_order:  # 按顺序处理每帧
            pred_masks_per_obj = []
            for obj_idx in range(batch_size):
                if frame_idx in obj_output_dict["cond_frame_outputs"]:
                    current_out = obj_output_dict["cond_frame_outputs"][frame_idx]
                else:
                    # 运行追踪推理
                    current_out, pred_masks = self._run_single_frame_inference(
                        inference_state=inference_state,
                        output_dict=obj_output_dict,
                        frame_idx=frame_idx,
                        batch_size=1,
                        is_init_cond_frame=False,
                        point_inputs=None,
                        mask_inputs=None,
                        reverse=False,
                        run_mem_encoder=True,  # 启用记忆编码器更新记忆
                    )
                    obj_output_dict["non_cond_frame_outputs"][frame_idx] = current_out
                
                pred_masks_per_obj.append(pred_masks)
    
            # 返回调整后的掩膜
            all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
            _, video_res_masks = self._get_orig_video_res_output(inference_state, all_pred_masks)
            yield frame_idx, obj_ids, video_res_masks

    说明 :此方法预处理初始输入后,逐帧追踪对象。对于每帧,它调用_run_single_frame_inference(启用记忆编码器),利用对象的历史记忆和当前帧特征预测新掩膜。结果掩膜调整后返回,形成连续追踪。

总结

SAM2VideoPredictor是一款复杂但用户友好的工具,将SAM-2强大的分割能力从单张图像扩展到整个视频

通过管理持久的inference_state并利用历史信息逐帧智能传播对象分割,将手动视频标注转变为高效的自动化过程。它是处理动态对象时间维度的理想解决方案。

现在,我们已经了解了SAM2ImagePredictorSAM2VideoPredictor如何提供高级API与SAM-2交互,接下来让我们深入探索其核心智能:SAM2基础模型

下一章:SAM2基础模型

相关推荐
YangYang9YangYan3 小时前
大专计算机技术专业就业方向:解读、规划与提升指南
大数据·人工智能·数据分析
mwq301233 小时前
GPT监督微调SFT:在损失计算中屏蔽指令和填充 Token
人工智能
富唯智能3 小时前
智慧物流新篇章:复合机器人重塑装配车间物料配送
人工智能·工业机器人·复合机器人
递归不收敛4 小时前
四、高效注意力机制与模型架构
人工智能·笔记·自然语言处理·架构
扫地的小何尚4 小时前
AI创新的火花:NVIDIA DGX Spark开箱与深度解析
大数据·人工智能·spark·llm·gpu·nvidia·dgx
AI科技星4 小时前
接近光速运动下的光速不变性:基于张祥前统一场论的推导与验证
数据结构·人工智能·经验分享·算法·计算机视觉
864记忆4 小时前
opencv图像预处理函数的功能与作用
人工智能·opencv·计算机视觉
m0_650108244 小时前
【论文精读】VITRON:统一的像素级视觉大语言模型
计算机视觉·论文精读·视觉理解·视觉生成·多模态通用模型·视觉分割·视觉编辑