SAM2跟踪的理解15——第一帧解码器之后

目录

一、前言

二、回到track_step

三、回到_run_single_frame_inference

[3.1 fill_holes_in_mask_scores](#3.1 fill_holes_in_mask_scores)

[3.1.1 get_connected_components](#3.1.1 get_connected_components)

[3.2 _get_maskmem_pos_enc](#3.2 _get_maskmem_pos_enc)


一、前言

SAM跟踪的第一帧我们看到了MaskDecoder.forward出来以后又回到_forward_sam_heads处理了一些像上采样、obj_ptr的操作,然后就会回到了track_step,返回sam_outputs,然后只是把这些解码器输出的东西存到了current_out里面(因为第一帧没有记忆特征,所以这里没有进入记忆编码器),现在current_out里面已经相当多东西了

current_out:

{

"point_inputs": {

'point_coords':tensor([[[418.6667, 543.2889], [656.0000, 924.4445]]], device='cuda:0')

'point_labels': tensor([[2, 3]], device='cuda:0', dtype=torch.int32)

},

"mask_inputs": None,

"pred_masks": torch.Size([1, 1, 256, 256]),

"pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),

"obj_ptr": torch.Size([1, 256]),

"maskmem_features": None,

"maskmem_pos_enc": None

}

随后又会回到_run_single_frame_inference,主要操作就是给掩码的背景区域的孔洞用0.1填充,最后返回:

compact_current_out:{

"maskmem_features": None

"maskmem_pos_enc": None

"pred_masks": torch.Size([1, 1, 256, 256])

"obj_ptr": torch.Size([1, 256])

}

pred_masks_gpu: torch.Size([1, 1, 256, 256])

二、回到track_step

python 复制代码
def track_step(
        self,
        frame_idx,                       # 当前帧的全局序号(从 0 开始)
        is_init_cond_frame,              # 是否为初始条件帧(第一帧或用户重新标注的帧)
        current_vision_feats,            # 当前帧的 backbone 特征列表,长度=层数,每层形状 (HW, B, C)
        current_vision_pos_embeds,       # 对应的位置编码,与 current_vision_feats 一一对应
        feat_sizes,                      # 每层特征图的空间分辨率,如 [(256,256), (128,128)]
        point_inputs,                    # 用户点击/提示,格式 dict{"point_coords": (B,K,2), "point_labels": (B,K)}
        mask_inputs,                     # 用户输入的 mask(可选),形状 (B,1,H,W) 或 None
        output_dict,                     # 全局状态字典,包含 memory bank、已编码的记忆等
        num_frames,                      # 整个视频的总帧数
        track_in_reverse=False,          # 是否按倒序跟踪(demo 中回退播放时使用)
        run_mem_encoder=True,            # 是否要把当前预测结果编码成记忆供后续帧使用
        prev_sam_mask_logits=None,       # 上一帧预测的 low-res mask logits(demo 连续点击时累积用)
    ):
        """
        单帧跟踪函数。
        1. 融合历史记忆与当前视觉特征 -> 得到带记忆的 pix_feat
        2. 用 SAM head 解码出 mask
        3. 可选地把预测 mask 再编码成新的记忆
        返回 current_out,包含本帧所有需要保存/上传的结果。
        """
        # current_vision_feats: [
        #                   torch.Size([65536, 1, 32])
        #                   torch.Size([16384, 1, 64])
        #                   torch.Size([4096, 1, 256])
        # ]
        # current_vision_pos_embeds: [
        #                        torch.Size([65536, 1, 256])
        #                        torch.Size([16384, 1, 256])
        #                        torch.Size([4096, 1, 256])
        #                    ]
        # feat_sizes': [
        #                  (256, 256), 
        #                  (128, 128), 
        #                   (64, 64)
        #              ]
        #]
 
        # ------------------------------------------------------------------
        # 1. 组装本帧的输入字典,方便后续上传或可视化
        # ------------------------------------------------------------------
        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
        # current_out 示例:
        # {
        #   "point_inputs": {
        #       "point_coords": tensor([[[416.5333, 541.3926], 
        #                                       [656.5333, 926.340]]], device='cuda:0'),
        #       "point_labels": tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
        #   },
        #   "mask_inputs": None
        # }
 
        # ------------------------------------------------------------------
        # 2. 取出 backbone 的多层特征,转成 BCHW 供 SAM 的高分辨率分支使用
        #    最后一层特征留给记忆融合,其余层作为 high_res_features
        # ------------------------------------------------------------------
        if len(current_vision_feats) > 1:
            # 将 (HW,B,C) -> (B,C,H,W)
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
            ]
            # high_res_features:[
            #                        torch.Size([1, 32, 256, 256]), 
            #                        torch.Size([1, 64, 128, 128])
            #]
        else:
            high_res_features = None
 
        # ------------------------------------------------------------------
        # 3. 如果用户直接给了 GT mask 且配置要求跳过 SAM,则直接把它当输出
        # ------------------------------------------------------------------
        # mask_inputs:None  use_mask_input_as_output_without_sam:True
        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
            # 取最后一层特征作为像素嵌入
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)          # (B,C,H*W)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])  # (B,C,H,W)
            sam_outputs = self._use_mask_as_output(
                pix_feat, high_res_features, mask_inputs
            )
        else:
            # ------------------------------------------------------------------
            # 4. 正常路径:把当前特征与记忆库融合,得到带历史信息的 pix_feat
            # ------------------------------------------------------------------
            pix_feat_with_mem = self._prepare_memory_conditioned_features(
                frame_idx=frame_idx,
                is_init_cond_frame=is_init_cond_frame,
                current_vision_feats=current_vision_feats[-1:],   # 只用最后一层
                current_vision_pos_embeds=current_vision_pos_embeds[-1:],
                feat_sizes=feat_sizes[-1:],
                output_dict=output_dict,
                num_frames=num_frames,
                track_in_reverse=track_in_reverse,
            )
 
            # ------------------------------------------------------------------
            # 5. 处理 demo 场景:如果之前已有 SAM 预测结果,把它作为 mask prompt
            # ------------------------------------------------------------------
            if prev_sam_mask_logits is not None:
                assert point_inputs is not None and mask_inputs is None, \
                    "prev_sam_mask_logits 仅在有点击且无外部 mask 时生效"
                mask_inputs = prev_sam_mask_logits
 
            # ------------------------------------------------------------------
            # 6. 决定 SAM 是否输出多 mask(初始帧或仅单点击时通常允许多 mask)
            # ------------------------------------------------------------------
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
 
            # ------------------------------------------------------------------
            # 7. 真正调用 SAM 解码头,得到 low/high 分辨率 mask 及 object pointer
            # ------------------------------------------------------------------
            sam_outputs = self._forward_sam_heads(
                backbone_features=pix_feat_with_mem,
                point_inputs=point_inputs,
                mask_inputs=mask_inputs,
                high_res_features=high_res_features,
                multimask_output=multimask_output,
            )
            # sam_outputs 是一个 tuple,按顺序包含:
            # 0 low_res_multimasks (B,num_masks,256,256)
            # 1 high_res_multimasks(B,num_masks,1024,1024)
            # 2 ious                 (B,num_masks)
            # 3 low_res_masks        (B,1,256,256)        ← 最终选中的单 mask
            # 4 high_res_masks       (B,1,1024,1024)
            # 5 obj_ptr              (B,256)               ← 用于记忆的 object token
            # 6 object_score_logits  (B,1)                 ← 该帧目标存在置信度

            # 示例:
            # low_res_multimasks: torch.Size([1, 1, 256, 256])
            # high_res_multimasks: torch.Size([1, 1, 1024, 1024])
            # ious: tensor([[0.9436]], device='cuda:0')
            # low_res_masks: torch.Size([1, 1, 256, 256])
            # low_res_masks: torch.Size([1, 1, 1024, 1024])
            # obj_ptr: torch.Size([1, 256])
            # object_score_logits: object_score_logits: tensor([[24.0962]], device='cuda:0') 
 
        # ------------------------------------------------------------------
        # 8. 解压 SAM 输出,把关键字段写回 current_out
        # ------------------------------------------------------------------
        (
            _,                       # low_res_multimasks
            _,                       # high_res_multimasks
            _,                       # ious
            low_res_masks,           # 最终单 mask (B,1,256,256)
            high_res_masks,          # 对应高分辨率 (B,1,1024,1024)
            obj_ptr,                 # object pointer (B,256)
            _,                       # object_score_logits
        ) = sam_outputs
 
        current_out["pred_masks"] = low_res_masks
        current_out["pred_masks_high_res"] = high_res_masks
        current_out["obj_ptr"] = obj_ptr
        # current_out 此时新增:
        # {
        #   ...,
        #   "pred_masks": torch.Size([1, 1, 256, 256]),
        #   "pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),
        #   "obj_ptr": torch.Size([1, 256])
        # }
 
        # ------------------------------------------------------------------
        # 9. 可选:把当前预测 mask 编码成新的记忆,供后续帧使用
        #    若 run_mem_encoder=False 或 num_maskmem=0 则跳过
        # ------------------------------------------------------------------
        # run_mem_encoder: False  self.num_maskmem: 7
        if run_mem_encoder and self.num_maskmem > 0:
            # 用高分辨率 mask 进行记忆编码,细节保留更好
            high_res_masks_for_mem_enc = high_res_masks
            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
                current_vision_feats=current_vision_feats,
                feat_sizes=feat_sizes,
                pred_masks_high_res=high_res_masks_for_mem_enc,
                is_mask_from_pts=(point_inputs is not None),  # 区分点击 vs 外部 mask
            )
            current_out["maskmem_features"] = maskmem_features
            current_out["maskmem_pos_enc"] = maskmem_pos_enc
        else:
            # 不编码记忆时置空,调用方据此跳过 memory bank 更新
            current_out["maskmem_features"] = None
            current_out["maskmem_pos_enc"] = None
 
        # current_out:
        # {
        #   "point_inputs": {
        #                       'point_coords':tensor([[[418.6667, 543.2889],                   [656.0000, 924.4445]]], device='cuda:0')
        #                       'point_labels': tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
        #                   },
        #   "mask_inputs": None,
        #   "pred_masks": torch.Size([1, 1, 256, 256]),
        #   "pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),
        #   "obj_ptr": torch.Size([1, 256]),
        #   "maskmem_features": None,
        #   "maskmem_pos_enc": None
        # }
 
        return current_out

