SAM2跟踪的理解5——prompt encoder

目录

一、前言

二、track_step

三、_prepare_memory_conditioned_features

四、_use_multimask

五、_forward_sam_heads

[5.1 PromptEncoder.forward](#5.1 PromptEncoder.forward)

[5.1.1 _get_batch_size](#5.1.1 _get_batch_size)

[5.1.2 _embed_points](#5.1.2 _embed_points)

为什么加0.5就是挪到像素中心

[5.1.2.1 PositionEmbeddingRandom.forward_with_coords](#5.1.2.1 PositionEmbeddingRandom.forward_with_coords)

_pe_encoding函数中为什么要进行这些操作,为什么维度是这样变化的?

[为什么torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])?](#为什么torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])?)

如何理解随机高斯矩阵投影与放大2π

[先归一化→随机线性投影→2π 放大→sin/cos 拼接,这有必然的顺序吗?比如先2π 放大再随机线性投影可以吗?](#先归一化→随机线性投影→2π 放大→sin/cos 拼接,这有必然的顺序吗?比如先2π 放大再随机线性投影可以吗?)

[5.1.2.2 self.pe_layer.forward_with_coords之后做了什么](#5.1.2.2 self.pe_layer.forward_with_coords之后做了什么)

[point_embedding[labels == -1] = 0.0 什么意思](#point_embedding[labels == -1] = 0.0 什么意思)

这什么意思

[5.1.3 torch.cat([sparse_embeddings, point_embeddings], dim=1)什么意思?](#5.1.3 torch.cat([sparse_embeddings, point_embeddings], dim=1)什么意思?)

[5.1.4 dense_embeddings](#5.1.4 dense_embeddings)


一、前言

从《SAM2跟踪的理解2》开始,我们一直在看第一帧做了什么,其实就是通过load_first_frame函数中的_get_image_feature的image encoder提取图像特征,然后再通过add_new_prompt将第一帧中用户给定提示输入到SAM2中进行推理获取掩码,推理的函数就是add_new_prompt里面的 _run_single_frame_inference里面的track_step。

这一篇我们就是看track_step里面做了什么 ,其实track_step先调用了_prepare_memory_conditioned_features函数,这个函数就是把当前帧的视觉特征与"记忆"做融合,如果有记忆的话,这个函数是会调用memory attention的 (如果你想知道memory attention在干嘛,这一篇是没有的,等到我们说第二帧发生了什么的时候就会有),然而我们现在是在第一帧,是没有"记忆"的 ,那它做什么事?就是做了下面这些事:(第三节的内容)

  1. 当前帧的视觉特征: current_vision_feats[-1]: torch.Size([4096, 1, 256])

  2. "无记忆"embedding:self.no_mem_embed: torch.Size([1, 1, 256])

  3. 相加: pix_feat_with_mem: torch.Size([4096, 1, 256])

  4. HW,B,C\] -\> \[B,C,H,W\]: pix_feat_with_mem: torch.Size(\[1, 256, 64, 64\])

下面是第一帧情况下的函数调用顺序,可以看到我们在2.14个地方,后面主要经历的是prompt encoder和mask decoder,本篇是讲prompt encoder(第五节内容)

注:如果注释中涉及到数值,比如矩形提示的位置坐标,这种因为调试分了很多次,所以会有些变化,不用在意这些细节,只要看维度变化即可。

2.12 <重点> add_new_prompt

2.13 <重点> _run_single_frame_inference

2.14 <重点> track_step (这篇开头在这)

2.15 <重点> _prepare_memory_conditioned_features

2.16 _use_multimask

2.17 <重点> _forward_sam_heads

2.18 提示编码器:类PromptEncoder.forward

2.19 类PositionEmbeddingRandom.forward_with_coords(这篇结束在这)

2.20 类PromptEncoder.get_dense_pe

2.21 掩码解码器 类MaskDecoder.forward

2.22 类MaskDecoder.predict_masks

2.23 TwoWayTransformer.forward

2.24 TwoWayAttentionBlock.forward

2.25 Attention.forward

2.26 <重点> MLP.forward

2.27 Attention.forward

2.28 LayerNorm2d.forward

2.29 MaskDecoder._dynamic_multimask_via_stability

2.30 MaskDecoder._get_stability_scores

2.31 fill_holes_in_mask_scores

2.32 _get_maskmem_pos_enc

2.33 _consolidate_temp_output_across_obj

2.34 _get_orig_video_res_output

prompt encoder是如何把输入的提示点编码成向量的呢?其实是下面这样,像点提示、矩形提示这些它会编码成torch.Size([B, N, 256])的****sparse_embeddings(N表示提示点数,包括padding点) ,mask提示会编码成torch.Size([B, 256, 64, 64])的****dense_embeddings (即使没有masks提示也会整一个),最后就是返回sparse_embeddings和dense_embeddings

1. 输入的提示(比如输入两个提示点,第三个是防止不给提示加的全0的点,我们叫padding点)

coords: torch.Size([1, 3, 2]) # 3表示点的数量,2表示点有x和y

比如 coords: tensor([[[0.4093, 0.5264], [0.6349, 0.8959], [0.0000, 0.0000]]])

2. 用随机高斯矩阵做线性变换:2D 坐标 → 128 维"频率向量"

self.positional_encoding_gaussian_matrix: torch.Size([2, 128])

coords = coords @ self.positional_encoding_gaussian_matrix

torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])

3. 放大 2π,让 sin/cos 充分振荡

coords = 2 * np.pi * coords

coords: torch.Size([1, 3, 128])

4. 拼接 sin 和 cos,得到 256 维最终编码(128×2)

torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

位置编码 point_embedding: torch.Size([1, 3, 256])

5. 算法如何区分提示点是正提示点还是负提示点,还是padding点?加上****角色向量

padding 虚点 → 全 0 + "not-a-point"向量

point_embedding[labels == -1] = 0.0

point_embedding[labels == -1] += self.not_a_point_embed.weight # padding点

真实提示点 → 位置编码 + 角色向量

point_embedding[labels == 0] += self.point_embeddings[0].weight # 负提示点

point_embedding[labels == 1] += self.point_embeddings[1].weight # 正提示点

point_embedding[labels == 2] += self.point_embeddings[2].weight # 矩形左上角点

point_embedding[labels == 3] += self.point_embeddings[3].weight # 矩形右下角点

维度不变 point_embedding: torch.Size([1, 3, 256])

6. sparse_embeddings:torch.Size([1, 3, 256])

sparse_embeddings: torch.Size([1, 0, 256])

sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)

sparse_embeddings: torch.Size([1, 3, 256])

7.处理 mask 提示

有 mask → 用轻量 CNN 压成 256×64×64 的"dense embedding";

dense_embeddings = self._embed_masks(masks)

# 无 mask → 用可学习的"no-mask"向量铺满 64×64 空间,保持形状一致

self.no_mask_embed: nn.Embedding(1, 256)

self.image_embedding_size: (64, 64)

dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(

bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]

)

(1,256) => (1,256,1,1) => (1,256, 64, 64)

8. dense_embeddings: torch.Size([1, 256, 64, 64])

二、track_step

python 复制代码
def track_step(
        self,
        frame_idx,                       # 当前帧的全局序号(从 0 开始)
        is_init_cond_frame,              # 是否为初始条件帧(第一帧或用户重新标注的帧)
        current_vision_feats,            # 当前帧的 backbone 特征列表,长度=层数,每层形状 (HW, B, C)
        current_vision_pos_embeds,       # 对应的位置编码,与 current_vision_feats 一一对应
        feat_sizes,                      # 每层特征图的空间分辨率,如 [(256,256), (128,128)]
        point_inputs,                    # 用户点击/提示,格式 dict{"point_coords": (B,K,2), "point_labels": (B,K)}
        mask_inputs,                     # 用户输入的 mask(可选),形状 (B,1,H,W) 或 None
        output_dict,                     # 全局状态字典,包含 memory bank、已编码的记忆等
        num_frames,                      # 整个视频的总帧数
        track_in_reverse=False,          # 是否按倒序跟踪(demo 中回退播放时使用)
        run_mem_encoder=True,            # 是否要把当前预测结果编码成记忆供后续帧使用
        prev_sam_mask_logits=None,       # 上一帧预测的 low-res mask logits(demo 连续点击时累积用)
    ):
        """
        单帧跟踪函数。
        1. 融合历史记忆与当前视觉特征 -> 得到带记忆的 pix_feat
        2. 用 SAM head 解码出 mask
        3. 可选地把预测 mask 再编码成新的记忆
        返回 current_out,包含本帧所有需要保存/上传的结果。
        """
        # current_vision_feats: [
        #                   torch.Size([65536, 1, 32])
        #                   torch.Size([16384, 1, 64])
        #                   torch.Size([4096, 1, 256])
        # ]
        # current_vision_pos_embeds: [
        #                        torch.Size([65536, 1, 256])
        #                        torch.Size([16384, 1, 256])
        #                        torch.Size([4096, 1, 256])
        #                    ]
        # feat_sizes': [
        #                  (256, 256), 
        #                  (128, 128), 
        #                   (64, 64)
        #              ]
        #]
 
        # ------------------------------------------------------------------
        # 1. 组装本帧的输入字典,方便后续上传或可视化
        # ------------------------------------------------------------------
        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
        # current_out 示例:
        # {
        #   "point_inputs": {
        #       "point_coords": tensor([[[416.5333, 541.3926], 
        #                                       [656.5333, 926.340]]], device='cuda:0'),
        #       "point_labels": tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
        #   },
        #   "mask_inputs": None
        # }
 
        # ------------------------------------------------------------------
        # 2. 取出 backbone 的多层特征,转成 BCHW 供 SAM 的高分辨率分支使用
        #    最后一层特征留给记忆融合,其余层作为 high_res_features
        # ------------------------------------------------------------------
        if len(current_vision_feats) > 1:
            # 将 (HW,B,C) -> (B,C,H,W)
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
            ]
            # high_res_features:[
            #                        torch.Size([1, 32, 256, 256]), 
            #                        torch.Size([1, 64, 128, 128])
            #]
        else:
            high_res_features = None
 
        # ------------------------------------------------------------------
        # 3. 如果用户直接给了 GT mask 且配置要求跳过 SAM,则直接把它当输出
        # ------------------------------------------------------------------
        # mask_inputs:None  use_mask_input_as_output_without_sam:True
        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
            # 取最后一层特征作为像素嵌入
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)          # (B,C,H*W)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])  # (B,C,H,W)
            sam_outputs = self._use_mask_as_output(
                pix_feat, high_res_features, mask_inputs
            )
        else:
            # ------------------------------------------------------------------
            # 4. 正常路径:把当前特征与记忆库融合,得到带历史信息的 pix_feat
            # ------------------------------------------------------------------
            pix_feat_with_mem = self._prepare_memory_conditioned_features(
                frame_idx=frame_idx,
                is_init_cond_frame=is_init_cond_frame,
                current_vision_feats=current_vision_feats[-1:],   # 只用最后一层
                current_vision_pos_embeds=current_vision_pos_embeds[-1:],
                feat_sizes=feat_sizes[-1:],
                output_dict=output_dict,
                num_frames=num_frames,
                track_in_reverse=track_in_reverse,
            )
 
            # ------------------------------------------------------------------
            # 5. 处理 demo 场景:如果之前已有 SAM 预测结果,把它作为 mask prompt
            # ------------------------------------------------------------------
            if prev_sam_mask_logits is not None:
                assert point_inputs is not None and mask_inputs is None, \
                    "prev_sam_mask_logits 仅在有点击且无外部 mask 时生效"
                mask_inputs = prev_sam_mask_logits
 
            # ------------------------------------------------------------------
            # 6. 决定 SAM 是否输出多 mask(初始帧或仅单点击时通常允许多 mask)
            # ------------------------------------------------------------------
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
 
            # ------------------------------------------------------------------
            # 7. 真正调用 SAM 解码头,得到 low/high 分辨率 mask 及 object pointer
            # ------------------------------------------------------------------
            sam_outputs = self._forward_sam_heads(
                backbone_features=pix_feat_with_mem,
                point_inputs=point_inputs,
                mask_inputs=mask_inputs,
                high_res_features=high_res_features,
                multimask_output=multimask_output,
            )
            # sam_outputs 是一个 tuple,按顺序包含:
            # 0 low_res_multimasks (B,num_masks,256,256)
            # 1 high_res_multimasks(B,num_masks,1024,1024)
            # 2 ious                 (B,num_masks)
            # 3 low_res_masks        (B,1,256,256)        ← 最终选中的单 mask
            # 4 high_res_masks       (B,1,1024,1024)
            # 5 obj_ptr              (B,256)               ← 用于记忆的 object token
            # 6 object_score_logits  (B,1)                 ← 该帧目标存在置信度
 
        # ------------------------------------------------------------------
        # 8. 解压 SAM 输出,把关键字段写回 current_out
        # ------------------------------------------------------------------
        (
            _,                       # low_res_multimasks
            _,                       # high_res_multimasks
            _,                       # ious
            low_res_masks,           # 最终单 mask (B,1,256,256)
            high_res_masks,          # 对应高分辨率 (B,1,1024,1024)
            obj_ptr,                 # object pointer (B,256)
            _,                       # object_score_logits
        ) = sam_outputs
 
        current_out["pred_masks"] = low_res_masks
        current_out["pred_masks_high_res"] = high_res_masks
        current_out["obj_ptr"] = obj_ptr
        # current_out 此时新增:
        # {
        #   ...,
        #   "pred_masks": torch.Size([1, 1, 256, 256]),
        #   "pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),
        #   "obj_ptr": torch.Size([1, 256])
        # }
 
        # ------------------------------------------------------------------
        # 9. 可选:把当前预测 mask 编码成新的记忆,供后续帧使用
        #    若 run_mem_encoder=False 或 num_maskmem=0 则跳过
        # ------------------------------------------------------------------
        if run_mem_encoder and self.num_maskmem > 0:
            # 用高分辨率 mask 进行记忆编码,细节保留更好
            high_res_masks_for_mem_enc = high_res_masks
            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
                current_vision_feats=current_vision_feats,
                feat_sizes=feat_sizes,
                pred_masks_high_res=high_res_masks_for_mem_enc,
                is_mask_from_pts=(point_inputs is not None),  # 区分点击 vs 外部 mask
            )
            current_out["maskmem_features"] = maskmem_features
            current_out["maskmem_pos_enc"] = maskmem_pos_enc
        else:
            # 不编码记忆时置空,调用方据此跳过 memory bank 更新
            current_out["maskmem_features"] = None
            current_out["maskmem_pos_enc"] = None
 
        # current_out:
        # {
        #   "point_inputs": {...},
        #   "mask_inputs": None,
        #   "pred_masks": torch.Size([1, 1, 256, 256]),
        #   "pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),
        #   "obj_ptr": torch.Size([1, 256]),
        #   "maskmem_features": None,
        #   "maskmem_pos_enc": None
        # }
 
        return current_out

