目录
三、_prepare_memory_conditioned_features
[5.1 PromptEncoder.forward](#5.1 PromptEncoder.forward)
[5.1.1 _get_batch_size](#5.1.1 _get_batch_size)
[5.1.2 _embed_points](#5.1.2 _embed_points)
[5.1.2.1 PositionEmbeddingRandom.forward_with_coords](#5.1.2.1 PositionEmbeddingRandom.forward_with_coords)
_pe_encoding函数中为什么要进行这些操作,为什么维度是这样变化的?
[为什么torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])?](#为什么torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])?)
[先归一化→随机线性投影→2π 放大→sin/cos 拼接,这有必然的顺序吗?比如先2π 放大再随机线性投影可以吗?](#先归一化→随机线性投影→2π 放大→sin/cos 拼接,这有必然的顺序吗?比如先2π 放大再随机线性投影可以吗?)
[5.1.2.2 self.pe_layer.forward_with_coords之后做了什么](#5.1.2.2 self.pe_layer.forward_with_coords之后做了什么)
[point_embedding[labels == -1] = 0.0 什么意思](#point_embedding[labels == -1] = 0.0 什么意思)
[5.1.3 torch.cat([sparse_embeddings, point_embeddings], dim=1)什么意思?](#5.1.3 torch.cat([sparse_embeddings, point_embeddings], dim=1)什么意思?)
[5.1.4 dense_embeddings](#5.1.4 dense_embeddings)
一、前言

从《SAM2跟踪的理解2》开始,我们一直在看第一帧做了什么,其实就是通过load_first_frame函数中的_get_image_feature的image encoder提取图像特征,然后再通过add_new_prompt将第一帧中用户给定提示输入到SAM2中进行推理获取掩码,推理的函数就是add_new_prompt里面的 _run_single_frame_inference里面的track_step。
这一篇我们就是看track_step里面做了什么 ,其实track_step先调用了_prepare_memory_conditioned_features函数,这个函数就是把当前帧的视觉特征与"记忆"做融合,如果有记忆的话,这个函数是会调用memory attention的 (如果你想知道memory attention在干嘛,这一篇是没有的,等到我们说第二帧发生了什么的时候就会有),然而我们现在是在第一帧,是没有"记忆"的 ,那它做什么事?就是做了下面这些事:(第三节的内容)
当前帧的视觉特征: current_vision_feats[-1]: torch.Size([4096, 1, 256])
"无记忆"embedding:self.no_mem_embed: torch.Size([1, 1, 256])
相加: pix_feat_with_mem: torch.Size([4096, 1, 256])
HW,B,C\] -\> \[B,C,H,W\]: pix_feat_with_mem: torch.Size(\[1, 256, 64, 64\])
下面是第一帧情况下的函数调用顺序,可以看到我们在2.14个地方,后面主要经历的是prompt encoder和mask decoder,本篇是讲prompt encoder(第五节内容)
注:如果注释中涉及到数值,比如矩形提示的位置坐标,这种因为调试分了很多次,所以会有些变化,不用在意这些细节,只要看维度变化即可。
2.12 <重点> add_new_prompt
2.13 <重点> _run_single_frame_inference
2.14 <重点> track_step (这篇开头在这)
2.15 <重点> _prepare_memory_conditioned_features
2.16 _use_multimask
2.17 <重点> _forward_sam_heads
2.18 提示编码器:类PromptEncoder.forward
2.19 类PositionEmbeddingRandom.forward_with_coords(这篇结束在这)
2.20 类PromptEncoder.get_dense_pe
2.21 掩码解码器 类MaskDecoder.forward
2.22 类MaskDecoder.predict_masks
2.23 TwoWayTransformer.forward
2.24 TwoWayAttentionBlock.forward
2.25 Attention.forward
2.26 <重点> MLP.forward
2.27 Attention.forward
2.28 LayerNorm2d.forward
2.29 MaskDecoder._dynamic_multimask_via_stability
2.30 MaskDecoder._get_stability_scores
2.31 fill_holes_in_mask_scores
2.32 _get_maskmem_pos_enc
2.33 _consolidate_temp_output_across_obj
2.34 _get_orig_video_res_output
prompt encoder是如何把输入的提示点编码成向量的呢?其实是下面这样,像点提示、矩形提示这些它会编码成torch.Size([B, N, 256])的****sparse_embeddings(N表示提示点数,包括padding点) ,mask提示会编码成torch.Size([B, 256, 64, 64])的****dense_embeddings (即使没有masks提示也会整一个),最后就是返回sparse_embeddings和dense_embeddings
1. 输入的提示(比如输入两个提示点,第三个是防止不给提示加的全0的点,我们叫padding点)
coords: torch.Size([1, 3, 2]) # 3表示点的数量,2表示点有x和y
比如 coords: tensor([[[0.4093, 0.5264], [0.6349, 0.8959], [0.0000, 0.0000]]])
2. 用随机高斯矩阵做线性变换:2D 坐标 → 128 维"频率向量"
self.positional_encoding_gaussian_matrix: torch.Size([2, 128])
coords = coords @ self.positional_encoding_gaussian_matrix
torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])
3. 放大 2π,让 sin/cos 充分振荡
coords = 2 * np.pi * coords
coords: torch.Size([1, 3, 128])
4. 拼接 sin 和 cos,得到 256 维最终编码(128×2)
torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
位置编码 point_embedding: torch.Size([1, 3, 256])
5. 算法如何区分提示点是正提示点还是负提示点,还是padding点?加上****角色向量
padding 虚点 → 全 0 + "not-a-point"向量
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight # padding点
真实提示点 → 位置编码 + 角色向量
point_embedding[labels == 0] += self.point_embeddings[0].weight # 负提示点
point_embedding[labels == 1] += self.point_embeddings[1].weight # 正提示点
point_embedding[labels == 2] += self.point_embeddings[2].weight # 矩形左上角点
point_embedding[labels == 3] += self.point_embeddings[3].weight # 矩形右下角点
维度不变 point_embedding: torch.Size([1, 3, 256])
6. sparse_embeddings:torch.Size([1, 3, 256])
sparse_embeddings: torch.Size([1, 0, 256])
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
sparse_embeddings: torch.Size([1, 3, 256])
7.处理 mask 提示
有 mask → 用轻量 CNN 压成 256×64×64 的"dense embedding";
dense_embeddings = self._embed_masks(masks)
# 无 mask → 用可学习的"no-mask"向量铺满 64×64 空间,保持形状一致
self.no_mask_embed: nn.Embedding(1, 256)
self.image_embedding_size: (64, 64)
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
(1,256) => (1,256,1,1) => (1,256, 64, 64)
8. dense_embeddings: torch.Size([1, 256, 64, 64])
二、track_step
python
def track_step(
self,
frame_idx, # 当前帧的全局序号(从 0 开始)
is_init_cond_frame, # 是否为初始条件帧(第一帧或用户重新标注的帧)
current_vision_feats, # 当前帧的 backbone 特征列表,长度=层数,每层形状 (HW, B, C)
current_vision_pos_embeds, # 对应的位置编码,与 current_vision_feats 一一对应
feat_sizes, # 每层特征图的空间分辨率,如 [(256,256), (128,128)]
point_inputs, # 用户点击/提示,格式 dict{"point_coords": (B,K,2), "point_labels": (B,K)}
mask_inputs, # 用户输入的 mask(可选),形状 (B,1,H,W) 或 None
output_dict, # 全局状态字典,包含 memory bank、已编码的记忆等
num_frames, # 整个视频的总帧数
track_in_reverse=False, # 是否按倒序跟踪(demo 中回退播放时使用)
run_mem_encoder=True, # 是否要把当前预测结果编码成记忆供后续帧使用
prev_sam_mask_logits=None, # 上一帧预测的 low-res mask logits(demo 连续点击时累积用)
):
"""
单帧跟踪函数。
1. 融合历史记忆与当前视觉特征 -> 得到带记忆的 pix_feat
2. 用 SAM head 解码出 mask
3. 可选地把预测 mask 再编码成新的记忆
返回 current_out,包含本帧所有需要保存/上传的结果。
"""
# current_vision_feats: [
# torch.Size([65536, 1, 32])
# torch.Size([16384, 1, 64])
# torch.Size([4096, 1, 256])
# ]
# current_vision_pos_embeds: [
# torch.Size([65536, 1, 256])
# torch.Size([16384, 1, 256])
# torch.Size([4096, 1, 256])
# ]
# feat_sizes': [
# (256, 256),
# (128, 128),
# (64, 64)
# ]
#]
# ------------------------------------------------------------------
# 1. 组装本帧的输入字典,方便后续上传或可视化
# ------------------------------------------------------------------
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# current_out 示例:
# {
# "point_inputs": {
# "point_coords": tensor([[[416.5333, 541.3926],
# [656.5333, 926.340]]], device='cuda:0'),
# "point_labels": tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
# },
# "mask_inputs": None
# }
# ------------------------------------------------------------------
# 2. 取出 backbone 的多层特征,转成 BCHW 供 SAM 的高分辨率分支使用
# 最后一层特征留给记忆融合,其余层作为 high_res_features
# ------------------------------------------------------------------
if len(current_vision_feats) > 1:
# 将 (HW,B,C) -> (B,C,H,W)
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
# high_res_features:[
# torch.Size([1, 32, 256, 256]),
# torch.Size([1, 64, 128, 128])
#]
else:
high_res_features = None
# ------------------------------------------------------------------
# 3. 如果用户直接给了 GT mask 且配置要求跳过 SAM,则直接把它当输出
# ------------------------------------------------------------------
# mask_inputs:None use_mask_input_as_output_without_sam:True
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# 取最后一层特征作为像素嵌入
pix_feat = current_vision_feats[-1].permute(1, 2, 0) # (B,C,H*W)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) # (B,C,H,W)
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# ------------------------------------------------------------------
# 4. 正常路径:把当前特征与记忆库融合,得到带历史信息的 pix_feat
# ------------------------------------------------------------------
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:], # 只用最后一层
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# ------------------------------------------------------------------
# 5. 处理 demo 场景:如果之前已有 SAM 预测结果,把它作为 mask prompt
# ------------------------------------------------------------------
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None, \
"prev_sam_mask_logits 仅在有点击且无外部 mask 时生效"
mask_inputs = prev_sam_mask_logits
# ------------------------------------------------------------------
# 6. 决定 SAM 是否输出多 mask(初始帧或仅单点击时通常允许多 mask)
# ------------------------------------------------------------------
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
# ------------------------------------------------------------------
# 7. 真正调用 SAM 解码头,得到 low/high 分辨率 mask 及 object pointer
# ------------------------------------------------------------------
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
# sam_outputs 是一个 tuple,按顺序包含:
# 0 low_res_multimasks (B,num_masks,256,256)
# 1 high_res_multimasks(B,num_masks,1024,1024)
# 2 ious (B,num_masks)
# 3 low_res_masks (B,1,256,256) ← 最终选中的单 mask
# 4 high_res_masks (B,1,1024,1024)
# 5 obj_ptr (B,256) ← 用于记忆的 object token
# 6 object_score_logits (B,1) ← 该帧目标存在置信度
# ------------------------------------------------------------------
# 8. 解压 SAM 输出,把关键字段写回 current_out
# ------------------------------------------------------------------
(
_, # low_res_multimasks
_, # high_res_multimasks
_, # ious
low_res_masks, # 最终单 mask (B,1,256,256)
high_res_masks, # 对应高分辨率 (B,1,1024,1024)
obj_ptr, # object pointer (B,256)
_, # object_score_logits
) = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
# current_out 此时新增:
# {
# ...,
# "pred_masks": torch.Size([1, 1, 256, 256]),
# "pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),
# "obj_ptr": torch.Size([1, 256])
# }
# ------------------------------------------------------------------
# 9. 可选:把当前预测 mask 编码成新的记忆,供后续帧使用
# 若 run_mem_encoder=False 或 num_maskmem=0 则跳过
# ------------------------------------------------------------------
if run_mem_encoder and self.num_maskmem > 0:
# 用高分辨率 mask 进行记忆编码,细节保留更好
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
is_mask_from_pts=(point_inputs is not None), # 区分点击 vs 外部 mask
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
# 不编码记忆时置空,调用方据此跳过 memory bank 更新
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
# current_out:
# {
# "point_inputs": {...},
# "mask_inputs": None,
# "pred_masks": torch.Size([1, 1, 256, 256]),
# "pred_masks_high_res": torch.Size([1, 1, 1024, 1024]),
# "obj_ptr": torch.Size([1, 256]),
# "maskmem_features": None,
# "maskmem_pos_enc": None
# }
return current_out
一句话
把"当前帧图像特征 + 用户给的点击/遮罩提示 + 历史记忆"喂给 SAM 解码器,生成当前帧的高质量分割结果,同时(可选)把这一帧的预测再编码成新的"记忆",供后续帧继续跟踪。
三、_prepare_memory_conditioned_features
在track_step中,_prepare_memory_conditioned_features传入的参数,值得注意的是current_vision_feats和current_vision_pos_embeds和feat_sizes都是传的最后一层
------------------------------------------------------------------
4. 正常路径:把当前特征与记忆库融合,得到带历史信息的 pix_feat
------------------------------------------------------------------
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:], # 只用最后一层
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
下面就是_prepare_memory_conditioned_features函数,注意现在我们是在初始条件帧:没有过去记忆,只能"凭空"编码,直接使把self.no_mem_embed: torch.Size([1, 1, 256]) 加到current_vision_feats[-1]: torch.Size([4096, 1, 256])得到torch.Size([4096, 1, 256]),然后再经过:
HW,B,C\] -\> \[B,C,H,W
得到:pix_feat_with_mem: torch.Size([1, 256, 64, 64]),就返回了。没有后面的memory attention
python
def _prepare_memory_conditioned_features(
self,
frame_idx, # 当前帧在整个视频序列中的绝对序号
is_init_cond_frame, # 当前帧是否属于"初始条件帧"(即带人工标注的 key-frame)
current_vision_feats, # 当前帧的视觉主干输出,list[tensor],每层分辨率一个
current_vision_pos_embeds, # 当前帧对应的空间位置编码,与上一项一一对应
feat_sizes, # 每项特征图的 (H, W)
output_dict, # 之前所有帧已经算好的输出字典,含 memory 与 object pointer
num_frames, # 整个视频总帧数(用于边界检查)
track_in_reverse=False, # 演示模式下倒着跟踪(从后往前)
):
"""
把当前帧的视觉特征与"记忆"做融合,得到 memory-conditioned 特征图。
如果 num_maskmem==0 则退化为普通 SAM(无记忆,单图分割)。
"""
# current_vision_feats: [torch.Size([4096, 1, 256]),]
B = current_vision_feats[-1].size(1) # batch size
# B: 1
C = self.hidden_dim # 隐层维数
# C: 256
H, W = feat_sizes[-1] # 最低分辨率特征图尺寸
# H: 64 W:64
device = current_vision_feats[-1].device
# device: device(type='cuda', index=0)
# 1) 无记忆模式:直接返回当前帧顶层特征(复现 SAM 单图效果)
# self.num_maskmem: 7
if self.num_maskmem == 0:
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
return pix_feat
num_obj_ptr_tokens = 0 # 记录后面要塞进 transformer 的 object pointer token 数
# 2) 非初始条件帧:需要把"过去记忆"拿出来做 cross-attention
# is_init_cond_frame: Ture
if not is_init_cond_frame:
to_cat_memory, to_cat_memory_pos_embed = [], []
# 2.1 先选最多 max_cond_frames_in_attn 个"离当前帧最近"的初始条件帧
assert len(output_dict["cond_frame_outputs"]) > 0
cond_outputs = output_dict["cond_frame_outputs"]
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
frame_idx, cond_outputs, self.max_cond_frames_in_attn
)
# t_pos=0 表示这些条件帧在时序上被视为"原点"
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# 2.2 再取 (self.num_maskmem-1) 个"普通记忆帧"(非条件帧)
r = self.memory_temporal_stride_for_eval # 采样 stride
for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # 离当前帧多远
if t_rel == 1:
# 紧邻帧(t_rel==1)一定取
prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
else:
# 更远的帧按 stride r 跳着取
if track_in_reverse:
prev_frame_idx = -(-(frame_idx + 2) // r) * r + (t_rel - 2) * r
else:
prev_frame_idx = ((frame_idx - 2) // r) * r - (t_rel - 2) * r
# 先查非条件帧字典,没有再查"未被选中"的条件帧
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) or \
unselected_cond_outputs.get(prev_frame_idx, None)
t_pos_and_prevs.append((t_pos, out))
# 2.3 把选中的记忆特征和对应的位置编码拼到一起
for t_pos, prev in t_pos_and_prevs:
if prev is None:
continue # 该帧不存在,跳过(padding)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # [HW,B,C]
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# 再加上"时序位置编码":越早的记忆 t_pos 越大,embedding 越靠后
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
to_cat_memory_pos_embed.append(maskmem_enc)
# 2.4 构造"对象指针" object pointer(用于保持实例身份)
if self.use_obj_ptrs_in_encoder:
max_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder)
# 选哪些条件帧的 pointer:eval 时可限制只取"过去"的
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
ptr_cond_outputs = {
t: out for t, out in selected_cond_outputs.items()
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
}
else:
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [(abs(frame_idx - t), out["obj_ptr"])
for t, out in ptr_cond_outputs.items()]
# 继续往前(或往后)取普通帧的 pointer
for t_diff in range(1, max_ptrs):
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
if t < 0 or (num_frames is not None and t >= num_frames):
break
out = output_dict["non_cond_frame_outputs"].get(
t, unselected_cond_outputs.get(t, None))
if out is not None:
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
if pos_and_ptrs: # 至少有一个 pointer
pos_list, ptrs_list = zip(*pos_and_ptrs)
obj_ptrs = torch.stack(ptrs_list, dim=0) # [ptr_len, B, C]
if self.add_tpos_enc_to_obj_ptrs: # 给 pointer 也加时序位置
t_diff_max = max_ptrs - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = get_1d_sine_pe(torch.tensor(pos_list, device=device) / t_diff_max,
dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
# 如果 mem_dim < C,把一个 pointer 拆成多个 token
if self.mem_dim < C:
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) \
.permute(0, 2, 1, 3).flatten(0, 1)
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
to_cat_memory.append(obj_ptrs)
to_cat_memory_pos_embed.append(obj_pos)
num_obj_ptr_tokens = obj_ptrs.shape[0]
# 3) 初始条件帧:没有过去记忆,只能"凭空"编码
else:
# self.directly_add_no_mem_embed:True
if self.directly_add_no_mem_embed:
# 简单地把"无记忆"embedding 加到特征上,跳过 transformer
# current_vision_feats[-1]: torch.Size([4096, 1, 256])
# self.no_mem_embed: torch.Size([1, 1, 256])
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
# pix_feat_with_mem: torch.Size([4096, 1, 256])
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
# pix_feat_with_mem: torch.Size([1, 256, 64, 64])
# 输出形状:[HW,B,C] -> [B,C,H,W]
return pix_feat_with_mem
# 否则喂一个 dummy memory token 给 transformer,防止空输入
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
# 4) 把以上所有记忆 token 拼成序列,送进 memory_attention 层
memory = torch.cat(to_cat_memory, dim=0)
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats, # 当前帧多层特征
curr_pos=current_vision_pos_embeds, # 当前帧位置编码
memory=memory, # 过去记忆 + object pointer
memory_pos=memory_pos_embed, # 对应位置编码
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
# 输出形状:[HW,B,C] -> [B,C,H,W]
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
一句话总结
该函数是 SAM 2 "记忆解码器"的核心:
对于当前帧,先把"最近若干初始条件帧"和"按 stride 采样到的普通记忆帧"的特征、位置编码以及 object pointer 收集起来,拼成一个记忆序列;随后用轻量级 transformer(
memory_attention)让当前帧特征与这段记忆做 cross-attention,从而把"过去看到的掩码和对象身份"注入到当前特征图里,最终返回一张融合后的 B×C×H×W 特征,供后续掩码解码器使用。
四、_use_multimask
继续看track_step后面发生了什么
------------------------------------------------------------------
5. 处理 demo 场景:如果之前已有 SAM 预测结果,把它作为 mask prompt
------------------------------------------------------------------
prev_sam_mask_logits: None
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None, \
"prev_sam_mask_logits 仅在有点击且无外部 mask 时生效"
mask_inputs = prev_sam_mask_logits
------------------------------------------------------------------
6. 决定 SAM 是否输出多 mask(初始帧或仅单点击时通常允许多 mask)
------------------------------------------------------------------
is_init_cond_frame: True
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
python
def _use_multimask(self, is_init_cond_frame, point_inputs):
"""
根据当前帧类型和提示点数量,决定 SAM mask-decoder 是否输出**多掩膜**(multimask)。
背景:
SAM 默认对**模糊提示**(单点/粗框)同时给出 3 个候选掩膜,让用户或后续模块挑选;
对**清晰提示**(多点、已有较好 mask)则只输出 1 个精炼掩膜,节省计算与显存。
本函数即实现这一策略开关。
参数:
is_init_cond_frame (bool):
True → 当前帧是初始条件帧(用户首次点击或重新初始化)。
point_inputs (dict | None):
点提示字典,含 "point_labels" (B, N) 和 "point_coords" (B, N, 2)。
若为 None,表示无点提示。
返回:
bool:
True → 使用 multimask 输出(3 个候选);
False → 只输出单掩膜。
"""
# 1. 计算当前提示点数量(不含背景点)
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
# point_labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
# num_pts: 2 对于粗框,提示有2个,矩形左上角和矩形右下角
# 2. 同时满足以下 4 个条件才开启 multimask:
# a) 全局开关打开;
# b) 当前是初始条件帧,**或**配置允许跟踪阶段也 multimask;
# c) 点数量在 [min, max] 区间内;
# d) 实际上只要有提示点就会触发,没点时 num_pts=0 自动不满足。
multimask_output = (
# self.multimask_output_in_sam:True
self.multimask_output_in_sam # 全局超参:是否启用 multimask 功能
and
# is_init_cond_frame:True self.multimask_output_for_tracking:True
(is_init_cond_frame or self.multimask_output_for_tracking) # 帧类型限制
and
# self.multimask_min_pt_num:0 num_pt:2 self.multimask_max_pt_num:1
(self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) # 点数限制
)
# multimask_output:False
return multimask_output
一句话
按"全局开关 + 帧类型 + 提示点数量"三要素,快速决定 SAM 解码器是输出 3 个候选 mask(True)还是只输出 1 个精炼 mask(False)。
可以看到,在我们的例子中最后返回的是False,这是因为我们的例子中粗框提示包括两个,一个是矩形的左上角点,一个是矩形的右下角点,2个提示已经是比较精确了,只有在0或者1个提示的时候,提示才是模糊的,模糊提示情况下才输出多个掩码。
五、_forward_sam_heads
继续看track_step后面发生了什么
------------------------------------------------------------------
7. 真正调用 SAM 解码头,得到 low/high 分辨率 mask 及 object pointer
------------------------------------------------------------------
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
python
@torch.inference_mode()
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
完整的 SAM 风格「提示编码 + 掩膜解码」一条龙:
把 backbone 特征与用户提示(点/框/mask)一起送 SAM,
输出多组或单组掩膜 logits、IoU 估计、对象指针等。
参数:
backbone_features (Tensor):
已融合记忆的图像嵌入,形状 (B, C, H, W),
其中 H=W=sam_image_embedding_size(默认 64)。
point_inputs (dict | None):
{"point_coords": (B, P, 2), "point_labels": (B, P)}
坐标为**绝对像素**(会在内部归一化),label: 1=前, 0=背, -1=pad。
mask_inputs (Tensor | None):
低分辨率掩膜提示 (B,1,H',W'),与点提示互斥。
high_res_features (list[Tensor] | None):
额外两层更高分辨率特征 [4H,4W] 与 [2H,2W],供 decoder refine 边缘。
multimask_output (bool):
True → 输出 3 候选掩膜 + 3 IoU;False → 1 掩膜 + 1 IoU。
返回:
tuple:
0. low_res_multimasks -- (B,M,H/4,W/4) 低分多掩膜 logits
1. high_res_multimasks -- (B,M,H,W) 高分多掩膜 logits
2. ious -- (B,M) 各掩膜 IoU 估计
3. low_res_masks -- (B,1,H/4,W/4) **最佳**低分掩膜
4. high_res_masks -- (B,1,H,W) **最佳**高分掩膜
5. obj_ptr -- (B,C) 对象指针(用于记忆)
6. object_score_logits -- (B,) 对象出现置信度(可软可硬)
"""
# backbone_features: torch.Size([1, 256, 64, 64])
# point_inputs: {
# 'point_coords': tensor([[[421.3333, 540.4445],
# [654.9333, 921.6000]]], device='cuda:0'),
# 'point_labels': tensor([[2, 3]],
# device='cuda:0', dtype=torch.int32)
# }
# mask_inputs:None
# high_res_features:[
# torch.Size([1, 32, 256, 256]),
# torch.Size([1, 64, 128, 128])
# ]
# multimask_output: False
B = backbone_features.size(0)
# B: 1
device = backbone_features.device
# device: device(type='cuda', index=0)
# --- 1. 输入断言:尺寸必须匹配 SAM 预设 ---
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# --- 2. 构造点提示 ---
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"] # (B, P, 2)
# sam_point_coords: torch.Size([1, 2, 2])
# sam_point_coords: tensor([[[421.3333, 540.4445],
# [654.9333, 921.6000]]], device='cuda:0'),
sam_point_labels = point_inputs["point_labels"] # (B, P)
# sam_point_labels: torch.Size([1, 2])
# sam_point_labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# 无点提示时,用 1 个 pad 点(label=-1)占位,保证 prompt encoder 正常 forward
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# --- 3. 构造掩膜提示 ---
# mask_inputs: None
if mask_inputs is not None:
# 若外部 mask 分辨率不符,先双线性下采样到 prompt encoder 期望尺寸
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # 抗锯齿,减少下采样 aliasing
)
else:
sam_mask_prompt = mask_inputs
else:
# 无 mask 时,prompt encoder 内部会加 learned `no_mask_embed`
sam_mask_prompt = None
# --- 4. 送进 SAM 提示编码器 ---
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None, # 框提示在外部已转成 2 点,不再走这里
masks=sam_mask_prompt,
)
# --- 5. 送进 SAM mask 解码器 ---
(
low_res_multimasks, # (B, M, H/4, W/4) M=3 or 1
ious, # (B, M) IoU 估计
sam_output_tokens, # (B, M, C) 解码器输出 token
object_score_logits, # (B,) 对象出现/消失 logits
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(), # 固定 2D 位置编码
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # image 已 batch,无需再 repeat
high_res_features=high_res_features, # 高分特征供 refine 边缘
)
# --- 6. 对象 score 后处理:若模型预测"无对象",把掩膜 logits 置为 NO_OBJ_SCORE ---
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0 # 硬阈值
# 记忆用掩膜必须**硬**选择:有对象才保留,否则置极大负值
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# --- 7. 数据类型转换:bf16/fp16 -> fp32(老版本 PyTorch interpolate 不支持 bf16)---
low_res_multimasks = low_res_multimasks.float()
# 上采样到图像原分辨率(stride=1)
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
# --- 8. 选取最佳掩膜 ---
sam_output_token = sam_output_tokens[:, 0] # 默认取第 1 个 token(单掩膜时即自身)
if multimask_output:
# 多掩膜时,选 IoU 估计最高的那个
best_iou_inds = torch.argmax(ious, dim=-1) # (B,)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
# 若解码器输出了多个 token,同样要选最佳
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
# 单掩膜时,最佳即唯一
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# --- 9. 从最佳 token 提取对象指针(用于记忆)---
obj_ptr = self.obj_ptr_proj(sam_output_token) # (B, C)
# --- 10. 对象指针后处理:若模型认为"无对象",指针也被削弱或替换 ---
if self.pred_obj_scores:
if self.soft_no_obj_ptr:
# 软削弱:用 sigmoid 概率加权
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
# 硬削弱:0/1 加权
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
# 剩余权重用"无对象指针"补齐,保证指针和为 1
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
# --- 11. 返回打包结果 ---
return (
low_res_multimasks, # 0 低分多掩膜
high_res_multimasks, # 1 高分多掩膜
ious, # 2 各掩膜 IoU 估计
low_res_masks, # 3 最佳低分掩膜
high_res_masks, # 4 最佳高分掩膜
obj_ptr, # 5 对象指针(记忆 key/query)
object_score_logits, # 6 对象出现置信度 logits
)
一句话
把"已融合记忆的图像特征"和"用户给的点/低分 mask 提示"一起塞进 SAM 的提示编码器 + 掩膜解码器,先产出 1 或 3 组低分 mask,再上采样到原图分辨率,挑 IoU 最高的作为最终 mask,并同步给出对应的对象指针与存在置信度,供后续记忆更新使用。
5.1 PromptEncoder.forward
sam2/modeling/sam/prompt_encoder.py
--- 4. 送进 SAM 提示编码器 ---
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels), # 走的是这里
boxes=None, # 框提示在外部已转成 2 点,不再走这里
masks=sam_mask_prompt,
)
python
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
# points:[
# tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0'),
# tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
# ]
# boxes:None
# masks:None
# 1. 统一推断 batch 大小(以第一个非 None 的提示为准)
bs = self._get_batch_size(points, boxes, masks)
# bs: 1
# 2. 先建一个空张量,后续把点/框嵌入拼进来
# self.embed_dim: 256
sparse_embeddings = torch.empty(
(bs, 0, self.embed_dim), device=self._get_device()
)
# sparse_embeddings: torch.Size([1, 0, 256])
# 3. 处理点提示:转成 256-d 向量并拼到 sparse_embeddings
if points is not None:
coords, labels = points
# coords: tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0')
# labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
# pad=(boxes is None):当没有框提示时,给点序列末尾补一个"padding 点"
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
# point_embeddings: torch.Size([1, 3, 256])
# sparse_embeddings: torch.Size([1, 0, 256])
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
# sparse_embeddings: torch.Size([1, 3, 256])
# 4. 处理框提示:左上+右下两个角点各生成 256-d 向量并拼接
# boxes: None
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
# 5. 处理 mask 提示:
# 有 mask → 用轻量 CNN 压成 256×64×64 的"dense embedding";
# 无 mask → 用可学习的"no-mask"向量铺满 64×64 空间,保持形状一致
# masks: None
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
# self.no_mask_embed: nn.Embedding(1, 256)
# self.image_embedding_size: (64, 64)
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
# dense_embeddings: torch.Size([1, 256, 64, 64])
# 6. 返回:
# sparse_embeddings: [B, N_pts+N_boxes, 256] 供 Transformer self/cross attention
# dense_embeddings : [B, 256, 64, 64] 与图像特征逐像素相加
return sparse_embeddings, dense_embeddings
功能:把三种提示(点、框、mask)全部映射成与 SAM 图像 embedding 同维度的向量。
sparse embeddings:点/框 → 1-D 向量序列,后续给 Transformer 做 cross-attention。
dense embeddings:mask → 2-D 特征图,与 64×64 的图像 embedding 逐像素相加;若无 mask,就用固定的"无 mask"向量铺一张同样大小的图,保证后续计算图不变。
5.1.1 _get_batch_size
python
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
# points:[
# tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0'),
# tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
# ]
# boxes:None
# masks:None
# 1. 点提示优先级最高:points[0] 是坐标张量,形状 (B, N, 2)
if points is not None:
return points[0].shape[0] # 直接返回第 0 维 batch 大小
# points[0]:torch.Size([1, 2, 2]) points[0].shape[0]: 1
# 2. 没有点时看框提示:boxes 形状 (B, 4)
elif boxes is not None:
return boxes.shape[0]
# 3. 也没有框时看 mask 提示:masks 形状 (B, 1, H, W)
elif masks is not None:
return masks.shape[0]
# 4. 三者皆空,默认 batch=1(推理时用户可能啥也不给)
else:
return 1
极简工具函数:在三种提示里按"点→框→mask"顺序挑一个现成的张量,把它的 batch 维大小读出来,保证后续 embedding 操作知道要处理多少张图。若没有任何提示,就保守地返回 1。
5.1.2 _embed_points
python
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
# 输入:
# points : tensor([[[421.3333, 540.4445], [654.9333, 921.600]]], device='cuda:0')
# labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
# pad: True
# 1. 把坐标从"左上角"挪到像素中心,与 SAM 训练保持一致
points = points + 0.5 # Shift to center of pixel
# tensor([[[421.8333, 540.9445], [655.4333, 922.100]]], device='cuda:0')
if pad:
# 2. 追加一个坐标 (0,0)、标签 -1 的"虚拟点",防止空序列
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
# padding_point: tensor([[[0., 0.]]], device='cuda:0')
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
# padding_label: tensor([[-1.]], device='cuda:0')
points = torch.cat([points, padding_point], dim=1)
# points: tensor([[[421.8333, 540.9445], [655.4333, 922.100], [0.0000, 0.0000]]], device='cuda:0')
labels = torch.cat([labels, padding_label], dim=1)
# labels: tensor([[ 2., 3., -1.]], device='cuda:0')
# 3. 正弦-余弦位置编码:坐标 → 256-D 向量
# points: torch.Size([1, 3, 2])
# self.input_image_size: (1024, 1024)
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
# point_embedding: torch.Size([1, 3, 256])
# 4. 对 padding 点:先清零位置编码,再加可学习的"not-a-point"向量
point_embedding[labels == -1] = 0.0
# point_embedding: torch.Size([1, 3, 256])
# self.not_a_point_embed.weight: Embedding(1, 256)
point_embedding[labels == -1] += self.not_a_point_embed.weight
# point_embedding: torch.Size([1, 3, 256])
# 5. 按标签值加上对应的角色向量(0/1/2/3 各有一个可学习向量)
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
point_embedding[labels == 2] += self.point_embeddings[2].weight
point_embedding[labels == 3] += self.point_embeddings[3].weight
# 返回形状 (B, N, 256) 的嵌入序列
return point_embedding
坐标平移:让提示点落在像素中心,与 SAM 训练保持一致。
可选 padding:防止"空序列"导致模型报错;padding 点的标签为 -1。
位置编码:正弦-余弦映射,保证坐标信息可区分。
角色向量:四种标签各自对应一个可学习向量,加到对应点的嵌入上,实现"前景/背景/框角"语义区分。
1. 把坐标从"左上角"挪到像素中心,与 SAM 训练保持一致
points = points + 0.5
为什么加0.5就是挪到像素中心
图像里记录的坐标通常是"左上角"原点、以像素为单位的整数网格:
(0,0) ------ (1,0) ------ (2,0) ------ ... | | | (0,1) ------ (1,1) ------ (2,1) ------ ...
点
(x, y)在内存里表示的是这个像素格子的左上角。真正想用的"物理"中心是
(x + 0.5, y + 0.5)。于是给原坐标统一加 0.5,就把"左上角"挪到了该像素正方形的中心,与 SAM 训练时采用的坐标系一致,后续正弦-余弦编码才能正确反映空间位置。
5.1.2.1 PositionEmbeddingRandom.forward_with_coords
segment-anything-2/sam2/modeling/position_encoding.py
3. 正弦-余弦位置编码:坐标 → 256-D 向量
points: torch.Size([1, 3, 2])
self.input_image_size: (1024, 1024)
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
point_embedding: torch.Size([1, 3, 256])
python
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
# 随机生成 2×128 的高斯矩阵,作为"空间频率"权重,训练期间冻结
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
# 输入:
# coords: torch.Size([1, 3, 2])
# coords: tensor([[[0.4093, 0.5264], [0.6349, 0.8959], [0.0000, 0.0000]]], device='cuda:0'))
# 1. 把 [0,1] 映射到 [-1,1],与 NeRF 等做法一致
coords = 2 * coords - 1
# coords: torch.Size([1, 3, 2])
# coords: tensor([[[-0.1813, 0.0528], [ 0.2697, 0.7917], [-1.0000, -1.0000]]], device='cuda:0')
# 2. 用随机高斯矩阵做线性变换:2D 坐标 → 128 维"频率向量"
# coords: torch.Size([1, 3, 2])
# self.positional_encoding_gaussian_matrix: torch.Size([2, 128])
coords = coords @ self.positional_encoding_gaussian_matrix
# coords: torch.Size([1, 3, 128])
# 3. 放大 2π,让 sin/cos 充分振荡
coords = 2 * np.pi * coords
# coords: torch.Size([1, 3, 128])
# 4. 拼接 sin 和 cos,得到 256 维最终编码(128×2)
# outputs d_1 x ... x d_n x C shape
# torch.Size([1, 3, 256])
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
# 构造 [1,1] 起步的累加网格,每个像素坐标 = 整数索引 - 0.5
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5 # 行方向累加 → y 坐标
x_embed = grid.cumsum(dim=1) - 0.5 # 列方向累加 → x 坐标
# 归一化到 [0,1]
y_embed = y_embed / h
x_embed = x_embed / w
# 堆成 (H,W,2) 后过 _pe_encoding → (H,W,256)
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
# 输入:
# coords_input: tensor([[[419.1667, 539.0482], [650.1000, 917.359, device='cuda:0')
# image_size: (1024, 1024)
# 1. 克隆一份,避免原地修改用户张量
coords = coords_input.clone()
# 2. 按图像宽高归一化到 [0,1];注意 x 对应 W,y 对应 H
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
# coords: tensor([[[0.4093, 0.5264], [0.6349, 0.8959], [0.0000, 0.0000]]], device='cuda:0')
# 3. 过 _pe_encoding 得到 (B, N, 256) 的位置编码
return self._pe_encoding(coords.to(torch.float)) # B x N x C
核心思想:用固定的高斯随机矩阵把 2D 坐标映射到高维向量,再经 sin/cos 得到与坐标平滑相关的位置编码。
forward给整幅图生成网格编码,形状256×H×W,与图像特征相加。
forward_with_coords给任意像素坐标(未归一化)即时计算编码,形状B×N×256,用于点/框提示。
_pe_encoding函数中为什么要进行这些操作,为什么维度是这样变化的?
一句话总结
"先归一化→随机线性投影→2π 放大→sin/cos 拼接" 这套流程,把任意 2-D 坐标变成 256-D 高频、平滑、可反推的位置编码;维度变化
B×N×2 → B×N×128 → B×N×256只是通道扩维,Batch 与点数始终不变。下面按步骤说明"为什么"以及"为什么维度这样走"。
- 归一化到 [-1,1]
coords = 2 * coords - 1
原因:与 NeRF、Transformer 位置编码常用做法一致,把值中心放到 0,使后续 sin/cos 在正负两侧对称振荡,避免偏置。
维度:仍保持
B×N×2,只是数值区间从 [0,1] 变 [-1,1]。
- 随机高斯矩阵投影
coords = coords @ self.positional_encoding_gaussian_matrix# [B,N,2] @ [2,128] → [B,N,128]
原因:
-- 直接把 2-D 喂进 sin/cos 只能得到 4 个值(sin x, cos x, sin y, cos y),信息量太少。
-- 用固定但随机的矩阵
R^(2×128)把 2-D 映射到 128 维"频率空间",每个维度都是原始坐标的不同线性组合,等价于给每个通道分配不同的"波数"。-- 矩阵在训练前一次性生成并 register_buffer,不参与梯度更新,保证同一坐标永远得到同一编码。
维度:
B×N×2 → B×N×128,通道扩到 128,Batch 大小 B 与点数 N 不变。
- 放大 2π
coords = 2π * coords
原因:sin/cos 周期 2π,若数值范围太小(如 ±1),函数几乎线性,缺乏高频分量;放大后不同坐标差值更明显,提高空间分辨率。
维度:仍为
B×N×128,只是数值放大。
- sin/cos 拼接
return torch.cat([sin(coords), cos(coords)], dim=-1)# [B,N,128] + [B,N,128] → [B,N,256]
原因:
-- sin 与 cos 相位差 90°,组合后能唯一恢复原始坐标,同时提供平滑、可导、有界的表示。
-- 拼接后通道翻倍,形成 256-D 向量,与 SAM 的 prompt_embed_dim 对齐。
维度:
B×N×128 → B×N×256,最终输出 256 维位置编码。
维度变化小结
B×N×2 │ 1. 归一化 [-1,1] (数值变,shape 不变) ▼ B×N×2 │ 2. @ [2,128] (矩阵乘,通道扩 128) ▼ B×N×128 │ 3. ×2π (数值放大,shape 不变) ▼ B×N×128 │ 4. cat(sin,cos) (通道再翻倍) ▼ B×N×256Batch 维(B)和点数维(N)始终不变,只有最后一维从 2 → 128 → 256。
为什么torch.Size([1, 3, 2]) @ torch.Size([2, 128]) => coords: torch.Size([1, 3, 128])?
PyTorch 的
@(torch.matmul)在三维张量上做广播式批量矩阵乘法:
左侧
A:(1, 3, 2)右侧
B:(2, 128)规则:把最后两维当成真正的矩阵,前面所有维当作批维(batch)。
这里批维都是
1,于是自动广播成相同的批维形状,然后逐批做矩阵乘:
(3, 2) @ (2, 128) → (3, 128)批维
1保持不变,最终得到:
torch.Size([1, 3, 128])一句话:最后一维是"契约"维度(必须相等),倒数第二维是结果维度;前面所有维只要可广播即可。
如何理解随机高斯矩阵投影与放大2π
把"随机高斯矩阵投影 + 放大 2π"想成"给 2-D 坐标配一把随机但固定的尺子,再把刻度拉得足够细",让 sin/cos 能"看"到高频细节。
- 随机高斯矩阵投影 ------ 造一把多刻度的尺子
原始坐标只有 x,y 两个数,信息量极少。
用固定的随机矩阵
R^(2×128)做一次线性映射,相当于对 (x,y) 做 128 组不同的加权求和 ,每组权重服从高斯分布。
结果:128 个"新特征",每个都是原始坐标的不同线性组合,
等价于 128 条不同斜率、不同截距的"尺子"去读同一坐标。
因为矩阵固定,同一坐标永远得到同一 128-D 向量,保证可复现。
随机性让不同通道之间不相关,后续 sin/cos 就能产出丰富、不冗余的频率分量。
- 放大 2π ------ 把刻度拉密,让 sin/cos 足够敏感
sin/cos 的一个完整周期是 2π。
若输入数值范围只有 [-1,1],sin/cos 变化非常缓慢,几乎呈线性,
导致相邻坐标编码几乎一样,空间分辨率低。
乘以 2π 后,数值范围变成 [-2π, 2π],
同样坐标差 0.01 在函数值上会被放大到 Δ ≈ 0.06π,
sin/cos 明显跳变,相邻坐标即可区分。
直观效果:把"粗刻度"变"细刻度",引入高频分量 ,
使网络既能捕捉大尺度平滑,也能分辨小尺度抖动。
一句话总结
随机投影 = 用 128 把随机尺子量坐标 ;
2π 放大 = 把尺子刻度加密,让 sin/cos 能"看"到更细的空间纹理。
先归一化→随机线性投影→2π 放大→sin/cos 拼接,这有必然的顺序吗?比如先2π 放大再随机线性投影可以吗?
不能调换------"先投影、再 2π 放大"是数学上保证"每条随机方向都能均匀覆盖 sin/cos 全周期"的关键顺序;反过来做会把高频噪声塞进投影矩阵,破坏编码的平滑性与可区分性。
数学视角
想让每条随机方向
ω_i都充分利用 sin/cos 的 2π 周期:
PE_i = sin(2π ⋅ ω_i^T [x;y])必须把 2π 放在最外层 ;否则
PE_i = sin(ω_i^T (2π [x;y])) = sin((2π ω_i)^T [x;y])等价于把随机向量直接放大 2π------
导致同一矩阵元素量级差异巨大(有的方向 0.1,有的方向 6.28),
数值分布不均匀,训练/推理时容易出现梯度爆炸、频率分布失衡。
几何视角
"归一化 → 投影"一步把 2-D 点均匀地撒进 128-D 球体;
"×2π"再把球体半径拉到 2π,使得 sin/cos 在整个球面上均匀振荡。
若先 ×2π 再投影,相当于把原始坐标先拉到 [-2π,2π],
再做随机投影------此时投影输入的数值范围已经变大,
随机矩阵无法"感知"原来的 [-1,1] 边界,
不同样本的投影值可能落在 [-20,20] 甚至更大,
导致 sin/cos 进入多周期混叠,相邻坐标编码反而变得不可区分。
实验经验
原版 SAM、NeRF、Transformer 位置编码无一例外采用
"归一化 → 随机/固定投影 → 2π 放大 → sin/cos"顺序;
调换后复现的跟踪精度/重建 PSNR 均明显下降。
结论
顺序是刚性的:
归一化 → 2. 随机投影 → 3. 2π 放大 → 4. sin/cos 拼接
不可颠倒。
5.1.2.2 self.pe_layer.forward_with_coords之后做了什么
一、
self.pe_layer.forward_with_coords返回之后做了什么?它只干了两件事:
把"坐标"变成"256-D 向量"
返回的
point_embedding形状[B, N, 256]里,每个点已经带上了唯一、平滑、可反推的位置信息(正弦-余弦编码)。把"语义角色"塞进同一向量
随后代码用
labels做掩码,把 4 个可学习的"角色向量"逐点加进去:
label 0→ 背景点击
label 1→ 前景点击
label 2→ 框左上角
label 3→ 框右下角
label -1→ padding 虚点(先清零再加not_a_point_embed)这样同一物理位置只要角色不同,最终 256-D 向量也不同,Transformer 就能区分"左上角"与"右下角"或"前景"与"背景"。
二、整个
_embed_points在做什么?把用户给的"原始像素坐标 + 标签"变成可直接送进 Transformer 的 256-D 提示向量,流程:
像素中心化 → 对齐 SAM 训练方式
可选 padding → 防止空序列
位置编码 → 让网络知道"点在哪儿"
角色加向量 → 让网络知道"这是什么点"
三、为什么要这么做?
位置必须可区分:同一标签的两个点如果坐标不同,向量必须不同 → 靠正弦-余弦编码。
角色必须可区分:同一坐标的两个点如果标签不同,向量必须不同 → 靠加角色向量。
统一接口 :最终出来的
[B, N, 256]与 SAM 的prompt_embed_dim一致,可直接和图像 token 做 cross-attention,无需额外转换。一句话:
_embed_points就是"坐标 + 标签 → 256-D 提示向量"的流水线,保证位置可辨、角色可辨、格式兼容。
point_embedding[labels == -1] = 0.0 什么意思
labels形状是(B=1, N=3),里面只有 3 个标量:
tensor([[2, 3, -1]])它们分别对应第 0 个点、第 1 个点、第 2 个点 的标签,不是 256 维向量里的某个分量。
point_embedding形状是(B=1, N=3, C=256),可以看成 3 条 256-D 向量:
[0] 号向量 → 标签 2 [1] 号向量 → 标签 3 [2] 号向量 → 标签 -1布尔索引
labels == -1返回同样(1,3)的 mask:
tensor([[False, False, True]])用这条 mask 去挑
point_embedding时,按前两个维度对齐:
只要 mask 为
True的位置,就把整条 256-D 向量抓出来(共 256 个数)置 0;其余两条向量完全不动。
所以"256"跟"3"并不直接对应,而是:
3 个点 → 3 条 256-D 向量 → 用 3 个标签值决定哪一整条向量要被清零。
point_embedding[labels == -1] += self.not_a_point_embed.weight 什么意思?
这行代码的意思是:
把"标签为 -1 的点"对应的那一整条 256 维向量,清零后再单独加上一个可学习的"虚拟点"嵌入向量 (
self.not_a_point_embed.weight),从而告诉模型"这是一个 padding 虚点,不是真实用户提示"。具体步骤:
前面已经把
point_embedding[labels == -1]整行置 0;
self.not_a_point_embed是一个nn.Embedding(1, 256),即 1 个 256 维的可学习向量;
+=把这 256 个数加到对应行上,使该行不再全 0,而是带有"我是 padding"这一固定语义。结果:
真实提示点 → 位置编码 + 角色向量
padding 虚点 → 全 0 + "not-a-point"向量
模型后续看到这条向量,就能区分"用户没给提示"与"用户给了提示"。
5. 按标签值加上对应的角色向量(0/1/2/3 各有一个可学习向量)
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
point_embedding[labels == 2] += self.point_embeddings[2].weight
point_embedding[labels == 3] += self.point_embeddings[3].weight
这什么意思
把"角色"这一离散信息写进向量里:
每种标签值(0/1/2/3)都对应一个可学习的 256 维向量 (
self.point_embeddings[i].weight)。哪条点的标签等于 i,就把对应向量整行加到这点的 256-D 嵌入上,
于是同一坐标、不同角色也能产生不同的最终提示向量,供 Transformer 区分"背景点 / 前景点 / 框左上角 / 框右下角"。
5.1.3 torch.cat([sparse_embeddings, point_embeddings], dim=1)什么意思?
pad=(boxes is None):当没有框提示时,给点序列末尾补一个"padding 点"
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
point_embeddings: torch.Size([1, 3, 256])
sparse_embeddings: torch.Size([1, 0, 256])
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
sparse_embeddings: torch.Size([1, 3, 256])
torch.cat在这里就是"把空袋子换成实袋子":
一开始
sparse_embeddings是形状(1, 0, 256)------第二维长度为 0,相当于一个"空列表"。
point_embeddings形状(1, 3, 256)是刚刚算好的 3 条 256-D 提示向量。
torch.cat([sparse_embeddings, point_embeddings], dim=1)沿
dim=1(提示点个数维度)把空张量和 3 条向量拼起来,结果变成(1, 3, 256)。一句话:用拼接把"空提示序列"变成"含有 3 个点的提示序列",后续就能直接送进 Transformer 做交叉注意力。
5.1.4 dense_embeddings
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
self.no_mask_embed: nn.Embedding(1, 256)
self.image_embedding_size: (64, 64)
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
dense_embeddings: torch.Size([1, 256, 64, 64])
什么意思?
这段代码在生成"mask 提示"的密集嵌入(dense_embeddings),供后续与图像特征逐像素相加:
如果用户给了 mask
dense_embeddings = self._embed_masks(masks)→ 用轻量 CNN 把
[B,1,H,W]的 mask 压成[B,256,64,64]的特征图。如果没给 mask
self.no_mask_embed是一个nn.Embedding(1,256),里面只有 1 个 256-D 向量(可学习)。先
.reshape(1, -1, 1, 1)变成[1,256,1,1]再
.expand(bs, -1, 64, 64)沿 batch 和空间维铺平→ 得到 一张全是同一向量的 64×64 特征图 ,形状
[B,256,64,64]。结果:
无论有没有外部 mask,网络都拿到一张 通道=256、空间=64×64 的密集嵌入,后续可直接与 SAM 的图像 embedding 逐像素相加;没 mask 时就用"可学习的'无 mask'向量"填满整张图,保持计算图一致。
为什么.reshape(1, -1, 1, 1) 变成 [1,256,1,1],为什么 .expand(bs, -1, 64, 64) 沿 batch 和空间维铺平 → 得到 一张全是同一向量的 64×64 特征图,形状 [B,256,64,64],reshape和expand有啥区别?
一句话先答:
reshape只"重新划分尺寸"不复制数据 ;expand只"广播重复"不分配新内存 ,两者配合才能把 1 个 256-D 向量"变"成一张[B,256,64,64]的常量特征图。
reshape(1, -1, 1, 1)怎么得到[1,256,1,1]
self.no_mask_embed.weight原始形状[256](Embedding 的 weight 是 2-D,但第 0 维=1,所以 view 出来就是 1×256)。
reshape(1, -1, 1, 1)里-1代表"自动推断",于是
256 → (1, 256, 1, 1)。结果:4-D 张量,只有 256 个数,内存没复制。
expand(bs, -1, 64, 64)怎么得到[B,256,64,64]
expand不会真的复制数据,它只在逻辑上把尺寸为 1 的维度"广播"成任意大小。这里把
batch=1 → B
H=1 → 64
W=1 → 64通道 256 本来就是 1 份,保持不变。
于是形状变成
[B,256,64,64],内部仍共享那 256 个数,几乎不占额外显存。
- 为什么要两步
reshape先把向量"立"成 4-D,才能参与广播;
expand再把"单像素"广播成"整张图",保证后续与图像特征相加时尺寸一致。如果直接用
repeat会真的复制数据,显存翻倍;expand是零成本视图。
总结
reshape负责"改形状",expand负责"铺平面";一个改维度,一个做广播,两者配合就能把 1 个 256-D 向量零成本地变成[B,256,64,64]的常量特征图。
你的意思是256个数,每个数都扩展到了64x64个,但是内存没复制,是这样理解吗?
不是把"每个数"都扩成 64×64,而是:
256 个数整体 当成一条 256 通道的"样板",在空间维度 (64×64) 上重复引用同一份内存 ------逻辑上看起来是
[B,256,64,64],物理上仍然只有那 256 个 float,外加一点点广播元信息。所以没有产生 256×64×64 份新数据,显存几乎零增长。