一句话

把"当前帧图像特征 + 用户给的点击/遮罩提示 + 历史记忆"喂给 SAM 解码器,生成当前帧的高质量分割结果,同时(可选)把这一帧的预测再编码成新的"记忆",供后续帧继续跟踪

三、回到_run_single_frame_inference

python 复制代码
@torch.inference_mode()
def _run_single_frame_inference(
    self,
    output_dict,
    frame_idx,
    batch_size,
    is_init_cond_frame,
    point_inputs,
    mask_inputs,
    reverse,
    run_mem_encoder,
    prev_sam_mask_logits=None,
):
    """
    对**单帧**执行一次"跟踪+解码"推理。
    核心作用:
        1. 取出该帧图像特征
        2. 调用 track_step 完成真正推理(prompt encoder + mask decoder + 记忆读写)
        3. 把输出整理成紧凑格式并搬到指定存储设备(GPU/CPU)
        4. 返回两份掩膜:一份用于后续跟踪(GPU),一份用于用户可视化/保存(CPU)
    参数:
        output_dict (dict):
            当前对象的输出仓库(cond/non_cond 子字典),track_step 会把结果写回这里。
        frame_idx (int):
            当前帧编号。
        batch_size (int):
            当前提示 batch 大小(= 对象数)。
        is_init_cond_frame (bool):
            True -> 该帧是"初始条件帧",将强制使用纯 SAM 路径,不读记忆。
        point_inputs (dict | None):
            点提示 {"point_coords": (B,N,2), "point_labels": (B,N)}。
        mask_inputs (torch.Tensor | None):
            低分辨率掩膜提示 (B,1,H',W');与 point_inputs 互斥。
        reverse (bool):
            是否按**时间倒序**跟踪(用于反向修正)。
        run_mem_encoder (bool):
            是否立即把新掩膜编码进记忆。用户交互阶段设为 False,延迟到 propagate 阶段。
        prev_sam_mask_logits (torch.Tensor | None):
            上一版掩膜 logits,作为 mask decoder 的额外提示(用于"迭代 refinement")。
    返回:
        tuple(dict, torch.Tensor):
            1. compact_current_out -- 紧凑版输出,已搬到 storage_device
            2. pred_masks_gpu      -- 当前帧掩膜 logits,仍留在 GPU,供下一帧跟踪使用
    """
 
    # 1. 取出该帧图像特征(已 expand 到 B 份)
    (
        _,
        _,
        current_vision_feats,      # backbone FPN 特征列表
        current_vision_pos_embeds, # 对应 positional embedding
        feat_sizes,                # 各层特征尺寸
    ) = self._get_image_feature(frame_idx, batch_size)
    # current_vision_feats: [
    #                   torch.Size([65536, 1, 32])
    #                   torch.Size([16384, 1, 64])
    #                   torch.Size([4096, 1, 256])
    # ]
    # current_vision_pos_embeds: [
    #                        torch.Size([65536, 1, 256])
    #                        torch.Size([16384, 1, 256])
    #                        torch.Size([4096, 1, 256])
    #                    ]
    # feat_sizes': [
    #                  (256, 256), 
    #                  (128, 128), 
    #                   (64, 64)
    #              ]
    #]

    # 2. 点提示与掩膜提示不能同时出现(SAM 规定)
    assert point_inputs is None or mask_inputs is None
 
    # frame_idx: 0 
    # is_init_cond_frame: True
    # point_inputs: {
    #                    'point_coords':tensor([[[416.5333, 541.3926], 
    #                                       [656.5333, 926.340]]], device='cuda:0')
    #                    'point_labels': tensor([[2, 3]], device='cuda:0', 
    #                                      dtype=torch.int32)
    # },
    # mask_inputs:None
    # output_dict: {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {} }
    # self.condition_state["num_frames"]: 1
    # reverse:False
    # run_mem_encoder: False
    # prev_sam_mask_logits: None

    # 3. 真正推理一步:prompt encoder → mask decoder → 记忆读写(可选)
    current_out = self.track_step(
        frame_idx=frame_idx,
        is_init_cond_frame=is_init_cond_frame,
        current_vision_feats=current_vision_feats,
        current_vision_pos_embeds=current_vision_pos_embeds,
        feat_sizes=feat_sizes,
        point_inputs=point_inputs,    
        mask_inputs=mask_inputs,     
        output_dict=output_dict,               # 结果写回这里
        num_frames=self.condition_state["num_frames"],
        track_in_reverse=reverse,
        run_mem_encoder=run_mem_encoder,       # 交互阶段=False,延迟编码
        prev_sam_mask_logits=prev_sam_mask_logits,
    )
    # current_out:{
    #            'point_inputs': {
    #                    'point_coords':tensor([[[416.5333, 541.3926], 
    #                                           [656.5333, 926.340]]], device='cuda:0')
    #                    'point_labels': tensor([[2, 3]], device='cuda:0', 
    #                                      dtype=torch.int32)
    #            },
    #            'mask_inputs': None,
    #            'pred_masks': torch.Size([1, 1, 256, 256]), 
    #            'pred_masks_high_res': torch.Size([1, 1, 1024, 1024]),
    #            'obj_ptr': torch.Size([1, 256]),
    #            'maskmem_features':None,
    #            'maskmem_pos_enc': None
    # }
 
    # 4. 根据用户设置,把部分数据搬到 CPU 以节省 GPU 显存
    storage_device = self.condition_state["storage_device"]
 
    # 4-a 记忆特征(用于后续帧)-> bfloat16 + 搬到 storage_device
    maskmem_features = current_out["maskmem_features"]
    # maskmem_features:None
    if maskmem_features is not None:
        maskmem_features = maskmem_features.to(torch.bfloat16)
        maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
 
    # 4-b 当前帧掩膜 -> 先填小洞(可选)-> 搬到 storage_device
    pred_masks_gpu = current_out["pred_masks"]  # 仍留在 GPU
    # self.fill_hole_area: 8
    if self.fill_hole_area > 0:
        pred_masks_gpu = fill_holes_in_mask_scores(
            pred_masks_gpu, self.fill_hole_area
        )
    pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
    # pred_masks: torch.Size([1, 1, 256, 256])
 
    # 4-c 位置编码同一帧只需一份(节省重复拷贝)
    maskmem_pos_enc = self._get_maskmem_pos_enc(current_out)
    # 第一帧 maskmem_pos_enc: None

    # 4-d object pointer(小向量,<1KB)常驻 GPU,保证下次访问速度
    obj_ptr = current_out["obj_ptr"]
    # obj_ptr: torch.Size([1, 256])

    # 5. 组装紧凑版输出,减少后续 propagate 时的状态体积
    compact_current_out = {
        "maskmem_features": maskmem_features,  # 已 bfloat16 + 已搬
        "maskmem_pos_enc": maskmem_pos_enc,
        "pred_masks": pred_masks,              # 已搬
        "obj_ptr": obj_ptr,                    # 仍在 GPU
    }
    # compact_current_out:{
    #     "maskmem_features": None
    #     "maskmem_pos_enc": None      
    #     "pred_masks": torch.Size([1, 1, 256, 256]) 
    #     "obj_ptr":  torch.Size([1, 256])
    # }
    # pred_masks_gpu: torch.Size([1, 1, 256, 256])

    # 6. 返回两份掩膜:
    #    compact_current_out -- 给用户保存/可视化(CPU 或 bfloat16)
    #    pred_masks_gpu      -- 仍留在 GPU,供下一帧跟踪继续 refinement
    return compact_current_out, pred_masks_gpu