一句话

把"当前帧图像特征 + 用户给的点击/遮罩提示 + 历史记忆"喂给 SAM 解码器,生成当前帧的高质量分割结果,同时(可选)把这一帧的预测再编码成新的"记忆",供后续帧继续跟踪。

三、_prepare_memory_conditioned_features

在track_step中,_prepare_memory_conditioned_features传入的参数,值得注意的是current_vision_feats和current_vision_pos_embeds和feat_sizes都是传的最后一层

------------------------------------------------------------------

4. 正常路径:把当前特征与记忆库融合,得到带历史信息的 pix_feat

------------------------------------------------------------------

pix_feat_with_mem = self._prepare_memory_conditioned_features(

frame_idx=frame_idx,

is_init_cond_frame=is_init_cond_frame,

current_vision_feats=current_vision_feats[-1:], # 只用最后一层

current_vision_pos_embeds=current_vision_pos_embeds[-1:],

feat_sizes=feat_sizes[-1:],

output_dict=output_dict,

num_frames=num_frames,

track_in_reverse=track_in_reverse,

)

下面就是_prepare_memory_conditioned_features函数,注意现在我们是在初始条件帧:没有过去记忆,只能"凭空"编码,直接使把self.no_mem_embed: torch.Size([1, 1, 256]) 加到current_vision_feats[-1]: torch.Size([4096, 1, 256])得到torch.Size([4096, 1, 256]),然后再经过:

