SAM2跟踪的理解14——mask decoder

目录

一、前言

四、MaskDecoder

[4.2 回到MaskDecoder.forward](#4.2 回到MaskDecoder.forward)

五、回到_forward_sam_heads

[5.1 obj_ptr = self.obj_ptr_proj(sam_output_token)](#5.1 obj_ptr = self.obj_ptr_proj(sam_output_token))

[5.2 对象记忆的"褪色与重置"机制](#5.2 对象记忆的"褪色与重置"机制)

什么意思?

软模式和硬模式?

这个到底学习什么?


一、前言

上一篇中我们走出了MaskDecoder.forward的self._dynamic_multimask_via_stability函数,这一篇我们看看接下来发生什么,其实就是选择了一个稳定掩码之后返回这个掩码masks、掩码分数iou_pred、目标存在性分数object_score_logits,还有把之前我们得到的(1,4,256)的mask_tokens_out取第0个token给到变量sam_tokens_out作为物体记忆token也返回 。

masks: torch.Size([1, 1, 256, 256])

iou_pred:tensor([[0.9436]], device='cuda:0') 数值只是示例

sam_tokens_out: torch.Size([1, 1, 256])

object_score_logits: torch.Size([1, 1]) 即 tensor([[24.0962]], device='cuda:0') 数值只是示例

然后就回到_forward_sam_heads,主要操作就是把256的mask上采样到1024,然后就是之前的(1,4,256)的mask_tokens_out取第0个token得到的(1,1,256)的sam_output_token,然后取第0个得到(1,256)的sam_output_token,就是那4个掩码token的第0个。然后经过MLP变成obj_ptr,然后如果物体消失的话obj_ptr要加self.no_obj_ptr(self.no_obj_ptr是一个(1,256)的可学习的矩阵),obj_ptr就是物体的记忆指纹,如果物体消失的话记忆要归零。

最后返回的就是低分辨的256的mask,高分辨率的1024的mask,ious就是最佳mask对应的置信度,obj_ptr就是用于记忆物体的,object_score_logits就是用于判断物体是否存在的(如果大于0就是存在)

low_res_multimasks: torch.Size([1, 1, 256, 256])

high_res_multimasks: torch.Size([1, 1, 1024, 1024])

ious: tensor([[0.9436]], device='cuda:0')

low_res_masks: torch.Size([1, 1, 256, 256])

low_res_masks: torch.Size([1, 1, 1024, 1024])

obj_ptr: torch.Size([1, 256])

object_score_logits: object_score_logits: tensor([[24.0962]], device='cuda:0')

感觉现在还没有完全理解SAM2跟踪对物体消失的处理方式,而且我觉得需要新开一篇文章将SAM2记忆机制与DeepSeek OCR的记忆机制进行比对。

四、MaskDecoder

4.2 回到MaskDecoder.forward

sam2/modeling/sam/mask_decoder.py

python 复制代码
def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
        repeat_image: bool,
        high_res_features: Optional[List[torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict masks given image and prompt embeddings.
        Arguments:
          image_embeddings (torch.Tensor): the embeddings from the image encoder
          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
          multimask_output (bool): Whether to return multiple masks or a single
            mask.
        Returns:
          torch.Tensor: batched predicted masks
          torch.Tensor: batched predictions of mask quality
          torch.Tensor: batched SAM token for mask output
        """
 
        # 输入:   
        # image_embeddings: torch.Size([1, 256, 64, 64])
        # image_pe:torch.Size([1, 256, 64, 64])
        # sparse_embeddings: torch.Size([1, 3, 256])
        # dense_embeddings : torch.Size([1, 256, 64, 64])
        # multimask_output:False
        # repeat_image: False
        # high_res_features:[
        #         torch.Size([1, 32, 256, 256]),
        #         torch.Size([1, 64, 128, 128])
        # ]
 
        # >>> 1. 先把所有 embedding 喂给 mask decoder,拿到 4 个输出
        masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
            repeat_image=repeat_image,
            high_res_features=high_res_features,
        )
        # mask: torch.Size([1, 4, 256, 256])
        # iou_pred: tensor([[0.9436, 0.9098, 0.9337, 0.9457]], device='cuda:0')
        # mask_tokens_out: torch.Size([1, 4, 256])
        # object_score_logits: torch.Size([1, 1])  比如 tensor([[24.0962]], device='cuda:0')
 
        # Select the correct mask or masks for output
        # multimask_output:False
        if multimask_output:
            # >>> 2-a. 训练/多 mask 模式:只要后 3 个 mask(跳过第 0 个"默认" mask)
            masks = masks[:, 1:, :, :]
            iou_pred = iou_pred[:, 1:]
            # iou_pred: tensor([[0.9436]], device='cuda:0')
 
        # self.dynamic_multimask_via_stability:True  self.training:False
        elif self.dynamic_multimask_via_stability and not self.training:
            # >>> 2-b. 测试阶段且开 stability 筛 mask:自动挑一个最稳的
            # masks: torch.Size([1, 4, 256, 256])
            # iou_pred: torch.Size([1, 4])
            masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
            # masks: torch.Size([1, 1, 256, 256])
            # iou_scores_out: tensor([[0.9436]], device='cuda:0')
 
        else:
            # >>> 2-c. 默认单 mask 模式:直接取第 0 个通道
            masks = masks[:, 0:1, :, :]
            iou_pred = iou_pred[:, 0:1]
 
        # multimask_output: False  self.use_multimask_token_for_obj_ptr:True
        if multimask_output and self.use_multimask_token_for_obj_ptr:
            # >>> 3-a. 多 mask 且要把 token 当 object pointer:用后 3 个 token
            sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c] shape
        else:
            # >>> 3-b. 其余情况(包括单 mask)一律用第 0 个 token 当"物体记忆"
            # 获取掩码输出token。这里我们*总是*使用单掩码输出的token。
            # 在测试时,即使我们在1次点击后进行跟踪(且设置multimask_output=True),
            # 这里仍然使用单掩码token。理由是:在训练期间我们总是在多次点击后才进行跟踪,
            # 因此训练过程中看到的过往token始终是单掩码token(我们将它作为对象记忆token)。
            # mask_tokens_out: torch.Size([1, 4, 256])
            sam_tokens_out = mask_tokens_out[:, 0:1]  # [b, 1, c] shape
            # sam_tokens_out: torch.Size([1, 1, 256])
 
        # Prepare output
        # masks: torch.Size([1, 1, 256, 256])
        # iou_pred:tensor([[0.8732]], device='cuda:0')
        # sam_tokens_out: torch.Size([1, 1, 256])
        # object_score_logits: torch.Size([1, 1])  即 tensor([[24.0962]], device='cuda:0')
        return masks, iou_pred, sam_tokens_out, object_score_logits

代码整体流程一句话总结

用 predict_masks 一次性生成 4 组 mask 及其对应 IoU、token、objectness。

根据 multimask_output 标志和 dynamic_multimask_via_stability 策略,决定到底留几个 mask:

训练/多 mask 模式 → 留 3 个;

测试开 stability → 自动挑 1 个最稳的;

其余 → 直接拿第 0 个。

再按同样逻辑挑一个(或 3 个)token 作为后续跟踪用的"物体记忆"。

把最终 mask、IoU、token、objectness 分数一起返回。

4.2.2 sam_tokens_out是什么?

你困惑的核心是:sam_tokens_out到底是什么?为什么要这样选?

这是SAM在视频跟踪 场景下的关键设计,我把它叫 "物体身份证"


1. sam_tokens_out的本质

复制代码
sam_tokens_out = tensor([[[ 0.2, -0.5, 1.1, ..., 0.8 ]]])  # 形状: [1, 1, 256]
  • 这是一个 256维的向量 ,是模型对"这个物体是谁"的抽象编码

  • 就像人的身份证号,同一物体在不同帧中保持相同的身份特征

  • 在视频跟踪时,这个token会被保存下来,在下一帧作为"记忆"提示模型


2. 为什么总是用第0个token?(最关键)

复制代码
# 即使 multimask_output=True,这里仍然用第0个
sam_tokens_out = mask_tokens_out[:, 0:1]

这是SAM刻意训练的结果

阶段 发生了什么 记忆token是谁
训练时 用户点击物体5次 → 模型输出单掩码 → 保存token 第0个token
测试时 第1帧:用户点击1次 → 模型输出多掩码 仍然用第0个token做记忆
跟踪时 第2-N帧:用保存的token识别物体,无需再点击 第0个token持续使用

核心逻辑

  • 模型在训练时只见过第0个token代表物体身份

  • 如果测试时突然用第1-3个token,相当于给模型一个它不认识的新身份证

  • 结果:身份混乱,跟踪会漂移(drift)


3. 形象的比喻:海关安检

复制代码
# 第1帧(首次见到物体):
海关(模型)问:"你是谁?"
游客(物体)拿出护照(token 0):"我是张三,护照号12345"
海关记录:张三 → 12345

# 第2帧(跟踪):
海关问:"你是谁?"
游客拿出护照(token 0):"我是张三,护照号12345"
海关核对记录:"对上了,是同一个人,放行"

# 如果游客拿出token 1的护照:
"我是张三,护照号67890"
海关:"???我记录的是12345,你换了护照,不是同一个人"

问题:为什么解码掩码用多掩码,但身份用单掩码?

复制代码
# 解码掩码(当前帧"长什么样")
masks = multimask_logits  # 候选1: 全身, 候选2: 头部, 候选3: 尾巴

# 保存身份("你是谁")
sam_tokens_out = token_0  # 不管长什么样,身份证始终是0号

分离设计

  • 掩码 :描述外观(可变,多候选)

  • Token :描述身份(固定,单身份证)


4. 什么时候用第1-3个token?

复制代码
if multimask_output and self.use_multimask_token_for_obj_ptr:
    sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c]

这只在纯多掩码输出模式 下启用,不用于跟踪。比如:

  • 用户只想看当前帧的多个可能结果

  • 不保存身份信息,看完就丢弃


总结

sam_tokens_out 就是物体的"记忆指纹",为什么选第0个?

因为SAM在训练时就是这么教的 :"第0个token,你给我记住,你就是物体的唯一身份证,不管后续有多少个候选掩码,身份不能乱。"

这就是为什么注释里强调 "we always use the token for single mask output" ------这是SAM能稳定跟踪的秘密武器

五、回到_forward_sam_heads

sam2/modeling/sam2_base.py

python 复制代码
@torch.inference_mode()
def _forward_sam_heads(
    self,
    backbone_features,
    point_inputs=None,
    mask_inputs=None,
    high_res_features=None,
    multimask_output=False,
):
    """
    完整的 SAM 风格「提示编码 + 掩膜解码」一条龙:
    把 backbone 特征与用户提示(点/框/mask)一起送 SAM,
    输出多组或单组掩膜 logits、IoU 估计、对象指针等。
    参数:
        backbone_features (Tensor):
            已融合记忆的图像嵌入,形状 (B, C, H, W),
            其中 H=W=sam_image_embedding_size(默认 64)。
        point_inputs (dict | None):
            {"point_coords": (B, P, 2), "point_labels": (B, P)}
            坐标为**绝对像素**(会在内部归一化),label: 1=前, 0=背, -1=pad。
        mask_inputs (Tensor | None):
            低分辨率掩膜提示 (B,1,H',W'),与点提示互斥。
        high_res_features (list[Tensor] | None):
            额外两层更高分辨率特征 [4H,4W] 与 [2H,2W],供 decoder refine 边缘。
        multimask_output (bool):
            True → 输出 3 候选掩膜 + 3 IoU;False → 1 掩膜 + 1 IoU。
    返回:
        tuple:
            0. low_res_multimasks -- (B,M,H/4,W/4)  低分多掩膜 logits
            1. high_res_multimasks -- (B,M,H,W)     高分多掩膜 logits
            2. ious -- (B,M)                        各掩膜 IoU 估计
            3. low_res_masks -- (B,1,H/4,W/4)       **最佳**低分掩膜
            4. high_res_masks -- (B,1,H,W)          **最佳**高分掩膜
            5. obj_ptr -- (B,C)                     对象指针(用于记忆)
            6. object_score_logits -- (B,)          对象出现置信度(可软可硬)
    """
    # backbone_features: torch.Size([1, 256, 64, 64])
    # point_inputs: {
    #                   'point_coords': tensor([[[421.3333, 540.4445], 
    #                                           [654.9333, 921.6000]]], device='cuda:0'),
    #                   'point_labels': tensor([[2, 3]], 
    #                                            device='cuda:0', dtype=torch.int32)
    #               }
    # mask_inputs:None
    # high_res_features:[
    #                        torch.Size([1, 32, 256, 256]), 
    #                        torch.Size([1, 64, 128, 128])
    #                   ]
    # multimask_output: False
 
    B = backbone_features.size(0)
    # B: 1
 
    device = backbone_features.device
    # device: device(type='cuda', index=0) 
 
    # --- 1. 输入断言:尺寸必须匹配 SAM 预设 ---
    assert backbone_features.size(1) == self.sam_prompt_embed_dim
    assert backbone_features.size(2) == self.sam_image_embedding_size
    assert backbone_features.size(3) == self.sam_image_embedding_size
 
    # --- 2. 构造点提示 ---
    if point_inputs is not None:
        sam_point_coords = point_inputs["point_coords"]  # (B, P, 2)
        # sam_point_coords: torch.Size([1, 2, 2])
        # sam_point_coords: tensor([[[421.3333, 540.4445], 
        #                             [654.9333, 921.6000]]], device='cuda:0'),
 
        sam_point_labels = point_inputs["point_labels"]  # (B, P)
        # sam_point_labels: torch.Size([1, 2])
        # sam_point_labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
 
        assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
    else:
        # 无点提示时,用 1 个 pad 点(label=-1)占位,保证 prompt encoder 正常 forward
        sam_point_coords = torch.zeros(B, 1, 2, device=device)
        sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
 
    # --- 3. 构造掩膜提示 ---
    # mask_inputs: None     
    if mask_inputs is not None:
        # 若外部 mask 分辨率不符,先双线性下采样到 prompt encoder 期望尺寸
        assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
        if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
            sam_mask_prompt = F.interpolate(
                mask_inputs.float(),
                size=self.sam_prompt_encoder.mask_input_size,
                align_corners=False,
                mode="bilinear",
                antialias=True,  # 抗锯齿,减少下采样 aliasing
            )
        else:
            sam_mask_prompt = mask_inputs
    else:
        # 无 mask 时,prompt encoder 内部会加 learned `no_mask_embed`
        sam_mask_prompt = None
 
    # --- 4. 送进 SAM 提示编码器 ---
    sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
        points=(sam_point_coords, sam_point_labels),
        boxes=None,  # 框提示在外部已转成 2 点,不再走这里
        masks=sam_mask_prompt,
    )
    # sparse_embeddings: torch.Size([1, 3, 256])
    # dense_embeddings: torch.Size([1, 256, 64, 64]) 
 
    # backbone_features: torch.Size([1, 256, 64, 64])
    # multimask_output: False
    # high_res_features:[
    #                        torch.Size([1, 32, 256, 256]), 
    #                        torch.Size([1, 64, 128, 128])
    #                   ]
    
    # --- 5. 送进 SAM mask 解码器 ---
    (
        low_res_multimasks,   # (B, M, H/4, W/4)  M=3 or 1
        ious,                 # (B, M)             IoU 估计
        sam_output_tokens,    # (B, M, C)          解码器输出 token
        object_score_logits,  # (B,)               对象出现/消失 logits
    ) = self.sam_mask_decoder(
        image_embeddings=backbone_features,
        image_pe=self.sam_prompt_encoder.get_dense_pe(),  # 固定 2D 位置编码
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
        repeat_image=False,  # image 已 batch,无需再 repeat
        high_res_features=high_res_features,  # 高分特征供 refine 边缘
    )
    # low_res_multimasks: torch.Size([1, 1, 256, 256])
    # ious: tensor([[0.9436]], device='cuda:0')
    # sam_output_tokens: torch.Size([1, 1, 256])
    # object_score_logits: torch.Size([1, 1])  即 tensor([[24.0962]], device='cuda:0')
 
    # --- 6. 对象 score 后处理:若模型预测"无对象",把掩膜 logits 置为 NO_OBJ_SCORE ---
    # self.pred_obj_scores: True
    if self.pred_obj_scores:
        # object_score_logits: tensor([[24.0962]], device='cuda:0') 
        is_obj_appearing = object_score_logits > 0  # 硬阈值
        # is_obj_appearing: tensor([[True]], device='cuda:0')

        # NO_OBJ_SCORE: -1024.0
        # low_res_multimasks: torch.Size([1, 1, 256, 256])
        # 记忆用掩膜必须**硬**选择:有对象才保留,否则置极大负值
        low_res_multimasks = torch.where(
            is_obj_appearing[:, None, None],
            low_res_multimasks,
            NO_OBJ_SCORE,
        )
 
    # --- 7. 数据类型转换:bf16/fp16 -> fp32(老版本 PyTorch interpolate 不支持 bf16)---
    low_res_multimasks = low_res_multimasks.float()
    # 上采样到图像原分辨率(stride=1)
    # self.image_size: 1024
    high_res_multimasks = F.interpolate(
        low_res_multimasks,
        size=(self.image_size, self.image_size),
        mode="bilinear",
        align_corners=False,
    )
    # high_res_multimasks: torch.Size([1, 1, 1024, 1024])
 
    # --- 8. 选取最佳掩膜 ---
    # sam_output_tokens: torch.Size([1, 1, 256])
    sam_output_token = sam_output_tokens[:, 0]  # 默认取第 1 个 token(单掩膜时即自身)
    # sam_output_token :torch.Size([1, 256])

    # multimask_output: False
    if multimask_output:
        # 多掩膜时,选 IoU 估计最高的那个
        best_iou_inds = torch.argmax(ious, dim=-1)  # (B,)
        batch_inds = torch.arange(B, device=device)
        low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
        high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
        # 若解码器输出了多个 token,同样要选最佳
        if sam_output_tokens.size(1) > 1:
            sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
    else:
        # 单掩膜时,最佳即唯一
        low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
        # low_res_masks: torch.Size([1, 1, 256, 256])
        # low_res_masks: torch.Size([1, 1, 1024, 1024])
 
    # --- 9. 从最佳 token 提取对象指针(用于记忆)---
    # sam_output_token: torch.Size([1, 256])
    obj_ptr = self.obj_ptr_proj(sam_output_token)  # (B, C)
    # obj_ptr: torch.Size([1, 256])
 
    # --- 10. 对象指针后处理:若模型认为"无对象",指针也被削弱或替换 ---
    # self.pred_obj_scores: True
    if self.pred_obj_scores:
        # self.soft_no_obj_ptr: False
        if self.soft_no_obj_ptr:
            # 软削弱:用 sigmoid 概率加权
            assert not self.teacher_force_obj_scores_for_mem
            lambda_is_obj_appearing = object_score_logits.sigmoid()
        else:
            # 硬削弱:0/1 加权
            # is_obj_appearing: tensor([[True]], device='cuda:0')
            lambda_is_obj_appearing = is_obj_appearing.float()
            # lambda_is_obj_appearing: tensor([[1.]], device='cuda:0')
 
        # self.fixed_no_obj_ptr: True
        if self.fixed_no_obj_ptr:
            obj_ptr = lambda_is_obj_appearing * obj_ptr
            # obj_ptr: torch.Size([1, 256])

        # 剩余权重用"无对象指针"补齐,保证指针和为 1
        obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
        # obj_ptr: torch.Size([1, 256])
 
    # low_res_multimasks: torch.Size([1, 1, 256, 256])
    # high_res_multimasks: torch.Size([1, 1, 1024, 1024])
    # ious: tensor([[0.9436]], device='cuda:0')
    # low_res_masks: torch.Size([1, 1, 256, 256])
    # low_res_masks: torch.Size([1, 1, 1024, 1024])
    # obj_ptr: torch.Size([1, 256])
    # object_score_logits: object_score_logits: tensor([[24.0962]], device='cuda:0') 

    # --- 11. 返回打包结果 ---
    return (
        low_res_multimasks,      # 0 低分多掩膜
        high_res_multimasks,     # 1 高分多掩膜
        ious,                    # 2 各掩膜 IoU 估计
        low_res_masks,           # 3 最佳低分掩膜
        high_res_masks,          # 4 最佳高分掩膜
        obj_ptr,                 # 5 对象指针(记忆 key/query)
        object_score_logits,     # 6 对象出现置信度 logits
    )

_forward_sam_heads 核心总结:视频对象跟踪的"推理引擎"

这是一个纯推理 函数(@torch.inference_mode()),负责将用户提示(点击、框、前一帧掩码)与历史记忆融合,生成当前帧的掩码预测和对象身份指针。


核心逻辑:5步流水线

第1-3步:输入准备(预处理)

复制代码
# 输入:backbone特征 + 用户提示(点/掩码)+ 高分辨率细节特征
# 输出:标准化的稀疏/密集嵌入
  • 点提示 :坐标归一化到 [0,1],标签 1=前景, 0=背景, -1=填充

  • 掩码提示 :下采样到 256×256 输入尺寸

  • 无输入时:用默认填充值保证流程不中断

第4-5步:SAM核心推理

复制代码
# 输入:预处理后的嵌入
# 输出:4个候选掩码 + IoU分数 + 4个token + 对象分数
  • 提示编码器:将用户点击/掩码编码成256维向量

  • 掩码解码器 :生成低分辨率 掩码(64×64)和高分辨率细节

  • 多候选机制:默认输出4个token(0=单掩码主输出,1-3=多掩码备选)

第6-7步:后处理与分辨率提升

复制代码
# 输入:低分辨率掩码
# 输出:原始分辨率掩码 + 处理无对象情况
  • 对象分数过滤 :如果模型认为"没有物体",掩码置为极大负值(NO_OBJ_SCORE

  • 上采样 :双线性插值到原始图像尺寸(1024×1024

第8步:掩码选择策略

复制代码
# 关键决策:单掩码 vs 多掩码
  • 单掩码模式multimask_output=False):直接用token 0的掩码

  • 多掩码模式multimask_output=True):从4个候选中选IoU最高的那个

  • **dynamic_multimask_via_stability **:如果token 0不稳定(轮廓模糊),自动回退到最佳多掩码

第9-10步:生成"对象身份证"(核心创新)

复制代码
# 输入:最佳token(256维)
# 输出:对象指针(256维,用于记忆机制)
  • 对象指针obj_ptr):从最佳token投影得到的身份特征向量

  • 跨帧记忆:这个指针会被存入记忆库,下一帧作为"我是谁"的提示

  • 动态加权 :根据对象分数(object_score_logits)削弱指针强度

    • soft_no_obj_ptr=True:用sigmoid概率软性削弱

    • fixed_no_obj_ptr=True:用0/1硬性削弱


设计哲学:为什么这样设计?

设计点 原因 类比
总是用token 0做记忆 训练时只见过token 0代表身份,其他token是"短期临时工" 护照号 vs 临时通行证
多掩码候选 处理歧义点击(边界、重叠区域) 专家会诊,取最佳方案
对象指针 将掩码"外观"与物体"身份"解耦 人脸(可变化) vs 身份证号(不变)
动态削弱 物体消失时,记忆应衰减避免干扰 记忆随时间褪色

输入输出全貌

输入

  • backbone_features:融合历史记忆的图像特征 [1, 256, 64, 64]

  • point_inputs:用户点击坐标和标签

  • high_res_features:边缘细节特征 [1, 32, 256, 256][1, 64, 128, 128]

输出(7个值):

  1. low_res_multimasks:4个低分辨率候选 [1, 4, 256, 256]

  2. high_res_multimasks:4个高分辨率候选 [1, 4, 1024, 1024]

  3. ious:4个IoU分数 [1, 4]

  4. low_res_masks:最佳低分辨率掩码 [1, 1, 256, 256]

  5. high_res_masks:最佳高分辨率掩码 [1, 1, 1024, 1024]

  6. obj_ptr对象指针 [1, 256]用于下一帧记忆

  7. object_score_logits:对象出现置信度 [1, 1]


一句话总结

这个函数是视频SAM的推理中枢 :它接收带记忆的图像用户提示 ,输出当前帧掩码 的同时,生成一个256维的身份指针,告诉记忆系统"这个物体长这样,下次见到要认出来"。

5.1 obj_ptr = self.obj_ptr_proj(sam_output_token)

这个会进入MLP

python 复制代码
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/
 
transformer/transformer_predictor.py # noqa
 
class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        activation: nn.Module = nn.ReLU,
        sigmoid_output: bool = False,
    ) -> None:
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.sigmoid_output = sigmoid_output
        self.act = activation()
 
    def forward(self, x):
        # x: torch.Size([1, 256])
 
        for i, layer in enumerate(self.layers):
            # self.num_layers:3
            x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
            # i=0 x: torch.Size([1, 256])
            # i=1 x: torch.Size([1, 256])
            # i=2 x: torch.Size([1, 256])
 
        # self.sigmoid_output: False
        if self.sigmoid_output:
            x = F.sigmoid(x)
         
        return x

我们去sam2/modeling/sam2_base.py,在这个文件搜索self.obj_ptr_proj,发现它在_build_sam_heads函数中。而且确实在类SAM2Base的初始化中存在self._build_sam_heads()。

if self.use_obj_ptrs_in_encoder:

a linear projection on SAM output tokens to turn them into object pointers

self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)

if self.use_mlp_for_obj_ptr_proj:
self.obj_ptr_proj = MLP(

# self.hidden_dim:256
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
)

else:

self.obj_ptr_proj = torch.nn.Identity()

5.2 对象记忆的"褪色与重置"机制

--- 9. 从最佳 token 提取对象指针(用于记忆)---

sam_output_token: torch.Size([1, 256])

obj_ptr = self.obj_ptr_proj(sam_output_token) # (B, C)

obj_ptr: torch.Size([1, 256])

--- 10. 对象指针后处理:若模型认为"无对象",指针也被削弱或替换 ---

self.pred_obj_scores: True

if self.pred_obj_scores:

self.soft_no_obj_ptr: False

if self.soft_no_obj_ptr:

软削弱:用 sigmoid 概率加权

assert not self.teacher_force_obj_scores_for_mem

lambda_is_obj_appearing = object_score_logits.sigmoid()

else:

硬削弱:0/1 加权

is_obj_appearing: tensor([[True]], device='cuda:0')

lambda_is_obj_appearing = is_obj_appearing.float()

lambda_is_obj_appearing: tensor([[1.]], device='cuda:0')

self.fixed_no_obj_ptr: True

if self.fixed_no_obj_ptr:

obj_ptr = lambda_is_obj_appearing * obj_ptr

obj_ptr: torch.Size([1, 256])

剩余权重用"无对象指针"补齐,保证指针和为 1

obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

obj_ptr: torch.Size([1, 256])

什么意思?

这段代码是对象记忆的"褪色与重置"机制,防止物体消失后残留的错误记忆干扰后续跟踪。


核心思想:物体消失时,记忆应该"归零"或"模糊化"

obj_ptr 是物体的"记忆指纹",但如果物体离开画面,这个指纹会变成 "幽灵记忆" ------下一帧会错误地匹配背景或其他物体。这段代码就是动态削弱或替换这个指纹


分步拆解

第1步:提取初始对象指针

复制代码
obj_ptr = self.obj_ptr_proj(sam_output_token)  # [1, 256]
  • sam_output_token:最佳掩码对应的token(物体的"身份编码")

  • obj_ptr_proj:线性投影层,将token转换为记忆指针(用于下一帧检索)

第2步:计算"存在权重" lambda_is_obj_appearing

复制代码
# 硬模式(当前配置)
lambda_is_obj_appearing = is_obj_appearing.float()  # 0.0 或 1.0
# 结果:tensor([[1.0]]) 表示"物体存在"

lambda 的含义:

  • λ = 1.0 :物体明确存在 → 记忆指针100%保留

  • λ = 0.0 :物体消失 → 记忆指针被削弱或替换


第3步:削弱或替换指针(关键操作)

python 复制代码
if self.fixed_no_obj_ptr:  # True
    obj_ptr = lambda_is_obj_appearing * obj_ptr  # 物体消失时乘以0,直接清零

# 用"空指针"补全剩余权重
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

两种模式对比

模式 soft_no_obj_ptr=False (硬) soft_no_obj_ptr=True (软)
λ计算 is_obj_appearing.float() (0/1) object_score_logits.sigmoid() (0~1连续值)
物体存在 λ=1.0 → obj_ptr不变 λ≈1.0 → obj_ptr几乎不变
物体消失 λ=0.0 → obj_ptr = 0 + no_obj_ptr λ≈0.0 → obj_ptr很弱 + no_obj_ptr主导
优点 干净利落的开关,无记忆残留 平滑过渡,避免突变抖动
缺点 可能因误判导致记忆突然中断 物体消失后仍残留微弱记忆

什么是 self.no_obj_ptr

这是一个可学习的参数向量 ,形状 [B,256],代表 "没有物体时的默认记忆状态"

类比 :就像数据库里的 "空值(NULL)""默认用户",当物体不存在时,记忆系统应该指向这个"空身份",而不是一个过期的物体身份。


完整流程可视化

场景A:物体正常存在

复制代码
is_obj_appearing = True → λ = 1.0

obj_ptr = 1.0 * "猫的指纹" + (1-1.0) * "空指针"
        = "猫的指纹"  # 完全保留记忆

场景B:物体消失(硬模式)

复制代码
is_obj_appearing = False → λ = 0.0

obj_ptr = 0.0 * "猫的指纹" + (1-0.0) * "空指针"
        = "空指针"  # 彻底清空记忆,避免"猫"的幽灵残留

场景C:物体可能消失(软模式)

复制代码
object_score_logits = -2.0 → λ ≈ 0.12

obj_ptr = 0.12 * "猫的指纹" + 0.88 * "空指针"
        = "模糊的、几乎空的指针"  # 记忆很弱,但仍有微弱痕迹

为什么要这么设计?

防止"记忆污染"

如果没有这个机制:

复制代码
# 第1帧:跟踪猫
obj_ptr = "猫的指纹"

# 第10帧:猫跑出画面,但记忆未清空
# 第11帧:出现一只狗
obj_ptr 仍然 ≈ "猫的指纹" → 模型误以为狗是猫 → **跟踪失败**

有了这个机制:

复制代码
# 第10帧:模型检测不到猫 → λ=0
obj_ptr = "空指针"

# 第11帧:出现狗
obj_ptr = "空指针" → 模型需要用户重新提示 → **避免误匹配**

一句话总结

这段代码实现了记忆的"生命周期管理" :物体存在时强化记忆,物体消失时主动遗忘,防止过时的记忆成为下一帧的"幽灵干扰"。

在sam2/modeling/sam2_base.py搜索self.no_obj_ptr,发现在SAM2Base初始化中,

条件:只有当模型需要预测对象分数 AND 在编码器中使用对象指针时,才创建这个参数

if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:

创建一个形状为 [1, hidden_dim] 的可学习参数(默认 hidden_dim=256)

self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))

用截断正态分布重新初始化(均值=0,标准差=0.02)

trunc_normal_(self.no_obj_ptr, std=0.02)
这段代码是初始化"无对象指针" ,让它成为模型的一个可学习参数

为什么这样做?

  1. torch.nn.Parameter

    no_obj_ptr 注册为模型参数,意味着:

    • 它会被优化器更新(可学习

    • 会出现在 model.parameters()

    • 会被保存到检查点(checkpoint)

  2. torch.zeros 然后 trunc_normal_

    • 先清零:确保初始值不会太大,避免训练初期梯度爆炸

    • 再加微小噪声std=0.02 的截断正态分布,让初始值在零附近小幅扰动

    • 目的 :模型能自动学习什么样的"空状态"最合适,而不是人为设定死值

  3. 为什么不是固定值?

    如果 no_obj_ptr 是固定零向量

    复制代码
    self.no_obj_ptr = torch.zeros(1, 256)  # 固定值
    • 模型无法适应不同场景

    • 可能与其他真实对象指针"撞车"(距离太近)

    作为可学习参数,模型会学会:

    • 远离所有真实对象的指针分布

    • 成为记忆空间中的 "无效身份"锚点


类比: learnable "NULL" token

这就像数据库设计中的 "可学习的NULL值"

复制代码
# 固定NULL
NULL = [0, 0, 0, ..., 0]  # 死板,可能与某些数据冲突

# 可学习NULL
NULL = [0.01, -0.02, 0.005, ...]  # 训练后自动调整,确保唯一性

一句话总结

no_obj_ptr 不是随意设定的零向量,而是一个可训练的参数,模型通过反向传播自动学会"当物体不存在时,记忆系统应该处于什么状态",从而更智能地清除过时记忆。

所以其实就是obj_ptr是sam_output_token经过MLP后得到的,如果没消失的话obj_ptr就不用加self.no_obj_ptr,如果消失的话就要加。我知道(1,4,256)的mask_tokens_out取第0个token得到的(1,1,256)sam_output_token,然后取第0个得到(1,256)的sam_output_token,但是这个sam_output_token是啥?是用于判断物体是否存在的token吗?(不是,判断物体是否存在是object_score_logits,这个是那4个输出掩码的第一个)。那为什么要经过MLP变成obj_ptr,然后为什么物体消失要加self.no_obj_ptr,我理解物体是有可能是慢慢消失的,就是物体的一部分慢慢变小直至消失,它是怎么记忆这一点呢?

你问到了视频跟踪最核心的设计,我来逐层拆解这三个关键概念:


1. sam_output_token 到底是什么?

复制代码
# 形状: [1, 256]
sam_output_token = tensor([-0.2, 0.5, 1.2, ..., -0.8])

本质 :这是解码器对"这个物体长什么样"的语义指纹,不是"存在判断器"。

类比:

  • 人脸识别:128维的脸部特征向量(描述五官特征)

  • 不是:一个开关信号(0=没人,1=有人)

它不负责说"物体在不在",只负责说"如果物体在,它的身份特征是什么"。


2. 为什么要经过MLP变成 obj_ptr

复制代码
obj_ptr = self.obj_ptr_proj(sam_output_token)  # 线性投影

关键原因:空间对齐 + 可学习记忆空间

对比 sam_output_token obj_ptr
来源 掩码解码器输出 经过投影层
空间 解码空间(偏外观描述) 记忆空间(偏身份检索)
用途 生成掩码 检索/存储记忆
可学习性 固定语义 可训练调整
  • 翻译官作用:MLP把"外观描述"翻译成"身份ID",让记忆系统能看懂

  • 训练效果 :自动拉近同物体不同帧的 obj_ptr 距离,推远不同物体的距离

如果没有投影

复制代码
# 直接存储 token
记忆库 = { "第1帧": token_0, "第5帧": token_0, "第10帧": token_0 }

# 第11帧检索
查询 = token_0_new
结果 = 找与 token_0_new 最相似的记忆 → 可能匹配错误(token空间≠记忆空间)

3. 物体消失机制:怎么感知"慢慢消失"?

这是独立判断的! 物体消失由 object_score_logits 分支 负责,不是 sam_output_token

工作流程:

第N帧:物体完整

复制代码
object_score_logits = 5.0      # λ = sigmoid(5.0) ≈ 0.99
sam_output_token = "猫的完整身份编码"
obj_ptr = 0.99 * "猫" + 0.01 * "空指针"  # 几乎全保留

第N+5帧:被遮挡50%

复制代码
object_score_logits = 1.0      # λ ≈ 0.73
sam_output_token = "猫的模糊身份"
obj_ptr = 0.73 * "模糊猫" + 0.27 * "空指针"  # 记忆开始模糊

第N+10帧:完全消失

复制代码
object_score_logits = -3.0     # λ ≈ 0.05
sam_output_token = "垃圾编码(解码器乱猜)"
obj_ptr = 0.05 * "垃圾" + 0.95 * "空指针"  # 几乎清空

4. 为什么需要 self.no_obj_ptr

如果物体消失后 obj_ptr 直接清零:

复制代码
# 第10帧:obj_ptr = [0, 0, ..., 0]
# 第11帧:新物体出现
obj_ptr_new = "狗的身份"
记忆检索:找与 [0,0,...,0] 最相似的特征 → 可能匹配到背景 → 跟踪漂移

正确设计

复制代码
# 第10帧:obj_ptr = 学习的"空指针"(远离所有真实物体)
# 第11帧:新物体出现
记忆检索:找不到有效匹配 → 触发"需要用户重新提示" → 避免误匹配

self.no_obj_ptr 就是记忆空间里的"无人区" ,告诉模型:"这里什么都没有,别瞎匹配"。


总结:三个变量的区别

变量 形状 作用 判断是否消失
sam_output_token [256] 身份编码(物体是谁) ❌ 不判断
object_score_logits [1] 存在置信度(物体在不在) ✅ 输出λ
obj_ptr [256] 记忆指针(用于检索) ❌ 但会被λ削弱

关系链

复制代码
物体外观 → sam_output_token → obj_ptr_proj → obj_ptr → 存入记忆
   ↓
object_score_logits → λ → 控制 obj_ptr 的削弱程度

一句话回答

物体消失不是由token判断,而是由独立的 object_score_logits 分支 输出λ值;当λ→0时,obj_ptr 会向学习的"空指针" 渐进靠拢,实现软性遗忘,避免硬清零导致记忆空间污染。

软模式和硬模式?

代码真相:加权和 vs 硬切换

这段代码做的是向量插值,不是"硬清零":

复制代码
# 当 λ=0.8 时
obj_ptr = 0.8 * "猫的指针" + 0.2 * "空指针"  # 混合状态

# 当 λ=0.0 时
obj_ptr = 0.0 * "猫的指针" + 1.0 * "空指针"  # 完全切换

关键看 lambda_is_obj_appearing 的取值方式:


两种模式:硬 vs 软

硬模式(你代码里的配置)

复制代码
self.fixed_no_obj_ptr = True
self.soft_no_obj_ptr = False

# λ = 0.0 或 1.0(离散)
lambda_is_obj_appearing = is_obj_appearing.float()  # tensor([[0.0]]) 或 tensor([[1.0]])

行为

  • 物体存在 :λ=1.0 → obj_ptr = 猫的指针 (纯猫)

  • 物体消失 :λ=0.0 → obj_ptr = 空指针 (纯空)

确实是"突然切换",没有中间态。物体消失的瞬间,记忆指针立刻变成"无效身份"。


软模式

复制代码
self.soft_no_obj_ptr = True

# λ = 0.0 ~ 1.0(连续)
lambda_is_obj_appearing = object_score_logits.sigmoid()  # tensor([[0.73]])

行为

  • 物体清晰 :λ≈0.99 → obj_ptr ≈ 猫的指针

  • 物体模糊 :λ≈0.50 → obj_ptr = 0.5*猫 + 0.5*空 (混合状态)

  • 物体消失 :λ≈0.01 → obj_ptr ≈ 空指针

这才是"慢慢消失" ,通过 object_score_logits 的平滑变化实现渐进过渡。


为什么你看到的是硬模式?

因为 fixed_no_obj_ptr=True默认配置 ,它在论文和官方代码中被验证最稳定。原因:

  1. 避免模糊记忆:软模式可能导致"半透明幽灵"------物体消失后,记忆处于猫和空的叠加态,下一帧可能匹配到背景

  2. 训练更简单:硬开关让梯度更干净,没有中间态的歧义

  3. 实际足够好object_score_logits 本身可以快速响应 ,0.0 和 1.0 之间的切换在时序上是平滑的(因为 logits 是连续变化的)

条件:只有当模型需要预测对象分数 AND 在编码器中使用对象指针时,才创建这个参数

if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:

创建一个形状为 [1, hidden_dim] 的可学习参数(默认 hidden_dim=256)

self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))

用截断正态分布重新初始化(均值=0,标准差=0.02)

trunc_normal_(self.no_obj_ptr, std=0.02)

这个到底学习什么?

这就是 视频记忆系统的"空值锚点" ,它学习的东西非常精妙:


它学习的是:记忆空间中的"无人区"坐标

想象一个256维的向量空间,每个点代表一个可能的物体身份

  • 区域A:猫的指针(如 [0.5, -0.2, ..., 1.0]

  • 区域B:狗的指针(如 [-0.8, 0.3, ..., -0.5]

  • 区域C:车的指针(如 [0.1, 0.9, ..., 0.4]

no_obj_ptr 学习的就是这个空间里"不属于任何物体"的那个点


为什么要 学习 这个?(而不是用固定的零向量)

问题1:零向量可能"太普通"

复制代码
# 如果 no_obj_ptr = [0, 0, ..., 0]
# 猫的指针 = [0.01, -0.02, 0.005, ...]  # 小噪声

距离(猫, 空) = ||[0.01, -0.02, ...]|| ≈ 很小
→ 模型可能分不清"猫"和"空"

问题2:零向量可能"太特殊"

某些激活函数(如ReLU)会让零区域梯度消失,导致"空状态"无法有效传播。

解决方案:可学习的空指针

复制代码
# 训练后,no_obj_ptr 可能变成:
no_obj_ptr = [3.5, -3.5, 3.5, -3.5, ...]  # 远离所有物体分布

距离(猫, 空) = 很大  # 明确分离
距离(狗, 空) = 很大  # 明确分离

它到底怎么学习?(训练动力学)

Step 1: 数据驱动

  • 当视频中物体消失 时,λ → 0,记忆指针被强制靠近 no_obj_ptr

  • 下一帧如果仍然没有物体,模型损失会惩罚错误的匹配

  • 反向传播 :梯度更新 no_obj_ptr,让它成为一个"不会导致误匹配"的状态

Step 2: 竞争学习

复制代码
# 损失函数隐含的竞争关系:
# 猫的指针 ← 靠近"猫的所有历史帧"
# 狗的指针 ← 靠近"狗的所有历史帧"
# 空指针   ← 远离"所有物体",靠近"所有无物体帧"

Step 3: 收敛到最优解

经过大量数据训练后,no_obj_ptr 自动找到:

  • 最佳位置 :记忆空间中距离所有真实物体最远的区域

  • 最佳编码 :能够让下一帧的预测误差最小的向量


类比:学习数据的"聚类中心"

就像K-means聚类,但反着来

  • 普通聚类 :找到数据最密集的中心点

  • no_obj_ptr :找到最远离所有数据外围点

图示:

复制代码
向量空间:
   猫 ●
       \ 
        \ 
         ● 狗
        /
       /
   车 ●
       
   空 ● ← no_obj_ptr 学习到这里(数据荒漠)

为什么用 trunc_normal_ 初始化?

复制代码
trunc_normal_(self.no_obj_ptr, std=0.02)
  • 小方差(0.02):初始值接近0但不精确为0,给优化留出空间

  • 截断分布:防止初始值过大,远离原点,破坏记忆空间结构

  • 从零开始:让模型自己决定"空"应该是什么,而不是预设偏见


一句话总结

no_obj_ptr 学习的是记忆空间中的"数据荒漠"坐标 ------它通过反向传播自动找到一个远离所有真实物体、且最有利于消除误匹配的向量表示,作为物体消失时记忆的"安全锚点"。

相关推荐
HyperAI超神经2 小时前
入选NeurIPS 2025,智源/北大/北邮提出多流控制视频生成框架,基于音频解混实现精确音画同步
人工智能·ai·音视频·视频生成·neurlps 2025
草莓熊Lotso2 小时前
C++ 异常完全指南:从语法到实战,优雅处理程序错误
android·java·开发语言·c++·人工智能·经验分享·后端
yi个名字2 小时前
智能编码新时代:Vibe Coding与MCP驱动的工作流IDE革命
ide·人工智能
IT_陈寒2 小时前
Python性能优化实战:7个让代码提速300%的冷门技巧(附基准测试)
前端·人工智能·后端
熊猫钓鱼>_>2 小时前
多智能体协作:构建下一代高智能应用的技术范式
人工智能·ai·去中心化·wpf·agent·多智能体·multiagent
likeshop 好像科技2 小时前
AI知识库架构深度解析:智能体记忆与学习的智慧核心
人工智能·学习·架构
啊阿狸不会拉杆2 小时前
《数字图像处理》第 12 章 - 图像模式分类
图像处理·人工智能·算法·机器学习·计算机视觉·分类·数据挖掘
Robot侠2 小时前
ROS1从入门到精通 15: 机器人视觉 - 图像处理与计算机视觉
图像处理·人工智能·计算机视觉·机器人·ros·机器人操作系统
Robot侠2 小时前
赋予 AI 记忆:在 RTX 3090 上搭建本地 RAG 知识库问答系统
人工智能·langchain·llm·llama·qwen·rag·chromadb