3.1 fill_holes_in_mask_scores

sam2/utils/misc.py

if self.fill_hole_area > 0:

pred_masks_gpu = fill_holes_in_mask_scores(

pred_masks_gpu, self.fill_hole_area

)

python 复制代码
def fill_holes_in_mask_scores(mask, max_area):
    """
    后处理函数:填充掩码分数中小于max_area的小孔洞
    """

    # mask: torch.Size([1, 1, 256, 256])  
    # 输入的掩码分数张量

    # max_area: 8
    # 最大孔洞面积阈值
    
    # 孔洞定义:背景中面积小于等于max_area的连通区域
    # (背景区域是指掩码分数小于等于0的区域)
    assert max_area > 0, "max_area must be positive"
    
    # 获取背景区域的连通组件及其面积
    # labels标记每个连通组件的编号,areas记录每个像素所属组件的面积

    # 我的mask数值在-23到-16左右,mask<=0之后,背景像素都变成了True
    labels, areas = get_connected_components(mask <= 0)

    # labels: torch.Size([1, 1, 256, 256])
    # 连通组件标签图,每个像素值为其所属组件的编号

    # areas: torch.Size([1, 1, 256, 256])
    # 面积图,每个像素值为其所属组件的总面积
    
    # 识别孔洞:属于某个连通组件(labels > 0)且面积不超过max_area的区域
    is_hole = (labels > 0) & (areas <= max_area)

    # is_hole: torch.Size([1, 1, 256, 256])
    # 布尔掩码,标记哪些像素位置是待填充的孔洞
    
    # 使用torch.where填充孔洞:在孔洞位置填充一个小的正分数0.1,将其变为前景
    # 非孔洞位置保持原始掩码分数不变
    mask = torch.where(is_hole, 0.1, mask)

    # mask: torch.Size([1, 1, 256, 256])
    # 返回填充后的掩码分数张量
    return mask