HW,B,C\] -\> \[B,C,H,W

得到:pix_feat_with_mem: torch.Size([1, 256, 64, 64]),就返回了。没有后面的memory attention

python 复制代码
def _prepare_memory_conditioned_features(
        self,
        frame_idx,                       # 当前帧在整个视频序列中的绝对序号
        is_init_cond_frame,              # 当前帧是否属于"初始条件帧"(即带人工标注的 key-frame)
        current_vision_feats,            # 当前帧的视觉主干输出,list[tensor],每层分辨率一个
        current_vision_pos_embeds,       # 当前帧对应的空间位置编码,与上一项一一对应
        feat_sizes,                      # 每项特征图的 (H, W)
        output_dict,                     # 之前所有帧已经算好的输出字典,含 memory 与 object pointer
        num_frames,                      # 整个视频总帧数(用于边界检查)
        track_in_reverse=False,          # 演示模式下倒着跟踪(从后往前)
    ):
        """
        把当前帧的视觉特征与"记忆"做融合,得到 memory-conditioned 特征图。
        如果 num_maskmem==0 则退化为普通 SAM(无记忆,单图分割)。
        """
        # current_vision_feats: [torch.Size([4096, 1, 256]),]
        B = current_vision_feats[-1].size(1)  # batch size
        # B: 1

        C = self.hidden_dim                   # 隐层维数
        # C: 256

        H, W = feat_sizes[-1]                 # 最低分辨率特征图尺寸
        # H: 64 W:64

        device = current_vision_feats[-1].device
        # device: device(type='cuda', index=0)

        # 1) 无记忆模式:直接返回当前帧顶层特征(复现 SAM 单图效果)
        # self.num_maskmem: 7
        if self.num_maskmem == 0:
            pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
            return pix_feat

        num_obj_ptr_tokens = 0   # 记录后面要塞进 transformer 的 object pointer token 数

        # 2) 非初始条件帧:需要把"过去记忆"拿出来做 cross-attention
        # is_init_cond_frame: Ture 
        if not is_init_cond_frame:
            to_cat_memory, to_cat_memory_pos_embed = [], []

            # 2.1 先选最多 max_cond_frames_in_attn 个"离当前帧最近"的初始条件帧
            assert len(output_dict["cond_frame_outputs"]) > 0
            cond_outputs = output_dict["cond_frame_outputs"]
            selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
                frame_idx, cond_outputs, self.max_cond_frames_in_attn
            )
            # t_pos=0 表示这些条件帧在时序上被视为"原点"
            t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]

            # 2.2 再取 (self.num_maskmem-1) 个"普通记忆帧"(非条件帧)
            r = self.memory_temporal_stride_for_eval  # 采样 stride
            for t_pos in range(1, self.num_maskmem):
                t_rel = self.num_maskmem - t_pos      # 离当前帧多远
                if t_rel == 1:
                    # 紧邻帧(t_rel==1)一定取
                    prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
                else:
                    # 更远的帧按 stride r 跳着取
                    if track_in_reverse:
                        prev_frame_idx = -(-(frame_idx + 2) // r) * r + (t_rel - 2) * r
                    else:
                        prev_frame_idx = ((frame_idx - 2) // r) * r - (t_rel - 2) * r
                # 先查非条件帧字典,没有再查"未被选中"的条件帧
                out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) or \
                      unselected_cond_outputs.get(prev_frame_idx, None)
                t_pos_and_prevs.append((t_pos, out))

            # 2.3 把选中的记忆特征和对应的位置编码拼到一起
            for t_pos, prev in t_pos_and_prevs:
                if prev is None:
                    continue            # 该帧不存在,跳过(padding)
                feats = prev["maskmem_features"].to(device, non_blocking=True)
                to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))   # [HW,B,C]
                maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
                maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
                # 再加上"时序位置编码":越早的记忆 t_pos 越大,embedding 越靠后
                maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
                to_cat_memory_pos_embed.append(maskmem_enc)

            # 2.4 构造"对象指针" object pointer(用于保持实例身份)
            if self.use_obj_ptrs_in_encoder:
                max_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder)
                # 选哪些条件帧的 pointer:eval 时可限制只取"过去"的
                if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
                    ptr_cond_outputs = {
                        t: out for t, out in selected_cond_outputs.items()
                        if (t >= frame_idx if track_in_reverse else t <= frame_idx)
                    }
                else:
                    ptr_cond_outputs = selected_cond_outputs
                pos_and_ptrs = [(abs(frame_idx - t), out["obj_ptr"])
                                for t, out in ptr_cond_outputs.items()]
                # 继续往前(或往后)取普通帧的 pointer
                for t_diff in range(1, max_ptrs):
                    t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
                    if t < 0 or (num_frames is not None and t >= num_frames):
                        break
                    out = output_dict["non_cond_frame_outputs"].get(
                        t, unselected_cond_outputs.get(t, None))
                    if out is not None:
                        pos_and_ptrs.append((t_diff, out["obj_ptr"]))

                if pos_and_ptrs:            # 至少有一个 pointer
                    pos_list, ptrs_list = zip(*pos_and_ptrs)
                    obj_ptrs = torch.stack(ptrs_list, dim=0)            # [ptr_len, B, C]
                    if self.add_tpos_enc_to_obj_ptrs:                   # 给 pointer 也加时序位置
                        t_diff_max = max_ptrs - 1
                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
                        obj_pos = get_1d_sine_pe(torch.tensor(pos_list, device=device) / t_diff_max,
                                                 dim=tpos_dim)
                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
                        obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
                    else:
                        obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)

                    # 如果 mem_dim < C,把一个 pointer 拆成多个 token
                    if self.mem_dim < C:
                        obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) \
                                          .permute(0, 2, 1, 3).flatten(0, 1)
                        obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)

                    to_cat_memory.append(obj_ptrs)
                    to_cat_memory_pos_embed.append(obj_pos)
                    num_obj_ptr_tokens = obj_ptrs.shape[0]

        # 3) 初始条件帧:没有过去记忆,只能"凭空"编码
        else:
            # self.directly_add_no_mem_embed:True
            if self.directly_add_no_mem_embed:
                # 简单地把"无记忆"embedding 加到特征上,跳过 transformer
                # current_vision_feats[-1]: torch.Size([4096, 1, 256])
                # self.no_mem_embed: torch.Size([1, 1, 256])
                pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
                # pix_feat_with_mem: torch.Size([4096, 1, 256])

                pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
                # pix_feat_with_mem: torch.Size([1, 256, 64, 64])
                # 输出形状:[HW,B,C] -> [B,C,H,W]
                return pix_feat_with_mem

            # 否则喂一个 dummy memory token 给 transformer,防止空输入
            to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
            to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]

        # 4) 把以上所有记忆 token 拼成序列,送进 memory_attention 层
        memory = torch.cat(to_cat_memory, dim=0)
        memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)

        pix_feat_with_mem = self.memory_attention(
            curr=current_vision_feats,           # 当前帧多层特征
            curr_pos=current_vision_pos_embeds,  # 当前帧位置编码
            memory=memory,                       # 过去记忆 + object pointer
            memory_pos=memory_pos_embed,         # 对应位置编码
            num_obj_ptr_tokens=num_obj_ptr_tokens,
        )
        # 输出形状:[HW,B,C] -> [B,C,H,W]
        pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
        return pix_feat_with_mem

一句话总结

该函数是 SAM 2 "记忆解码器"的核心:

对于当前帧,先把"最近若干初始条件帧"和"按 stride 采样到的普通记忆帧"的特征、位置编码以及 object pointer 收集起来,拼成一个记忆序列;随后用轻量级 transformer(memory_attention)让当前帧特征与这段记忆做 cross-attention,从而把"过去看到的掩码和对象身份"注入到当前特征图里,最终返回一张融合后的 B×C×H×W 特征,供后续掩码解码器使用。

四、_use_multimask

继续看track_step后面发生了什么

------------------------------------------------------------------