这段代码是一个用于图像分割后处理的函数,主要功能是填充掩码图中的小孔洞

输入参数:

  • mask: 4D PyTorch张量(形状为[N, C, H, W]),表示像素级的掩码分数图,其中负值或0表示背景,正值表示前景

  • max_area: 整数,定义了被认为是"小孔洞"的最大面积阈值

处理流程:

  1. 识别背景区域 :通过 mask <= 0 找出所有背景像素

  2. 连通组件分析 :调用 get_connected_components 函数对背景区域进行连通组件标记,得到:

    • labels: 标记每个像素属于哪个连通组件(0表示非背景)

    • areas: 记录每个像素所在连通组件的总面积

  3. 孔洞检测 :找出所有面积不超过 max_area 的背景连通区域,这些被视作需要填充的孔洞

  4. 填充操作 :使用 torch.where 在孔洞位置填充数值0.1,将其转换为前景

输出结果:

  • 返回与输入同形状的 mask 张量,其中原本的小孔洞被填充为0.1,其余区域保持不变

应用场景: 在图像分割任务中,模型可能生成带有小孔洞或裂缝的掩码,此函数可有效消除这些噪声,使目标区域更加完整连续。

3.1.1 get_connected_components

python 复制代码
def get_connected_components(mask):
    """
    获取形状为 (N, 1, H, W) 的二值掩码的连通组件(8连通)。

    输入:
    - mask: 形状为 (N, 1, H, W) 的二值掩码张量,其中 1 表示前景(对于该函数来说),0 表示背景。

    输出:
    - labels: 形状为 (N, 1, H, W) 的张量,包含前景像素的连通组件标签,背景像素为 0。
    - counts: 形状为 (N, 1, H, W) 的张量,包含前景像素所属连通组件的面积,背景像素为 0。
    """
    # 导入SAM2的C++扩展模块,提供高效的连通域分析实现
    from sam2 import _C

    # mask: torch.Size([N, 1, H, W]) - 输入二值掩码
    # mask: torch.Size([1, 1, 256, 256])
    # 将掩码转换为uint8类型并确保内存布局连续,以满足C++接口要求
    # 调用C++实现的连通域分析函数
    return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
  • mask: 形状为 (N, 1, H, W) 的二值张量,其中 N 是批次大小,HW 是图像高和宽。值为1表示前景,0表示背景。

输出结果

  • labels: 与输入同形状的张量,每个前景像素赋予其所属连通组件的标签编号(从1开始),背景像素保持为0

  • counts: 与输入同形状的张量,每个前景像素值为该像素所属连通组件的总面积,背景像素为0

实现细节

  1. C++ 扩展 :通过 sam2._C 调用底层C++实现,保证计算效率

  2. 类型转换 :将输入转换为 torch.uint8 以满足C++接口要求

  3. 内存连续 :调用 .contiguous() 确保张量在内存中连续存储,避免底层访问错误

我的mask数值在-23到-16左右,mask<=0之后,背景像素都变成了True

labels, areas = get_connected_components(mask <= 0)

这句话是在做什么呢?mask是(1,1,256,256),然后mask<=0就是说那些属于背景的像素会变成True,属于前景的像素会变成False,然后转成uint8就是把True变成1,False变成0,然后get_connected_components就是把那些元素为1的区域(注意是对于函数输入参数来说的,也就是背景元素区域)进行分析,分析背景区域的连通性,然后会返回:

  • labels: 标记每个背景连通组件的编号(如1,2,3...)

  • areas: 记录每个背景连通组件的面积

假设我们背景是连通的,你就会看到labels是一个除了前景目标以外,元素值全为1的(1,1,256,256),然后areas是一个除了前景目标以外,元素值全为62130的(1,1,256,256),为什么是62130呢,这个数值只是个例子,就是说这个背景它只有一块连通域,且它的面积是62130像素数。下面这句就是把筛选出面积小的连通域,就是我们要修复的孔洞。