5. 处理 demo 场景:如果之前已有 SAM 预测结果,把它作为 mask prompt

------------------------------------------------------------------

prev_sam_mask_logits: None

if prev_sam_mask_logits is not None:

assert point_inputs is not None and mask_inputs is None, \

"prev_sam_mask_logits 仅在有点击且无外部 mask 时生效"

mask_inputs = prev_sam_mask_logits

------------------------------------------------------------------

6. 决定 SAM 是否输出多 mask(初始帧或仅单点击时通常允许多 mask)

------------------------------------------------------------------

is_init_cond_frame: True

multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)

python 复制代码
def _use_multimask(self, is_init_cond_frame, point_inputs):
    """
    根据当前帧类型和提示点数量,决定 SAM mask-decoder 是否输出**多掩膜**(multimask)。
    背景:
        SAM 默认对**模糊提示**(单点/粗框)同时给出 3 个候选掩膜,让用户或后续模块挑选;
        对**清晰提示**(多点、已有较好 mask)则只输出 1 个精炼掩膜,节省计算与显存。
    本函数即实现这一策略开关。
    参数:
        is_init_cond_frame (bool):
            True → 当前帧是初始条件帧(用户首次点击或重新初始化)。
        point_inputs (dict | None):
            点提示字典,含 "point_labels" (B, N) 和 "point_coords" (B, N, 2)。
            若为 None,表示无点提示。
    返回:
        bool:
            True  → 使用 multimask 输出(3 个候选);
            False → 只输出单掩膜。
    """
 
    # 1. 计算当前提示点数量(不含背景点)
    num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
    # point_labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
    # num_pts: 2   对于粗框,提示有2个,矩形左上角和矩形右下角

    # 2. 同时满足以下 4 个条件才开启 multimask:
    #    a) 全局开关打开;
    #    b) 当前是初始条件帧,**或**配置允许跟踪阶段也 multimask;
    #    c) 点数量在 [min, max] 区间内;
    #    d) 实际上只要有提示点就会触发,没点时 num_pts=0 自动不满足。
    multimask_output = (
        # self.multimask_output_in_sam:True
        self.multimask_output_in_sam  # 全局超参:是否启用 multimask 功能
        and 
        # is_init_cond_frame:True  self.multimask_output_for_tracking:True
        (is_init_cond_frame or self.multimask_output_for_tracking)  # 帧类型限制
        and 
        # self.multimask_min_pt_num:0   num_pt:2 self.multimask_max_pt_num:1
        (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)  # 点数限制
    )
    # multimask_output:False
 
    return multimask_output

一句话

按"全局开关 + 帧类型 + 提示点数量"三要素,快速决定 SAM 解码器是输出 3 个候选 mask(True)还是只输出 1 个精炼 mask(False)。

可以看到,在我们的例子中最后返回的是False,这是因为我们的例子中粗框提示包括两个,一个是矩形的左上角点,一个是矩形的右下角点,2个提示已经是比较精确了,只有在0或者1个提示的时候,提示才是模糊的,模糊提示情况下才输出多个掩码。

五、_forward_sam_heads

继续看track_step后面发生了什么

------------------------------------------------------------------

7. 真正调用 SAM 解码头,得到 low/high 分辨率 mask 及 object pointer

------------------------------------------------------------------

sam_outputs = self._forward_sam_heads(

backbone_features=pix_feat_with_mem,

point_inputs=point_inputs,

mask_inputs=mask_inputs,

high_res_features=high_res_features,

multimask_output=multimask_output,

)

python 复制代码
@torch.inference_mode()
def _forward_sam_heads(
    self,
    backbone_features,
    point_inputs=None,
    mask_inputs=None,
    high_res_features=None,
    multimask_output=False,
):
    """
    完整的 SAM 风格「提示编码 + 掩膜解码」一条龙:
    把 backbone 特征与用户提示(点/框/mask)一起送 SAM,
    输出多组或单组掩膜 logits、IoU 估计、对象指针等。
    参数:
        backbone_features (Tensor):
            已融合记忆的图像嵌入,形状 (B, C, H, W),
            其中 H=W=sam_image_embedding_size(默认 64)。
        point_inputs (dict | None):
            {"point_coords": (B, P, 2), "point_labels": (B, P)}
            坐标为**绝对像素**(会在内部归一化),label: 1=前, 0=背, -1=pad。
        mask_inputs (Tensor | None):
            低分辨率掩膜提示 (B,1,H',W'),与点提示互斥。
        high_res_features (list[Tensor] | None):
            额外两层更高分辨率特征 [4H,4W] 与 [2H,2W],供 decoder refine 边缘。
        multimask_output (bool):
            True → 输出 3 候选掩膜 + 3 IoU;False → 1 掩膜 + 1 IoU。
    返回:
        tuple:
            0. low_res_multimasks -- (B,M,H/4,W/4)  低分多掩膜 logits
            1. high_res_multimasks -- (B,M,H,W)     高分多掩膜 logits
            2. ious -- (B,M)                        各掩膜 IoU 估计
            3. low_res_masks -- (B,1,H/4,W/4)       **最佳**低分掩膜
            4. high_res_masks -- (B,1,H,W)          **最佳**高分掩膜
            5. obj_ptr -- (B,C)                     对象指针(用于记忆)
            6. object_score_logits -- (B,)          对象出现置信度(可软可硬)
    """
    # backbone_features: torch.Size([1, 256, 64, 64])
    # point_inputs: {
    #                   'point_coords': tensor([[[421.3333, 540.4445], 
    #                                           [654.9333, 921.6000]]], device='cuda:0'),
    #                   'point_labels': tensor([[2, 3]], 
    #                                            device='cuda:0', dtype=torch.int32)
    #               }
    # mask_inputs:None
    # high_res_features:[
    #                        torch.Size([1, 32, 256, 256]), 
    #                        torch.Size([1, 64, 128, 128])
    #                   ]
    # multimask_output: False

    B = backbone_features.size(0)
    # B: 1

    device = backbone_features.device
    # device: device(type='cuda', index=0) 

    # --- 1. 输入断言:尺寸必须匹配 SAM 预设 ---
    assert backbone_features.size(1) == self.sam_prompt_embed_dim
    assert backbone_features.size(2) == self.sam_image_embedding_size
    assert backbone_features.size(3) == self.sam_image_embedding_size
 
    # --- 2. 构造点提示 ---
    if point_inputs is not None:
        sam_point_coords = point_inputs["point_coords"]  # (B, P, 2)
        # sam_point_coords: torch.Size([1, 2, 2])
        # sam_point_coords: tensor([[[421.3333, 540.4445], 
        #                             [654.9333, 921.6000]]], device='cuda:0'),

        sam_point_labels = point_inputs["point_labels"]  # (B, P)
        # sam_point_labels: torch.Size([1, 2])
        # sam_point_labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)

        assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
    else:
        # 无点提示时,用 1 个 pad 点(label=-1)占位,保证 prompt encoder 正常 forward
        sam_point_coords = torch.zeros(B, 1, 2, device=device)
        sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
 
    # --- 3. 构造掩膜提示 ---
    # mask_inputs: None     
    if mask_inputs is not None:
        # 若外部 mask 分辨率不符,先双线性下采样到 prompt encoder 期望尺寸
        assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
        if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
            sam_mask_prompt = F.interpolate(
                mask_inputs.float(),
                size=self.sam_prompt_encoder.mask_input_size,
                align_corners=False,
                mode="bilinear",
                antialias=True,  # 抗锯齿,减少下采样 aliasing
            )
        else:
            sam_mask_prompt = mask_inputs
    else:
        # 无 mask 时,prompt encoder 内部会加 learned `no_mask_embed`
        sam_mask_prompt = None
 
    # --- 4. 送进 SAM 提示编码器 ---
    sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
        points=(sam_point_coords, sam_point_labels),
        boxes=None,  # 框提示在外部已转成 2 点,不再走这里
        masks=sam_mask_prompt,
    )
 
    # --- 5. 送进 SAM mask 解码器 ---
    (
        low_res_multimasks,   # (B, M, H/4, W/4)  M=3 or 1
        ious,                 # (B, M)             IoU 估计
        sam_output_tokens,    # (B, M, C)          解码器输出 token
        object_score_logits,  # (B,)               对象出现/消失 logits
    ) = self.sam_mask_decoder(
        image_embeddings=backbone_features,
        image_pe=self.sam_prompt_encoder.get_dense_pe(),  # 固定 2D 位置编码
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
        repeat_image=False,  # image 已 batch,无需再 repeat
        high_res_features=high_res_features,  # 高分特征供 refine 边缘
    )
 
    # --- 6. 对象 score 后处理:若模型预测"无对象",把掩膜 logits 置为 NO_OBJ_SCORE ---
    if self.pred_obj_scores:
        is_obj_appearing = object_score_logits > 0  # 硬阈值
        # 记忆用掩膜必须**硬**选择:有对象才保留,否则置极大负值
        low_res_multimasks = torch.where(
            is_obj_appearing[:, None, None],
            low_res_multimasks,
            NO_OBJ_SCORE,
        )
 
    # --- 7. 数据类型转换:bf16/fp16 -> fp32(老版本 PyTorch interpolate 不支持 bf16)---
    low_res_multimasks = low_res_multimasks.float()
    # 上采样到图像原分辨率(stride=1)
    high_res_multimasks = F.interpolate(
        low_res_multimasks,
        size=(self.image_size, self.image_size),
        mode="bilinear",
        align_corners=False,
    )
 
    # --- 8. 选取最佳掩膜 ---
    sam_output_token = sam_output_tokens[:, 0]  # 默认取第 1 个 token(单掩膜时即自身)
    if multimask_output:
        # 多掩膜时,选 IoU 估计最高的那个
        best_iou_inds = torch.argmax(ious, dim=-1)  # (B,)
        batch_inds = torch.arange(B, device=device)
        low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
        high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
        # 若解码器输出了多个 token,同样要选最佳
        if sam_output_tokens.size(1) > 1:
            sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
    else:
        # 单掩膜时,最佳即唯一
        low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
 
    # --- 9. 从最佳 token 提取对象指针(用于记忆)---
    obj_ptr = self.obj_ptr_proj(sam_output_token)  # (B, C)
 
    # --- 10. 对象指针后处理:若模型认为"无对象",指针也被削弱或替换 ---
    if self.pred_obj_scores:
        if self.soft_no_obj_ptr:
            # 软削弱:用 sigmoid 概率加权
            assert not self.teacher_force_obj_scores_for_mem
            lambda_is_obj_appearing = object_score_logits.sigmoid()
        else:
            # 硬削弱:0/1 加权
            lambda_is_obj_appearing = is_obj_appearing.float()
 
        if self.fixed_no_obj_ptr:
            obj_ptr = lambda_is_obj_appearing * obj_ptr
        # 剩余权重用"无对象指针"补齐,保证指针和为 1
        obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
 
    # --- 11. 返回打包结果 ---
    return (
        low_res_multimasks,      # 0 低分多掩膜
        high_res_multimasks,     # 1 高分多掩膜
        ious,                    # 2 各掩膜 IoU 估计
        low_res_masks,           # 3 最佳低分掩膜
        high_res_masks,          # 4 最佳高分掩膜
        obj_ptr,                 # 5 对象指针(记忆 key/query)
        object_score_logits,     # 6 对象出现置信度 logits
    )

一句话

把"已融合记忆的图像特征"和"用户给的点/低分 mask 提示"一起塞进 SAM 的提示编码器 + 掩膜解码器,先产出 1 或 3 组低分 mask,再上采样到原图分辨率,挑 IoU 最高的作为最终 mask,并同步给出对应的对象指针与存在置信度,供后续记忆更新使用。

5.1 PromptEncoder.forward

sam2/modeling/sam/prompt_encoder.py

--- 4. 送进 SAM 提示编码器 ---

sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(

points=(sam_point_coords, sam_point_labels), # 走的是这里

boxes=None, # 框提示在外部已转成 2 点,不再走这里

masks=sam_mask_prompt,

)

python 复制代码
def forward(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Embeds different types of prompts, returning both sparse and dense
    embeddings.

    Arguments:
      points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
        and labels to embed.
      boxes (torch.Tensor or none): boxes to embed
      masks (torch.Tensor or none): masks to embed

    Returns:
      torch.Tensor: sparse embeddings for the points and boxes, with shape
        BxNx(embed_dim), where N is determined by the number of input points
        and boxes.
      torch.Tensor: dense embeddings for the masks, in the shape
        Bx(embed_dim)x(embed_H)x(embed_W)
    """
    # points:[
    #            tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0'),
    #            tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
    #    ]
    # boxes:None  
    # masks:None

    # 1. 统一推断 batch 大小(以第一个非 None 的提示为准)
    bs = self._get_batch_size(points, boxes, masks)
    # bs: 1

    # 2. 先建一个空张量,后续把点/框嵌入拼进来
    # self.embed_dim: 256
    sparse_embeddings = torch.empty(
        (bs, 0, self.embed_dim), device=self._get_device()
    )
    # sparse_embeddings: torch.Size([1, 0, 256])

    # 3. 处理点提示:转成 256-d 向量并拼到 sparse_embeddings
    if points is not None:
        coords, labels = points
        # coords: tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0')
        # labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)

        # pad=(boxes is None):当没有框提示时,给点序列末尾补一个"padding 点"
        point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
        # point_embeddings: torch.Size([1, 3, 256])

        # sparse_embeddings: torch.Size([1, 0, 256])
        sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        # sparse_embeddings: torch.Size([1, 3, 256])

    # 4. 处理框提示:左上+右下两个角点各生成 256-d 向量并拼接
    # boxes: None
    if boxes is not None:
        box_embeddings = self._embed_boxes(boxes)
        sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

    # 5. 处理 mask 提示:
    #    有 mask → 用轻量 CNN 压成 256×64×64 的"dense embedding";
    #    无 mask → 用可学习的"no-mask"向量铺满 64×64 空间,保持形状一致
    # masks: None
    if masks is not None:
        dense_embeddings = self._embed_masks(masks)
    else:
        # self.no_mask_embed: nn.Embedding(1, 256)
        # self.image_embedding_size: (64, 64)
        dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
            bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
        )
        # dense_embeddings: torch.Size([1, 256, 64, 64])

    # 6. 返回:
    #    sparse_embeddings: [B, N_pts+N_boxes, 256]  供 Transformer self/cross attention
    #    dense_embeddings : [B, 256, 64, 64]          与图像特征逐像素相加
    return sparse_embeddings, dense_embeddings
  1. 功能:把三种提示(点、框、mask)全部映射成与 SAM 图像 embedding 同维度的向量。

  2. sparse embeddings:点/框 → 1-D 向量序列,后续给 Transformer 做 cross-attention。

  3. dense embeddings:mask → 2-D 特征图,与 64×64 的图像 embedding 逐像素相加;若无 mask,就用固定的"无 mask"向量铺一张同样大小的图,保证后续计算图不变。

5.1.1 _get_batch_size

python 复制代码
def _get_batch_size(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> int:
    """
    Gets the batch size of the output given the batch size of the input prompts.
    """
    # points:[
    #            tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0'),
    #            tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
    #    ]
    # boxes:None  
    # masks:None

    # 1. 点提示优先级最高:points[0] 是坐标张量,形状 (B, N, 2)
    if points is not None:
        return points[0].shape[0]          # 直接返回第 0 维 batch 大小
        # points[0]:torch.Size([1, 2, 2])   points[0].shape[0]: 1

    # 2. 没有点时看框提示:boxes 形状 (B, 4)
    elif boxes is not None:
        return boxes.shape[0]

    # 3. 也没有框时看 mask 提示:masks 形状 (B, 1, H, W)
    elif masks is not None:
        return masks.shape[0]

    # 4. 三者皆空,默认 batch=1(推理时用户可能啥也不给)
    else:
        return 1

极简工具函数:在三种提示里按"点→框→mask"顺序挑一个现成的张量,把它的 batch 维大小读出来,保证后续 embedding 操作知道要处理多少张图。若没有任何提示,就保守地返回 1。

5.1.2 _embed_points

python 复制代码
def _embed_points(
        self,
        points: torch.Tensor,
        labels: torch.Tensor,
        pad: bool,
    ) -> torch.Tensor:
        """Embeds point prompts."""

        # 输入:
        # points : tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0')
        # labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
        # pad: True

        # 1. 把坐标从"左上角"挪到像素中心,与 SAM 训练保持一致
        points = points + 0.5  # Shift to center of pixel
        # tensor([[[421.8333, 540.9445], [655.4333, 922.100]]], device='cuda:0')

        if pad:
            # 2. 追加一个坐标 (0,0)、标签 -1 的"虚拟点",防止空序列
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            # padding_point: tensor([[[0., 0.]]], device='cuda:0')

            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
            # padding_label: tensor([[-1.]], device='cuda:0')

            points = torch.cat([points, padding_point], dim=1)
            # points: tensor([[[421.8333, 540.9445], [655.4333, 922.100], [0.0000,    0.0000]]], device='cuda:0')

            labels = torch.cat([labels, padding_label], dim=1)
            # labels: tensor([[ 2.,  3., -1.]], device='cuda:0')

        # 3. 正弦-余弦位置编码:坐标 → 256-D 向量
        # points: torch.Size([1, 3, 2])
        # self.input_image_size: (1024, 1024)
        point_embedding = self.pe_layer.forward_with_coords(
            points, self.input_image_size
        )
        # point_embedding: torch.Size([1, 3, 256])

        # 4. 对 padding 点:先清零位置编码,再加可学习的"not-a-point"向量
        point_embedding[labels == -1] = 0.0
        # point_embedding: torch.Size([1, 3, 256])

        # self.not_a_point_embed.weight: Embedding(1, 256)
        point_embedding[labels == -1] += self.not_a_point_embed.weight
        # point_embedding: torch.Size([1, 3, 256])

        # 5. 按标签值加上对应的角色向量(0/1/2/3 各有一个可学习向量)
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        point_embedding[labels == 2] += self.point_embeddings[2].weight
        point_embedding[labels == 3] += self.point_embeddings[3].weight

        # 返回形状 (B, N, 256) 的嵌入序列
        return point_embedding
  1. 坐标平移:让提示点落在像素中心,与 SAM 训练保持一致。

  2. 可选 padding:防止"空序列"导致模型报错;padding 点的标签为 -1。

  3. 位置编码:正弦-余弦映射,保证坐标信息可区分。

  4. 角色向量:四种标签各自对应一个可学习向量,加到对应点的嵌入上,实现"前景/背景/框角"语义区分。

1. 把坐标从"左上角"挪到像素中心,与 SAM 训练保持一致

points = points + 0.5

为什么加0.5就是挪到像素中心

图像里记录的坐标通常是"左上角"原点、以像素为单位的整数网格:

复制代码
(0,0) ------ (1,0) ------ (2,0) ------ ...
  |        |        |
(0,1) ------ (1,1) ------ (2,1) ------ ...
  • (x, y) 在内存里表示的是这个像素格子的左上角

  • 真正想用的"物理"中心是 (x + 0.5, y + 0.5)

于是给原坐标统一加 0.5,就把"左上角"挪到了该像素正方形的中心,与 SAM 训练时采用的坐标系一致,后续正弦-余弦编码才能正确反映空间位置。

5.1.2.1 PositionEmbeddingRandom.forward_with_coords

segment-anything-2/sam2/modeling/position_encoding.py

3. 正弦-余弦位置编码:坐标 → 256-D 向量

points: torch.Size([1, 3, 2])

self.input_image_size: (1024, 1024)

point_embedding = self.pe_layer.forward_with_coords(

points, self.input_image_size

)

point_embedding: torch.Size([1, 3, 256])

python 复制代码
class PositionEmbeddingRandom(nn.Module):
    """
    Positional encoding using random spatial frequencies.
    """

    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        super().__init__()
        if scale is None or scale <= 0.0:
            scale = 1.0
        # 随机生成 2×128 的高斯矩阵,作为"空间频率"权重,训练期间冻结
        self.register_buffer(
            "positional_encoding_gaussian_matrix",
            scale * torch.randn((2, num_pos_feats)),
        )

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape

        # 输入:
        # coords: torch.Size([1, 3, 2])
        # coords: tensor([[[0.4093, 0.5264], [0.6349, 0.8959],  [0.0000, 0.0000]]], device='cuda:0'))

        # 1. 把 [0,1] 映射到 [-1,1],与 NeRF 等做法一致
        coords = 2 * coords - 1
        # coords: torch.Size([1, 3, 2])
        # coords: tensor([[[-0.1813, 0.0528], [ 0.2697, 0.7917], [-1.0000, -1.0000]]], device='cuda:0')

        # 2. 用随机高斯矩阵做线性变换:2D 坐标 → 128 维"频率向量"
        # coords: torch.Size([1, 3, 2])
        # self.positional_encoding_gaussian_matrix: torch.Size([2, 128])
        coords = coords @ self.positional_encoding_gaussian_matrix
        # coords: torch.Size([1, 3, 128])

        # 3. 放大 2π,让 sin/cos 充分振荡
        coords = 2 * np.pi * coords
        # coords: torch.Size([1, 3, 128])

        # 4. 拼接 sin 和 cos,得到 256 维最终编码(128×2)
        # outputs d_1 x ... x d_n x C shape
        # torch.Size([1, 3, 256])
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size
        device: Any = self.positional_encoding_gaussian_matrix.device
        # 构造 [1,1] 起步的累加网格,每个像素坐标 = 整数索引 - 0.5
        grid = torch.ones((h, w), device=device, dtype=torch.float32)
        y_embed = grid.cumsum(dim=0) - 0.5   # 行方向累加 → y 坐标
        x_embed = grid.cumsum(dim=1) - 0.5   # 列方向累加 → x 坐标
        # 归一化到 [0,1]
        y_embed = y_embed / h
        x_embed = x_embed / w

        # 堆成 (H,W,2) 后过 _pe_encoding → (H,W,256)
        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
        return pe.permute(2, 0, 1)  # C x H x W

    def forward_with_coords(
        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
    ) -> torch.Tensor:
        """Positionally encode points that are not normalized to [0,1]."""

        # 输入:
        # coords_input: tensor([[[419.1667, 539.0482], [650.1000, 917.359, device='cuda:0')
        # image_size: (1024, 1024)

        # 1. 克隆一份,避免原地修改用户张量
        coords = coords_input.clone()

        # 2. 按图像宽高归一化到 [0,1];注意 x 对应 W,y 对应 H
        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
        # coords: tensor([[[0.4093, 0.5264], [0.6349, 0.8959],  [0.0000, 0.0000]]], device='cuda:0')

        # 3. 过 _pe_encoding 得到 (B, N, 256) 的位置编码
        return self._pe_encoding(coords.to(torch.float))  # B x N x C
  1. 核心思想:用固定的高斯随机矩阵把 2D 坐标映射到高维向量,再经 sin/cos 得到与坐标平滑相关的位置编码。

  2. forward 给整幅图生成网格编码,形状 256×H×W,与图像特征相加。

  3. forward_with_coords 给任意像素坐标(未归一化)即时计算编码,形状 B×N×256,用于点/框提示。

_pe_encoding函数中为什么要进行这些操作,为什么维度是这样变化的?

一句话总结

"先归一化→随机线性投影→2π 放大→sin/cos 拼接" 这套流程,把任意 2-D 坐标变成 256-D 高频、平滑、可反推的位置编码;维度变化 B×N×2 → B×N×128 → B×N×256 只是通道扩维,Batch 与点数始终不变。

下面按步骤说明"为什么"以及"为什么维度这样走"。


  1. 归一化到 [-1,1]
    coords = 2 * coords - 1
  • 原因:与 NeRF、Transformer 位置编码常用做法一致,把值中心放到 0,使后续 sin/cos 在正负两侧对称振荡,避免偏置。

  • 维度:仍保持 B×N×2,只是数值区间从 [0,1] 变 [-1,1]。

  1. 随机高斯矩阵投影
    coords = coords @ self.positional_encoding_gaussian_matrix # [B,N,2] @ [2,128] → [B,N,128]
  • 原因:

    -- 直接把 2-D 喂进 sin/cos 只能得到 4 个值(sin x, cos x, sin y, cos y),信息量太少。

    -- 用固定但随机的矩阵 R^(2×128) 把 2-D 映射到 128 维"频率空间",每个维度都是原始坐标的不同线性组合,等价于给每个通道分配不同的"波数"。

    -- 矩阵在训练前一次性生成并 register_buffer,不参与梯度更新,保证同一坐标永远得到同一编码。

  • 维度:B×N×2 → B×N×128,通道扩到 128,Batch 大小 B 与点数 N 不变。

  1. 放大 2π
    coords = 2π * coords
  • 原因:sin/cos 周期 2π,若数值范围太小(如 ±1),函数几乎线性,缺乏高频分量;放大后不同坐标差值更明显,提高空间分辨率。

  • 维度:仍为 B×N×128,只是数值放大。

  1. sin/cos 拼接
    return torch.cat([sin(coords), cos(coords)], dim=-1) # [B,N,128] + [B,N,128] → [B,N,256]
  • 原因:

    -- sin 与 cos 相位差 90°,组合后能唯一恢复原始坐标,同时提供平滑、可导、有界的表示。

    -- 拼接后通道翻倍,形成 256-D 向量,与 SAM 的 prompt_embed_dim 对齐。

  • 维度:B×N×128 → B×N×256,最终输出 256 维位置编码。


维度变化小结

复制代码
B×N×2
   │   1. 归一化 [-1,1]  (数值变,shape 不变)
   ▼
B×N×2
   │   2. @ [2,128]  (矩阵乘,通道扩 128)
   ▼
B×N×128
   │   3. ×2π         (数值放大,shape 不变)
   ▼
B×N×128
   │   4. cat(sin,cos) (通道再翻倍)
   ▼
B×N×256

Batch 维(B)和点数维(N)始终不变,只有最后一维从 2 → 128 → 256。

为什么torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])?

PyTorch 的 @torch.matmul)在三维张量上做广播式批量矩阵乘法

  • 左侧 A(1, 3, 2)

  • 右侧 B(2, 128)

规则:把最后两维当成真正的矩阵,前面所有维当作批维(batch)。

这里批维都是 1,于是自动广播成相同的批维形状,然后逐批做矩阵乘:
(3, 2) @ (2, 128) → (3, 128)

批维 1 保持不变,最终得到:
torch.Size([1, 3, 128])

一句话:最后一维是"契约"维度(必须相等),倒数第二维是结果维度;前面所有维只要可广播即可。

如何理解随机高斯矩阵投影与放大2π

把"随机高斯矩阵投影 + 放大 2π"想成"给 2-D 坐标配一把随机但固定的尺子,再把刻度拉得足够细",让 sin/cos 能"看"到高频细节。


  1. 随机高斯矩阵投影 ------ 造一把多刻度的尺子
  • 原始坐标只有 x,y 两个数,信息量极少。

  • 用固定的随机矩阵 R^(2×128) 做一次线性映射,相当于

    对 (x,y) 做 128 组不同的加权求和 ,每组权重服从高斯分布。

    结果:128 个"新特征",每个都是原始坐标的不同线性组合,

    等价于 128 条不同斜率、不同截距的"尺子"去读同一坐标。

  • 因为矩阵固定,同一坐标永远得到同一 128-D 向量,保证可复现。

  • 随机性让不同通道之间不相关,后续 sin/cos 就能产出丰富、不冗余的频率分量。


  1. 放大 2π ------ 把刻度拉密,让 sin/cos 足够敏感
  • sin/cos 的一个完整周期是 2π。

  • 若输入数值范围只有 [-1,1],sin/cos 变化非常缓慢,几乎呈线性,

    导致相邻坐标编码几乎一样,空间分辨率低。

  • 乘以 2π 后,数值范围变成 [-2π, 2π],

    同样坐标差 0.01 在函数值上会被放大到 Δ ≈ 0.06π,

    sin/cos 明显跳变,相邻坐标即可区分。

  • 直观效果:把"粗刻度"变"细刻度",引入高频分量

    使网络既能捕捉大尺度平滑,也能分辨小尺度抖动。


一句话总结

随机投影 = 用 128 把随机尺子量坐标

2π 放大 = 把尺子刻度加密,让 sin/cos 能"看"到更细的空间纹理。

先归一化→随机线性投影→2π 放大→sin/cos 拼接,这有必然的顺序吗?比如先2π 放大再随机线性投影可以吗?

不能调换------"先投影、再 2π 放大"是数学上保证"每条随机方向都能均匀覆盖 sin/cos 全周期"的关键顺序;反过来做会把高频噪声塞进投影矩阵,破坏编码的平滑性与可区分性。


  1. 数学视角

    想让每条随机方向 ω_i 都充分利用 sin/cos 的 2π 周期:
    PE_i = sin(2π ⋅ ω_i^T [x;y])

    必须把 2π 放在最外层 ;否则
    PE_i = sin(ω_i^T (2π [x;y])) = sin((2π ω_i)^T [x;y])

    等价于把随机向量直接放大 2π------

  • 导致同一矩阵元素量级差异巨大(有的方向 0.1,有的方向 6.28),

  • 数值分布不均匀,训练/推理时容易出现梯度爆炸、频率分布失衡。


  1. 几何视角

    "归一化 → 投影"一步把 2-D 点均匀地撒进 128-D 球体;

    "×2π"再把球体半径拉到 2π,使得 sin/cos 在整个球面上均匀振荡。

    若先 ×2π 再投影,相当于把原始坐标先拉到 [-2π,2π],

    再做随机投影------此时投影输入的数值范围已经变大,

  • 随机矩阵无法"感知"原来的 [-1,1] 边界,

  • 不同样本的投影值可能落在 [-20,20] 甚至更大,

    导致 sin/cos 进入多周期混叠,相邻坐标编码反而变得不可区分。


  1. 实验经验

    原版 SAM、NeRF、Transformer 位置编码无一例外采用

    "归一化 → 随机/固定投影 → 2π 放大 → sin/cos"顺序;

    调换后复现的跟踪精度/重建 PSNR 均明显下降。


结论

顺序是刚性的:

  1. 归一化 → 2. 随机投影 → 3. 2π 放大 → 4. sin/cos 拼接

    不可颠倒。

5.1.2.2 self.pe_layer.forward_with_coords之后做了什么

一、self.pe_layer.forward_with_coords 返回之后做了什么?

它只干了两件事:

  1. 把"坐标"变成"256-D 向量"

    返回的 point_embedding 形状 [B, N, 256] 里,每个点已经带上了唯一、平滑、可反推的位置信息(正弦-余弦编码)。

  2. 把"语义角色"塞进同一向量

    随后代码用 labels 做掩码,把 4 个可学习的"角色向量"逐点加进去:

    • label 0 → 背景点击

    • label 1 → 前景点击

    • label 2 → 框左上角

    • label 3 → 框右下角

    • label -1 → padding 虚点(先清零再加 not_a_point_embed

    这样同一物理位置只要角色不同,最终 256-D 向量也不同,Transformer 就能区分"左上角"与"右下角"或"前景"与"背景"。


二、整个 _embed_points 在做什么?

把用户给的"原始像素坐标 + 标签"变成可直接送进 Transformer 的 256-D 提示向量,流程:

  1. 像素中心化 → 对齐 SAM 训练方式

  2. 可选 padding → 防止空序列

  3. 位置编码 → 让网络知道"点在哪儿"

  4. 角色加向量 → 让网络知道"这是什么点"


三、为什么要这么做?

  • 位置必须可区分:同一标签的两个点如果坐标不同,向量必须不同 → 靠正弦-余弦编码。

  • 角色必须可区分:同一坐标的两个点如果标签不同,向量必须不同 → 靠加角色向量。

  • 统一接口 :最终出来的 [B, N, 256] 与 SAM 的 prompt_embed_dim 一致,可直接和图像 token 做 cross-attention,无需额外转换。

一句话:
_embed_points 就是"坐标 + 标签 → 256-D 提示向量"的流水线,保证位置可辨、角色可辨、格式兼容

point_embedding[labels == -1] = 0.0 什么意思
  • labels 形状是 (B=1, N=3),里面只有 3 个标量:
    tensor([[2, 3, -1]])

    它们分别对应第 0 个点、第 1 个点、第 2 个点 的标签,不是 256 维向量里的某个分量

  • point_embedding 形状是 (B=1, N=3, C=256)

    可以看成 3 条 256-D 向量:

    复制代码
    [0] 号向量 → 标签 2
    [1] 号向量 → 标签 3
    [2] 号向量 → 标签 -1

布尔索引 labels == -1 返回同样 (1,3) 的 mask:
tensor([[False, False, True]])

用这条 mask 去挑 point_embedding 时,按前两个维度对齐

  • 只要 mask 为 True 的位置,就把整条 256-D 向量抓出来(共 256 个数)置 0;

  • 其余两条向量完全不动。

所以"256"跟"3"并不直接对应,而是:

3 个点 → 3 条 256-D 向量 → 用 3 个标签值决定哪一整条向量要被清零。

point_embedding[labels == -1] += self.not_a_point_embed.weight 什么意思?

这行代码的意思是:

把"标签为 -1 的点"对应的那一整条 256 维向量,清零后再单独加上一个可学习的"虚拟点"嵌入向量self.not_a_point_embed.weight),从而告诉模型"这是一个 padding 虚点,不是真实用户提示"。

具体步骤:

  1. 前面已经把 point_embedding[labels == -1] 整行置 0;

  2. self.not_a_point_embed 是一个 nn.Embedding(1, 256),即 1 个 256 维的可学习向量;

  3. += 把这 256 个数加到对应行上,使该行不再全 0,而是带有"我是 padding"这一固定语义。

结果:

  • 真实提示点 → 位置编码 + 角色向量

  • padding 虚点 → 全 0 + "not-a-point"向量

模型后续看到这条向量,就能区分"用户没给提示"与"用户给了提示"。

5. 按标签值加上对应的角色向量(0/1/2/3 各有一个可学习向量)

point_embedding[labels == 0] += self.point_embeddings[0].weight

point_embedding[labels == 1] += self.point_embeddings[1].weight

point_embedding[labels == 2] += self.point_embeddings[2].weight

point_embedding[labels == 3] += self.point_embeddings[3].weight

这什么意思

把"角色"这一离散信息写进向量里:

每种标签值(0/1/2/3)都对应一个可学习的 256 维向量self.point_embeddings[i].weight)。

哪条点的标签等于 i,就把对应向量整行加到这点的 256-D 嵌入上,

于是同一坐标、不同角色也能产生不同的最终提示向量,供 Transformer 区分"背景点 / 前景点 / 框左上角 / 框右下角"。

5.1.3 torch.cat([sparse_embeddings, point_embeddings], dim=1)什么意思?

pad=(boxes is None):当没有框提示时,给点序列末尾补一个"padding 点"

point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))

point_embeddings: torch.Size([1, 3, 256])

sparse_embeddings: torch.Size([1, 0, 256])

sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)

sparse_embeddings: torch.Size([1, 3, 256])

torch.cat 在这里就是"把空袋子换成实袋子":

  • 一开始 sparse_embeddings 是形状 (1, 0, 256)------第二维长度为 0,相当于一个"空列表"。

  • point_embeddings 形状 (1, 3, 256) 是刚刚算好的 3 条 256-D 提示向量。

  • torch.cat([sparse_embeddings, point_embeddings], dim=1)

    沿 dim=1(提示点个数维度)把空张量和 3 条向量拼起来,结果变成 (1, 3, 256)

一句话:用拼接把"空提示序列"变成"含有 3 个点的提示序列",后续就能直接送进 Transformer 做交叉注意力。

5.1.4 dense_embeddings

if masks is not None:

dense_embeddings = self._embed_masks(masks)

else:

self.no_mask_embed: nn.Embedding(1, 256)

self.image_embedding_size: (64, 64)

dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(

bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]

)

dense_embeddings: torch.Size([1, 256, 64, 64])

什么意思?

这段代码在生成"mask 提示"的密集嵌入(dense_embeddings),供后续与图像特征逐像素相加:

  1. 如果用户给了 mask
    dense_embeddings = self._embed_masks(masks)

    → 用轻量 CNN 把 [B,1,H,W] 的 mask 压成 [B,256,64,64] 的特征图。

  2. 如果没给 mask

    • self.no_mask_embed 是一个 nn.Embedding(1,256),里面只有 1 个 256-D 向量(可学习)。

    • .reshape(1, -1, 1, 1) 变成 [1,256,1,1]

    • .expand(bs, -1, 64, 64) 沿 batch 和空间维铺平

      → 得到 一张全是同一向量的 64×64 特征图 ,形状 [B,256,64,64]

结果:

无论有没有外部 mask,网络都拿到一张 通道=256、空间=64×64 的密集嵌入,后续可直接与 SAM 的图像 embedding 逐像素相加;没 mask 时就用"可学习的'无 mask'向量"填满整张图,保持计算图一致。

为什么.reshape(1, -1, 1, 1) 变成 [1,256,1,1],为什么 .expand(bs, -1, 64, 64) 沿 batch 和空间维铺平 → 得到 一张全是同一向量的 64×64 特征图,形状 [B,256,64,64],reshape和expand有啥区别?

一句话先答:
reshape 只"重新划分尺寸"不复制数据expand 只"广播重复"不分配新内存 ,两者配合才能把 1 个 256-D 向量"变"成一张 [B,256,64,64] 的常量特征图。


  1. reshape(1, -1, 1, 1) 怎么得到 [1,256,1,1]
  • self.no_mask_embed.weight 原始形状 [256](Embedding 的 weight 是 2-D,但第 0 维=1,所以 view 出来就是 1×256)。

  • reshape(1, -1, 1, 1)-1 代表"自动推断",于是
    256 → (1, 256, 1, 1)

    结果:4-D 张量,只有 256 个数,内存没复制。


  1. expand(bs, -1, 64, 64) 怎么得到 [B,256,64,64]
  • expand 不会真的复制数据,它只在逻辑上把尺寸为 1 的维度"广播"成任意大小。

  • 这里把
    batch=1 → B
    H=1 → 64
    W=1 → 64

    通道 256 本来就是 1 份,保持不变。

    于是形状变成 [B,256,64,64]内部仍共享那 256 个数,几乎不占额外显存。


  1. 为什么要两步
  • reshape 先把向量"立"成 4-D,才能参与广播;

  • expand 再把"单像素"广播成"整张图",保证后续与图像特征相加时尺寸一致。

  • 如果直接用 repeat 会真的复制数据,显存翻倍;expand 是零成本视图。


总结
reshape 负责"改形状",expand 负责"铺平面";一个改维度,一个做广播,两者配合就能把 1 个 256-D 向量零成本地变成 [B,256,64,64] 的常量特征图。

你的意思是256个数,每个数都扩展到了64x64个,但是内存没复制,是这样理解吗?

不是把"每个数"都扩成 64×64,而是:

256 个数整体 当成一条 256 通道的"样板",在空间维度 (64×64) 上重复引用同一份内存 ------逻辑上看起来是 [B,256,64,64],物理上仍然只有那 256 个 float,外加一点点广播元信息。

所以没有产生 256×64×64 份新数据,显存几乎零增长。

相关推荐
深度学习lover2 小时前
<数据集>yolo茶叶嫩芽识别数据集<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·茶叶嫩芽识别
你们补药再卷啦2 小时前
ai(四)基础知识
人工智能
得物技术2 小时前
从0到1搭建一个智能分析OBS埋点数据的AI Agent|得物技术
人工智能·agent
乾元2 小时前
动态路由策略回归测试:把 CI/CD 思想带入网络路由(工程化 · Near-term)
运维·服务器·网络·人工智能·ci/cd·架构·智能路由器
AI视觉网奇2 小时前
live2d 全身数字人
人工智能·计算机视觉
HelloReader2 小时前
用 Spark Shell 做交互式数据分析从入门到自包含应用
人工智能
_codemonster2 小时前
AI大模型入门到实战系列(四)深入理解 Transformer 大语言模型
人工智能·语言模型·transformer
爱笑的眼睛113 小时前
从零构建与深度优化:PyTorch训练循环的工程化实践
java·人工智能·python·ai