is_hole = (labels > 0) & (areas <= max_area)
我来用一个具体例子逐行拆解这两句代码,彻底讲清孔洞识别的逻辑:


场景设定

假设我们有一个5x5的分数掩码(已简化):

复制代码
mask = torch.tensor([[[[ 0.8,  0.9,  0.0,  0.7,  0.6],
                       [ 0.7,  0.0,  0.0,  0.8,  0.5],
                       [ 0.9,  0.8,  0.8,  0.7,  0.6],
                       [ 0.0,  0.0,  0.7,  0.8,  0.9],
                       [ 0.8,  0.7,  0.6,  0.9,  0.8]]]])
  • 前景(正分数):大部分区域

  • 背景 (≤0):4个像素,形成2个孤岛(孔洞)


第1步:get_connected_components(mask <= 0) 的结果

执行后返回两个张量:

labels(组件编号图)

复制代码
[[[[0, 0, 1, 0, 0],
   [0, 2, 2, 0, 0],
   [0, 0, 0, 0, 0],
   [3, 3, 0, 0, 0],
   [0, 0, 0, 0, 0]]]]
  • 0:非背景区域(原始前景)

  • 1,2,3:三个背景连通组件的编号

areas(组件面积图)

复制代码
[[[[0, 0, 1, 0, 0],
   [0, 2, 2, 0, 0],
   [0, 0, 0, 0, 0],
   [2, 2, 0, 0, 0],
   [0, 0, 0, 0, 0]]]]
  • 每个像素记录其所属组件的总面积

  • 组件1面积=1(单个像素)

  • 组件2面积=2(两个像素相连)

  • 组件3面积=2(两个像素相连)


第2句:is_hole = (labels > 0) & (areas <= max_area)

假设 max_area = 2,这句代码分三步执行:

labels > 0:找出所有背景组件像素

复制代码
(labels > 0) → 布尔张量
[[[[False, False,  True, False, False],
   [False,  True,  True, False, False],
   [False, False, False, False, False],
   [ True,  True, False, False, False],
   [False, False, False, False, False]]]]
  • True:属于某个背景组件(即原始前景中的"孔洞"区域)

  • False:原始前景或孤立背景像素

areas <= max_area:筛选小面积组件

复制代码
(areas <= 2) → 布尔张量
[[[[ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True]]]]
  • 实际上,所有背景组件面积都≤2,所以全为True

  • 若某个组件面积>2(如大面积背景),对应位置会为False

& 操作:取交集,精确定位孔洞

复制代码
is_hole = (labels > 0) & (areas <= 2)

结果:

复制代码
[[[[False, False,  True, False, False],
   [False,  True,  True, False, False],
   [False, False, False, False, False],
   [ True,  True, False, False, False],
   [False, False, False, False, False]]]]

含义:同时满足以下两个条件的像素:

  • 属于某个背景连通组件(原始mask中的孔洞位置)

  • 该组件面积≤max_area(排除大面积真实背景)


第2句:mask = torch.where(is_hole, 0.1, mask)

这是填充操作的核心,逐像素判断:

复制代码
mask = torch.where(is_hole, 0.1, mask)

等价于:

复制代码
如果 is_hole[i,j] == True:
    mask[i,j] = 0.1          # 填充孔洞
否则:
    mask[i,j] = 原值          # 保持不变

执行结果

复制代码
[[[[ 0.8,  0.9,  0.1,  0.7,  0.6],   # (0,2)从0.0→0.1
   [ 0.7,  0.1,  0.1,  0.8,  0.5],   # (1,1),(1,2)从0.0→0.1
   [ 0.9,  0.8,  0.8,  0.7,  0.6],
   [ 0.1,  0.1,  0.7,  0.8,  0.9],   # (3,0),(3,1)从0.0→0.1
   [ 0.8,  0.7,  0.6,  0.9,  0.8]]]]

一句话总结

is_hole 精准标记了"背景中的小孤岛" ,然后 torch.where 把这些孤岛的海拔从"负/零"提升到"0.1",让它们从海底变成了陆地。

为什么一定要属于某个背景连通组件,是不是get_connected_components会对所有背景元素都进行一个组件划分,哪怕只有一个像素也可能会构成一个组件?

你的理解完全正确get_connected_components 确实会对每个孤立的背景像素都创建一个独立的组件 ,哪怕只有一个像素。

为什么必须加 labels > 0

因为连通组件的输出中,labels=0labels>0完全不同的语义

labels值 含义 在孔洞填充场景中的对应
0 不属于任何组件(原始掩码值为0的像素) 原始前景区域
1,2,3... 属于某个连通组件(原始掩码值为1的像素) 原始背景区域

极简示例说明

假设 mask <= 0 生成如下背景掩码:

复制代码
背景掩码 (1=背景, 0=前景):
[[[1, 0, 1]]]  # 3个像素:左背景,中前景,右背景

get_connected_components 的输出

复制代码
labels:
[[[1, 0, 2]]]  # 左背景→组件1,中前景→0,右背景→组件2

areas:
[[[1, 0, 1]]]  # 组件1面积=1,前景面积=0,组件2面积=1

分析

  • 中间像素 labels=0:原始值为0(前景),不是背景组件

  • 左右像素 labels>0:原始值为1(背景),是背景组件

  • 即使左右背景像素不相连,也各自成为面积为1的独立组件

如果去掉 labels > 0 会怎样?

复制代码
# 错误写法:缺少 labels > 0
is_hole = (areas <= max_area)  # 错误!

这会错误地把原始前景区域也当成孔洞

  • areas=0 的像素(原始前景)也会满足 areas <= max_area

  • 导致整个前景区域被误判为"孔洞"并填充,完全破坏结果

一句话总结

labels > 0"背景身份认证" :确保我们只处理真正的背景像素,排除原始前景区域的干扰。没有这个条件,连通组件分析就毫无意义。

3.2 _get_maskmem_pos_enc

sam2/sam2_camera_predictor.py

位置编码同一帧只需一份(节省重复拷贝)

maskmem_pos_enc = self._get_maskmem_pos_enc(current_out)

python 复制代码
def _get_maskmem_pos_enc(self, current_out):
    """
    `maskmem_pos_enc` 在所有帧和对象之间都是相同的,因此我们将其作为
    常量缓存在推理会话中,以减少会话存储大小。
    """
    model_constants = self.condition_state["constants"]

    # model_constants: {}
    # 获取模型常量字典,用于缓存共享参数

    # "out_maskmem_pos_enc" 应该是张量列表或 None
    out_maskmem_pos_enc = current_out["maskmem_pos_enc"]

    # out_maskmem_pos_enc: None
    # 从当前输出中获取掩码记忆位置编码

    if out_maskmem_pos_enc is not None:
        # 如果尚未缓存 maskmem_pos_enc,则进行缓存
        if "maskmem_pos_enc" not in model_constants:
            assert isinstance(out_maskmem_pos_enc, list)
            # 只取单个对象的切片,因为它在所有对象间都相同
            # 克隆张量以避免原始数据被修改
            maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
            model_constants["maskmem_pos_enc"] = maskmem_pos_enc
        else:
            # 从缓存中读取已存储的 maskmem_pos_enc
            maskmem_pos_enc = model_constants["maskmem_pos_enc"]
        
        # 将缓存的 maskmem_pos_enc 扩展到实际的批大小
        # 通过 expand 操作复制张量以匹配当前批次大小
        batch_size = out_maskmem_pos_enc[0].size(0)
        expanded_maskmem_pos_enc = [
            x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
        ]
    else:
        # 如果当前输出中没有位置编码,则返回 None
        expanded_maskmem_pos_enc = None
    
    return expanded_maskmem_pos_enc

这段代码是一个缓存优化方法 ,用于处理视频目标分割中的掩码记忆位置编码 (maskmem_pos_enc)。

核心思想 :由于位置编码在所有帧和不同对象之间是完全相同的,因此只需在第一次计算时缓存单个对象的编码,后续直接复用并扩展到所需批次大小,从而显著减少内存占用。

工作流程

  1. 读取缓存容器 :从 model_constants 中获取常量存储字典

  2. 获取当前编码 :从 current_out 中提取位置编码(可能是列表或 None)

  3. 缓存逻辑

    • 首次调用:若缓存中不存在,则取第一个对象的切片进行克隆并缓存

    • 后续调用:直接从缓存读取已存储的编码

  4. 批处理扩展 :根据当前批次大小,使用 expand 将缓存的编码广播到所有样本

  5. 空值处理:如果输入为 None,则直接返回 None

技术细节

  • 使用 [0:1] 切片而非 [0] 是为了保持张量的维度不变

  • clone() 确保缓存的张量与原始计算图分离,避免潜在的梯度问题

  • expand() 操作不分配新内存,只是创建新的视图,非常高效

应用场景:在视频推理过程中,此方法可避免因重复计算和存储相同的位置编码而导致内存随帧数和对象数线性增长的问题。

这个函数第一帧直接返回的是None

相关推荐
说私域2 小时前
链动2+1模式AI智能名片商城小程序:裂变过程驱动的商业新生态构建
人工智能·小程序
Miku162 小时前
Qwen3-8B vLLM 部署实践教程(AutoDL 平台)
人工智能
小宇的天下2 小时前
电子封装表面处理工艺
人工智能
Aevget2 小时前
DevExpress JS & ASP.NET Core v25.1新版亮点 - 新增AI文本编辑功能
javascript·人工智能·asp.net·界面控件·devexpress·ui开发
Niuguangshuo2 小时前
PyTorch优化器完全指南
人工智能·pytorch·python
子夜江寒2 小时前
深度学习入门
深度学习·神经网络
寻道模式2 小时前
【时间之外】创业踩坑指南(7)-方向盘哲学
人工智能·创业
CoovallyAIHub2 小时前
YOLO11算法深度解析:四大工业场景实战,开源数据集助力AI质检落地
深度学习·算法·计算机视觉